Merge branch 'pipecat-ai:main' into feature/genesys_serializer

This commit is contained in:
Sergio Sillero
2026-01-25 21:04:27 +01:00
committed by GitHub
91 changed files with 1764 additions and 337 deletions

View File

@@ -0,0 +1,40 @@
---
name: changelog
description: Create changelog files for important commits in a PR
---
Create changelog files for the important commits in this PR. The PR number is provided as an argument.
## Instructions
1. First, check what commits are on the current branch compared to main:
```
git log main..HEAD --oneline
```
2. For each significant change, create a changelog file in the `changelog/` folder using the format:
- `{PR_NUMBER}.added.md` - for new features
- `{PR_NUMBER}.added.2.md`, `{PR_NUMBER}.added.3.md` - for additional new features
- `{PR_NUMBER}.changed.md` - for changes to existing functionality
- `{PR_NUMBER}.fixed.md` - for bug fixes
- `{PR_NUMBER}.deprecated.md` - for deprecations
3. Each changelog file should at least contain a main single line starting with `- ` followed by a clear description of the change.
4. If the change is complicated, changelog files can have indented lines after the main line with additional details or code samples.
5. Use ⚠️ emoji prefix for breaking changes.
## Example
For PR #3519 with a new feature and a bug fix:
`changelog/3519.added.md`:
```
- Added `SomeNewFeature` for doing something useful.
```
`changelog/3519.fixed.md`:
```
- Fixed an issue where something was not working correctly.
```

View File

@@ -0,0 +1,257 @@
---
name: docstring
description: Document a Python module and its classes using Google style
---
Document a Python module and its classes using Google-style docstrings following project conventions. The class name is provided as an argument.
## Instructions
1. First, find the class in the codebase:
```
Search for "class ClassName" in src/pipecat/
```
2. If multiple files contain that class name:
- List all matches with their file paths
- Ask the user which one they want to document
- Wait for confirmation before proceeding
3. Once the file is identified, read the module to understand its structure:
- Identify all classes, functions, and important type aliases
- Understand the purpose of each component
4. Apply documentation in this order:
- Module docstring (at top, after imports)
- Class docstrings
- `__init__` methods (always document constructor parameters)
- Public methods (not starting with `_`)
- Dataclass/config classes with field descriptions
5. Skip documentation for:
- Private methods (starting with `_`)
- Simple dunder methods (`__str__`, `__repr__`, `__post_init__`)
- Very simple pass-through properties
- **Already documented code** - If a class, method, or function already has a complete docstring that follows the project style, do not modify it. A docstring is complete if it has:
- A one-line summary
- Args section (if it has parameters)
- Returns section (if it returns something meaningful)
- Only add or improve documentation where it is missing or incomplete
## Module Docstring Format
```python
"""[One-line description of module purpose].
[Optional: Longer explanation of functionality, key classes, or use cases.]
"""
```
Example:
```python
"""Neuphonic text-to-speech service implementations.
This module provides WebSocket and HTTP-based integrations with Neuphonic's
text-to-speech API for real-time audio synthesis.
"""
```
## Class Docstring Format
```python
class ClassName:
"""One-line summary describing what the class does.
[Longer description explaining purpose, behavior, and key features.
Use action-oriented language.]
[Optional: Event handlers, usage notes, or important caveats.]
"""
```
Example:
```python
class FrameProcessor(BaseObject):
"""Base class for all frame processors in the pipeline.
Frame processors are the building blocks of Pipecat pipelines, they can be
linked to form complex processing pipelines. They receive frames, process
them, and pass them to the next or previous processor in the chain.
Event handlers available:
- on_before_process_frame: Called before a frame is processed
- on_after_process_frame: Called after a frame is processed
Example::
@processor.event_handler("on_before_process_frame")
async def on_before_process_frame(processor, frame):
...
@processor.event_handler("on_after_process_frame")
async def on_after_process_frame(processor, frame):
...
"""
```
Note: When listing event handlers, do NOT use backticks. Include an `Example::` section (with double colon for Sphinx) showing the decorator pattern and function signature for each event.
## Constructor (`__init__`) Format
```python
def __init__(self, *, param1: Type, param2: Type = default, **kwargs):
"""Initialize the [ClassName].
Args:
param1: Description of param1 and its purpose.
param2: Description of param2. Defaults to [default].
**kwargs: Additional arguments passed to parent class.
"""
```
Example:
```python
def __init__(
self,
*,
api_key: str,
voice_id: Optional[str] = None,
sample_rate: Optional[int] = 22050,
**kwargs,
):
"""Initialize the Neuphonic TTS service.
Args:
api_key: Neuphonic API key for authentication.
voice_id: ID of the voice to use for synthesis.
sample_rate: Audio sample rate in Hz. Defaults to 22050.
**kwargs: Additional arguments passed to parent InterruptibleTTSService.
"""
```
## Method Docstring Format
```python
async def method_name(self, param1: Type) -> ReturnType:
"""One-line summary of what method does.
[Longer description if behavior isn't obvious.]
Args:
param1: Description of param1.
Returns:
Description of return value.
Raises:
ExceptionType: When this exception is raised.
"""
```
Example:
```python
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
"""Put an item into the priority queue.
System frames (`SystemFrame`) have higher priority than any other
frames. If a non-frame item is provided it will have the highest priority.
Args:
item: The item to enqueue.
"""
```
## Dataclass/Config Format
```python
@dataclass
class ConfigName:
"""One-line description of configuration.
[Explanation of when/how to use this config.]
Parameters:
field1: Description of field1.
field2: Description of field2. Defaults to [default].
"""
field1: Type
field2: Type = default_value
```
Example:
```python
@dataclass
class FrameProcessorSetup:
"""Configuration parameters for frame processor initialization.
Parameters:
clock: The clock instance for timing operations.
task_manager: The task manager for handling async operations.
observer: Optional observer for monitoring frame processing events.
"""
clock: BaseClock
task_manager: BaseTaskManager
observer: Optional[BaseObserver] = None
```
## Enum Documentation Format
```python
class EnumName(Enum):
"""One-line description of the enum purpose.
[Longer description of how the enum is used.]
Parameters:
VALUE1: Description of VALUE1.
VALUE2: Description of VALUE2.
"""
VALUE1 = 1
VALUE2 = 2
```
## Writing Style Guidelines
- **Concise and professional** - No casual language or filler words
- **Action-oriented** - Start with verbs: "Processes...", "Manages...", "Converts..."
- **Purpose before implementation** - Explain WHY before HOW
- **Clear parameter descriptions** - Include type hints, defaults, and purpose
- **No redundant type info** - Type hints are in the signature, don't repeat in description
- **Use backticks for code references** - Wrap class names, method names, event names, parameter names, and code snippets in backticks
Good: "Neuphonic API key for authentication."
Bad: "str: The API key (string) that is used for authenticating with Neuphonic."
Good: "Triggers `on_speech_started` when the `VADAnalyzer` detects speech."
Bad: "Triggers on_speech_started when the VADAnalyzer detects speech."
## Deprecation Notice Format
When documenting deprecated code:
```python
"""[Description].
.. deprecated:: X.X.X
`ClassName` is deprecated and will be removed in a future version.
Use `NewClassName` instead.
"""
```
## Checklist
Before finishing, verify:
- [ ] Module has a docstring at the top (after copyright header and imports)
- [ ] All public classes have docstrings
- [ ] All `__init__` methods document their parameters
- [ ] All public methods have docstrings with Args/Returns/Raises as needed
- [ ] Dataclasses use "Parameters:" section for field descriptions
- [ ] Enums document each value in "Parameters:" section
- [ ] Writing is concise and action-oriented
- [ ] No documentation added to private methods (starting with `_`)
- [ ] Existing complete docstrings were left unchanged

View File

@@ -0,0 +1,128 @@
---
name: pr-description
description: Update a GitHub PR description with a summary of changes
---
Update a GitHub pull request description based on the changes in the PR.
## Arguments
```
/pr-description <PR_NUMBER> [--fixes <ISSUE_NUMBERS>]
```
- `PR_NUMBER` (required): The pull request number to update
- `--fixes` (optional): Comma-separated issue numbers that this PR fixes (e.g., `--fixes 123,456`)
Examples:
- `/pr-description 3534`
- `/pr-description 3534 --fixes 123`
- `/pr-description 3534 --fixes 123,456,789`
## Instructions
1. First, gather information about the PR:
- Use GitHub plugin to get PR details (title, current description, base branch)
- Use local git to get commits: `git log main..HEAD --oneline`
- Use local git to get the diff: `git diff main..HEAD`
- Parse any `--fixes` argument for issue numbers
2. Check the existing PR description:
- If it already has a complete, accurate description that reflects the changes, do nothing
- If it's missing sections, incomplete, or outdated compared to the actual changes, proceed to update
- If it only has the template placeholder text, generate a full description
3. Analyze the changes:
- Understand the purpose of each commit
- Identify any breaking changes (API changes, removed features, behavior changes)
- Look for new features, bug fixes, refactoring, or documentation changes
- Collect issue numbers from:
- The `--fixes` argument (if provided)
- Commit messages (patterns like "Fixes #123", "Closes #456", "Resolves #789")
4. Generate or update the PR description with these sections:
## PR Description Format
### Summary (always include)
Brief bullet points describing what changed and why. Focus on the *purpose* and *impact*, not implementation details.
```markdown
## Summary
- Added X to enable Y
- Fixed bug where Z would happen
- Refactored W for better maintainability
```
### Breaking Changes (include only if applicable)
Document any changes that affect existing users or APIs.
```markdown
## Breaking Changes
- `ClassName.method()` now requires a `param` argument
- Removed deprecated `old_function()` - use `new_function()` instead
```
### Testing (include when non-obvious)
How to verify the changes work. Skip for trivial changes.
```markdown
## Testing
- Run `uv run pytest tests/test_feature.py` to verify the fix
- Example usage: `uv run examples/new_feature.py`
```
### Fixes (include if issues are provided or found in commits)
List issues this PR fixes. GitHub will automatically close these issues when the PR is merged.
```markdown
## Fixes
- Fixes #123
- Fixes #456
```
Note: Use "Fixes #X" format (not "Closes" or "Resolves") for consistency. Each issue should be on its own line with "Fixes" to ensure GitHub auto-closes them.
## Guidelines
- **Be concise** - Reviewers should understand the PR in 30 seconds
- **Focus on why** - The diff shows *what* changed, explain *why*
- **Skip empty sections** - Only include sections that have content
- **Use bullet points** - Easier to scan than paragraphs
- **Don't duplicate the diff** - Avoid listing every file or line changed
## Example Output
```markdown
## Summary
- Added `/docstring` skill for documenting Python modules with Google-style docstrings
- Skill finds classes by name and handles conflicts when multiple matches exist
- Skips already-documented code to avoid unnecessary changes
## Testing
/docstring ClassName
## Fixes
- Fixes #123
```
## Checklist
Before updating the PR:
- [ ] Verified existing description needs updating (not already complete)
- [ ] Summary accurately reflects the changes
- [ ] Breaking changes are clearly documented (if any)
- [ ] No unnecessary sections included
- [ ] Description is concise and scannable

View File

@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies
run: |
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
- name: Run tests with coverage
run: |

View File

@@ -37,7 +37,7 @@ jobs:
- name: Install dependencies
run: |
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
- name: Test with pytest
run: |

11
.gitignore vendored
View File

@@ -4,7 +4,14 @@ __pycache__/
*~
venv
.venv
/.idea
.idea
.gradle
.next
next-env.d.ts
local.properties
*.log
*.lock
smart_turn_audio_log
#*#
# Distribution / Packaging
@@ -27,7 +34,7 @@ share/python-wheels/
*.egg
MANIFEST
.DS_Store
.env
.env*
fly.toml
# Examples

View File

@@ -7,6 +7,129 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
<!-- towncrier release notes start -->
## [0.0.100] - 2026-01-20
### Added
- Added Hathora service to support Hathora-hosted TTS and STT models (only
non-streaming)
(PR [#3169](https://github.com/pipecat-ai/pipecat/pull/3169))
- Added `CambTTSService`, using Camb.ai's TTS integration with MARS models
(mars-flash, mars-pro, mars-instruct) for high-quality text-to-speech
synthesis.
(PR [#3349](https://github.com/pipecat-ai/pipecat/pull/3349))
- Added the `additional_headers` param to `WebsocketClientParams`, allowing
`WebsocketClientTransport` to send custom headers on connect, for cases such
as authentication.
(PR [#3461](https://github.com/pipecat-ai/pipecat/pull/3461))
- Added `UserIdleController` for detecting user idle state, integrated into
`LLMUserAggregator` and `UserTurnProcessor` via optional `user_idle_timeout`
parameter. Emits `on_user_turn_idle` event for application-level handling.
Deprecated `UserIdleProcessor` in favor of the new compositional approach.
(PR [#3482](https://github.com/pipecat-ai/pipecat/pull/3482))
- Added `on_user_mute_started` and `on_user_mute_stopped` event handlers to
`LLMUserAggregator` for tracking user mute state changes.
(PR [#3490](https://github.com/pipecat-ai/pipecat/pull/3490))
### Changed
- Enhanced interruption handling in `AsyncAITTSService` by supporting
multi-context WebSocket sessions for more robust context management.
(PR [#3287](https://github.com/pipecat-ai/pipecat/pull/3287))
- Throttle `UserSpeakingFrame` to broadcast at most every 200ms instead of on
every audio chunk, reducing frame processing overhead during user speech.
(PR [#3483](https://github.com/pipecat-ai/pipecat/pull/3483))
### Deprecated
- For consistency with other package names, we just deprecated
`pipecat.turns.mute` (introduced in Pipecat 0.0.99) in favor of
`pipecat.turns.user_mute`.
(PR [#3479](https://github.com/pipecat-ai/pipecat/pull/3479))
### Fixed
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
(PR [#3287](https://github.com/pipecat-ai/pipecat/pull/3287))
- Fixed an issue where the "bot-llm-text" RTVI event would not fire for
realtime (speech-to-speech) services:
- `AWSNovaSonicLLMService`
- `GeminiLiveLLMService`
- `OpenAIRealtimeLLMService`
- `GrokRealtimeLLMService`
The issue was that these services weren't pushing `LLMTextFrame`s. Now
they do.
(PR [#3446](https://github.com/pipecat-ai/pipecat/pull/3446))
- Fixed an issue where `on_user_turn_stop_timeout` could fire while a user is
talking when using `ExternalUserTurnStrategies`.
(PR [#3454](https://github.com/pipecat-ai/pipecat/pull/3454))
- Fixed an issue where user turn start strategies were not being reset after a
user turn started, causing incorrect strategy behavior.
(PR [#3455](https://github.com/pipecat-ai/pipecat/pull/3455))
- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions,
preventing incorrect turn starts when words are spoken with pauses between
them.
(PR [#3462](https://github.com/pipecat-ai/pipecat/pull/3462))
- Fixed an issue where Grok Realtime would error out when running with
SmallWebRTC transport.
(PR [#3480](https://github.com/pipecat-ai/pipecat/pull/3480))
- Fixed a `Mem0MemoryService` issue where passing `async_mode: true` was
causing an error. See
https://docs.mem0.ai/platform/features/async-mode-default-change.
(PR [#3484](https://github.com/pipecat-ai/pipecat/pull/3484))
- Fixed `AWSNovaSonicLLMService.reset_conversation()`, which would previously
error out. Now it successfully reconnects and "rehydrates" from the context
object.
(PR [#3486](https://github.com/pipecat-ai/pipecat/pull/3486))
- Fixed `AzureTTSService` transcript formatting issues:
- Punctuation now appears without extra spaces (e.g., "Hello!" instead of
"Hello !")
- CJK languages (Chinese, Japanese, Korean) no longer have unwanted spaces
between characters
(PR [#3489](https://github.com/pipecat-ai/pipecat/pull/3489))
- Fixed an issue where `UninterruptibleFrame` frames would not be preserved in
some cases.
(PR [#3494](https://github.com/pipecat-ai/pipecat/pull/3494))
- Fixed memory leak in `LiveKitTransport` when `video_in_enabled` is `False`.
(PR [#3499](https://github.com/pipecat-ai/pipecat/pull/3499))
- Fixed an issue in `AIService` where unhandled exceptions in `start()`,
`stop()`, or `cancel()` implementations would prevent `process_frame()` to
continue and therefore `StartFrame`, `EndFrame`, or `CancelFrame` from being
pushed downstream, causing the pipeline to not start or stop properly.
(PR [#3503](https://github.com/pipecat-ai/pipecat/pull/3503))
- Moved `NVIDIATTSService` and `NVIDIASTTService` client initialization from
constructor to `start()` for better error handling.
(PR [#3504](https://github.com/pipecat-ai/pipecat/pull/3504))
- Optimized `NVIDIATTSService` to process incoming audio frames immediately.
(PR [#3509](https://github.com/pipecat-ai/pipecat/pull/3509))
- Optimized `NVIDIASTTService` by removing unnecessary queue and task.
(PR [#3509](https://github.com/pipecat-ai/pipecat/pull/3509))
- Fixed a `CambTTSService` issue where client was being initialized in the
constructor which wouldn't allow for proper Pipeline error handling.
(PR [#3511](https://github.com/pipecat-ai/pipecat/pull/3511))
## [0.0.99] - 2026-01-13
### Added

View File

@@ -81,7 +81,7 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
| Serializers | [Exotel](https://docs.pipecat.ai/server/utilities/serializers/exotel), [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx), [Vonage](https://docs.pipecat.ai/server/utilities/serializers/vonage) |
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/google-imagen), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter) |
| Analytics & Metrics | [OpenTelemetry](https://docs.pipecat.ai/server/utilities/opentelemetry), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |

View File

@@ -1 +0,0 @@
- Added Hathora service to support Hathora-hosted TTS and STT models (only non-streaming)

View File

@@ -1 +0,0 @@
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.

View File

@@ -1 +0,0 @@
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.

View File

@@ -1 +0,0 @@
- Added `CambTTSService`, using Camb.ai's TTS integration with MARS models (mars-flash, mars-pro, mars-instruct) for high-quality text-to-speech synthesis.

View File

@@ -1,8 +0,0 @@
- Fixed an issue where the "bot-llm-text" RTVI event would not fire for realtime (speech-to-speech) services:
- `AWSNovaSonicLLMService`
- `GeminiLiveLLMService`
- `OpenAIRealtimeLLMService`
- `GrokRealtimeLLMService`
The issue was that these services weren't pushing `LLMTextFrame`s. Now they do.

View File

@@ -1 +0,0 @@
- Fixed an issue where `on_user_turn_stop_timeout` could fire while a user is talking when using `ExternalUserTurnStrategies`.

View File

@@ -1 +0,0 @@
- Fixed an issue where user turn start strategies were not being reset after a user turn started, causing incorrect strategy behavior.

View File

@@ -1 +0,0 @@
- Added the `additional_headers` param to `WebsocketClientParams`, allowing `WebsocketClientTransport` to send custom headers on connect, for cases such as authentication.

View File

@@ -1 +0,0 @@
- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions, preventing incorrect turn starts when words are spoken with pauses between them.

View File

@@ -1 +0,0 @@
- For consistency with other package names, we just deprecated `pipecat.turns.mute` (introduced in Pipecat 0.0.99) in favor of `pipecat.turns.user_mute`.

View File

@@ -1 +0,0 @@
- Fixed an issue where Grok Realtime would error out when running with SmallWebRTC transport.

View File

@@ -1 +0,0 @@
- Added `UserIdleController` for detecting user idle state, integrated into `LLMUserAggregator` and `UserTurnProcessor` via optional `user_idle_timeout` parameter. Emits `on_user_turn_idle` event for application-level handling. Deprecated `UserIdleProcessor` in favor of the new compositional approach.

View File

@@ -1 +0,0 @@
- Throttle `UserSpeakingFrame` to broadcast at most every 200ms instead of on every audio chunk, reducing frame processing overhead during user speech.

View File

@@ -1 +0,0 @@
- Fixed a `Mem0MemoryService` issue where passing `async_mode: true` was causing an error. See https://docs.mem0.ai/platform/features/async-mode-default-change.

View File

@@ -1,3 +0,0 @@
- Fixed `AzureTTSService` transcript formatting issues:
- Punctuation now appears without extra spaces (e.g., "Hello!" instead of "Hello !")
- CJK languages (Chinese, Japanese, Korean) no longer have unwanted spaces between characters

View File

@@ -1 +0,0 @@
- Added `on_user_mute_started` and `on_user_mute_stopped` event handlers to `LLMUserAggregator` for tracking user mute state changes.

View File

@@ -0,0 +1 @@
- `SarvamSTTService` now defaults `vad_signals` and `high_vad_sensitivity` to `None` (omitted from connection parameters), improving latency by ~300ms compared to the previous defaults.

View File

@@ -0,0 +1 @@
- Improved the STT TTFB (Time To First Byte) measurement, reporting the delay between when the user stops speaking and when the final transcription is received. Note: Unlike traditional TTFB which measures from a discrete request, STT services receive continuous audio input—so we measure from speech end to final transcript, which captures the latency that matters for voice AI applications. In support of this change, added `finalized` field to `TranscriptionFrame` to indicate when a transcript is the final result for an utterance.

View File

@@ -0,0 +1 @@
- Added `add_reached_upstream_filter()` and `add_reached_downstream_filter()` methods to `PipelineTask` for appending frame types.

1
changelog/3510.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `reached_upstream_types` and `reached_downstream_types` read-only properties to `PipelineTask` for inspecting current frame filters.

View File

@@ -0,0 +1 @@
- Changed frame filter storage from tuples to sets in `PipelineTask`.

View File

@@ -0,0 +1 @@
- Added `RTVIProcessor.create_rtvi_observer()` factory method for creating RTVI observers.

View File

@@ -0,0 +1 @@
- Added `FrameProcessor.broadcast_frame_instance(frame)` method to broadcast a frame instance by extracting its fields and creating new instances for each direction.

1
changelog/3519.added.md Normal file
View File

@@ -0,0 +1 @@
- `PipelineTask` now automatically adds `RTVIProcessor` and registers `RTVIObserver` when `enable_rtvi=True` (default), simplifying pipeline setup.

View File

@@ -0,0 +1 @@
- Fixed `FrameProcessor.broadcast_frame()` to deep copy kwargs, preventing shared mutable references between the downstream and upstream frame instances.

1
changelog/3519.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Transports now properly broadcast `InputTransportMessageFrame` frames both upstream and downstream instead of only pushing downstream.

1
changelog/3520.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `video_out_codec` parameter to `TransportParams` allowing configuration of the preferred video codec (e.g., `"VP8"`, `"H264"`, `"H265"`) for video output in `DailyTransport`.

1
changelog/3523.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `location` parameter to Google TTS services (`GoogleHttpTTSService`, `GoogleTTSService`, `GeminiTTSService`) for regional endpoint support.

1
changelog/3525.added.md Normal file
View File

@@ -0,0 +1 @@
- Added new `PIPECAT_SMART_TURN_LOG_DATA` environment variable, which causes Smart Turn input data to be saved to disk

View File

@@ -0,0 +1,2 @@
- Changed default Inworld TTS model from `inworld-tts-1` to
`inworld-tts-1.5-max`.

View File

@@ -23,7 +23,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
@@ -93,12 +92,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
)
rtvi = RTVIProcessor()
pipeline = Pipeline(
[
transport.input(),
rtvi,
stt,
user_aggregator,
llm,
@@ -115,7 +111,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
enable_usage_metrics=True,
),
observers=[
RTVIObserver(rtvi),
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),

View File

@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
@@ -88,12 +87,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
)
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
pipeline = Pipeline(
[
transport.input(),
rtvi,
stt,
user_aggregator,
llm,
@@ -110,7 +106,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
enable_usage_metrics=True,
),
observers=[
RTVIObserver(rtvi),
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),

View File

@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
@@ -90,12 +89,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
)
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
pipeline = Pipeline(
[
transport.input(), # Transport user input
rtvi,
stt,
user_aggregator, # User responses
llm, # LLM
@@ -114,7 +110,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
observers=[
RTVIObserver(rtvi),
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
@@ -123,10 +118,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
],
)
@rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi):
await rtvi.set_bot_ready()
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")

View File

@@ -113,6 +113,14 @@ async def load_conversation(params: FunctionCallParams):
# "content": f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}",
# }
# )
# If the last message isn't from the user, add a message asking for a recap
if messages and messages[-1].get("role") != "user":
messages.append(
{
"role": "user",
"content": "Can you catch me up on what we were talking about?",
}
)
params.context.set_messages(messages)
await params.llm.reset_conversation()
# await params.llm.trigger_assistant_response()

View File

@@ -59,7 +59,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
@@ -255,12 +254,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
),
)
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
pipeline = Pipeline(
[
transport.input(),
rtvi,
stt,
user_aggregator,
memory,
@@ -278,12 +275,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
observers=[RTVIObserver(rtvi)],
)
@rtvi.event_handler("on_client_ready")
@task.rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi):
await rtvi.set_bot_ready()
# Get personalized greeting based on user memories. Can pass agent_id and run_id as per requirement of the application to manage short term memory or agent specific memory.
greeting = await get_initial_greeting(
memory_client=memory.memory_client, user_id=USER_ID, agent_id=None, run_id=None

View File

@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
@@ -87,8 +86,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
)
rtvi = RTVIProcessor()
pipeline = Pipeline(
[
transport.input(), # Transport user input
@@ -108,13 +105,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
enable_metrics=True,
enable_usage_metrics=True,
),
observers=[RTVIObserver(rtvi)],
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@rtvi.event_handler("on_client_ready")
@task.rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi):
await rtvi.set_bot_ready()
# Kick off the conversation
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2025, Daily
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
@@ -125,14 +124,10 @@ async def run_bot(pipecat_transport):
),
)
# RTVI events for Pipecat client UI
rtvi = RTVIProcessor()
pipeline = Pipeline(
[
pipecat_transport.input(),
user_aggregator,
rtvi,
llm, # LLM
EdgeDetectionProcessor(
pipecat_transport._params.video_out_width,
@@ -149,13 +144,11 @@ async def run_bot(pipecat_transport):
enable_metrics=True,
enable_usage_metrics=True,
),
observers=[RTVIObserver(rtvi)],
)
@rtvi.event_handler("on_client_ready")
@task.rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi):
logger.info("Pipecat client ready.")
await rtvi.set_bot_ready()
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])

View File

@@ -293,12 +293,13 @@ async def run_eval_pipeline(
"You should only call the eval function if:\n"
"- The user explicitly attempts to answer the question, AND\n"
f"- Their answer can be cleanly evaluated using: {eval_config.eval}\n"
"Ignore greetings, comments, non-answers, or requests for clarification."
"Ignore greetings, comments, non-answers, or requests for clarification.\n"
"Numerical word answers are allowed (e.g., 'five' is the same as '5').\n"
)
if eval_config.eval_speaks_first:
system_prompt = f"You are an evaluation agent, be extremly brief. Numerical word answers are allowed. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
system_prompt = f"You are an evaluation agent, be extremly brief. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
else:
system_prompt = f"You are an evaluation agent, be extremly brief. Numerical word answers are allowed. First, ask one question: {example_prompt}. {common_system_prompt}"
system_prompt = f"You are an evaluation agent, be extremly brief. First, ask one question: {example_prompt}. {common_system_prompt}"
messages = [
{

View File

@@ -16,6 +16,7 @@ import numpy as np
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
from pipecat.utils.env import env_truthy
try:
import onnxruntime as ort
@@ -48,6 +49,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
"""
super().__init__(**kwargs)
self._log_data = env_truthy("PIPECAT_SMART_TURN_LOG_DATA", default=False)
if not smart_turn_model_path:
# Load bundled model
model_name = "smart-turn-v3.2-cpu.onnx"
@@ -81,6 +84,49 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
logger.debug("Loaded Local Smart Turn v3.x")
def _write_audio_to_wav(
self, audio_array: np.ndarray, sample_rate: int = 16000, suffix: str = ""
) -> None:
"""Write audio data to a WAV file in a background thread.
Args:
audio_array: The audio data as a numpy array (float32, normalized to [-1, 1]).
sample_rate: The sample rate of the audio data.
suffix: Optional suffix to append to the filename (e.g., "_raw", "_padded").
"""
import os
import threading
import wave
from datetime import datetime
# Generate filename with current timestamp (millisecond precision)
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S.%f")[:-3]
log_dir = "./smart_turn_audio_log"
os.makedirs(log_dir, exist_ok=True)
filename = os.path.join(log_dir, f"{timestamp}{suffix}.wav")
# Make a copy of the audio data to avoid issues with the array being modified
audio_copy = audio_array.copy()
def write_wav():
try:
# Convert float32 audio to int16 for WAV file
audio_int16 = (audio_copy * 32767).astype(np.int16)
with wave.open(filename, "wb") as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 2 bytes for int16
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
logger.debug(f"Wrote audio to {filename}")
except Exception as e:
logger.error(f"Failed to write audio to {filename}: {e}")
# Start background thread to write the WAV file
thread = threading.Thread(target=write_wav, daemon=True)
thread.start()
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
"""Predict end-of-turn using local ONNX model."""
@@ -95,6 +141,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
return audio_array
audio_for_logging = audio_array
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
@@ -122,6 +170,10 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
# Make prediction (1 for Complete, 0 for Incomplete)
prediction = 1 if probability > 0.5 else 0
if self._log_data:
suffix = "_complete" if prediction == 1 else "_incomplete"
self._write_audio_to_wav(audio_for_logging, sample_rate=16000, suffix=suffix)
return {
"prediction": prediction,
"probability": probability,

View File

@@ -426,12 +426,15 @@ class TranscriptionFrame(TextFrame):
timestamp: When the transcription occurred.
language: Detected or specified language of the speech.
result: Raw result from the STT service.
finalized: Whether this is the final transcription for an utterance.
Set by STT services that support commit/finalize signals.
"""
user_id: str
timestamp: str
language: Optional[Language] = None
result: Optional[Any] = None
finalized: bool = False
def __str__(self):
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"

View File

@@ -15,7 +15,7 @@ import asyncio
import importlib.util
import os
from pathlib import Path
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Set, Tuple, Type
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
@@ -49,6 +49,7 @@ from pipecat.pipeline.pipeline import Pipeline, PipelineSink, PipelineSource
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.processors.frameworks.rtvi import RTVIObserverParams, RTVIProcessor
from pipecat.utils.asyncio.task_manager import BaseTaskManager, TaskManager, TaskManagerParams
from pipecat.utils.tracing.setup import is_tracing_available
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
@@ -225,9 +226,12 @@ class PipelineTask(BasePipelineTask):
conversation_id: Optional[str] = None,
enable_tracing: bool = False,
enable_turn_tracking: bool = True,
enable_rtvi: bool = True,
idle_timeout_frames: Tuple[Type[Frame], ...] = (BotSpeakingFrame, UserSpeakingFrame),
idle_timeout_secs: Optional[float] = IDLE_TIMEOUT_SECS,
observers: Optional[List[BaseObserver]] = None,
rtvi_processor: Optional[RTVIProcessor] = None,
rtvi_observer_params: Optional[RTVIObserverParams] = None,
task_manager: Optional[BaseTaskManager] = None,
):
"""Initialize the PipelineTask.
@@ -244,6 +248,7 @@ class PipelineTask(BasePipelineTask):
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
clock: Clock implementation for timing operations.
conversation_id: Optional custom ID for the conversation.
enable_rtvi: Whether to automatically add RTVI support to the pipeline.
enable_tracing: Whether to enable tracing.
enable_turn_tracking: Whether to enable turn tracking.
idle_timeout_frames: A tuple with the frames that should trigger an idle
@@ -252,6 +257,8 @@ class PipelineTask(BasePipelineTask):
None. If a pipeline is idle the pipeline task will be cancelled
automatically.
observers: List of observers for monitoring pipeline execution.
rtvi_observer_params: The RTVI observer parameter to use if RTVI is enabled.
rtvi_processor: The RTVI processor to add if RTVI is enabled.
task_manager: Optional task manager for handling asyncio tasks.
"""
super().__init__()
@@ -306,6 +313,16 @@ class PipelineTask(BasePipelineTask):
self._heartbeat_push_task: Optional[asyncio.Task] = None
self._heartbeat_monitor_task: Optional[asyncio.Task] = None
# RTVI support
self._rtvi = None
if enable_rtvi:
self._rtvi = rtvi_processor or RTVIProcessor()
observers.append(self._rtvi.create_rtvi_observer(params=rtvi_observer_params))
@self.rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi: RTVIProcessor):
await rtvi.set_bot_ready()
# This is the idle event. When selected frames are pushed from any
# processor we consider the pipeline is not idle. We use an observer
# which will be listening any part of the pipeline.
@@ -335,7 +352,8 @@ class PipelineTask(BasePipelineTask):
# allows us to receive and react to downstream frames.
source = PipelineSource(self._source_push_frame, name=f"{self}::Source")
sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
self._pipeline = Pipeline([pipeline], source=source, sink=sink)
processors = [self._rtvi, pipeline] if self._rtvi else [pipeline]
self._pipeline = Pipeline(processors, source=source, sink=sink)
# The task observer acts as a proxy to the provided observers. This way,
# we only need to pass a single observer (using the StartFrame) which
@@ -348,8 +366,8 @@ class PipelineTask(BasePipelineTask):
# in. This is mainly for efficiency reason because each event handler
# creates a task and most likely you only care about one or two frame
# types.
self._reached_upstream_types: Tuple[Type[Frame], ...] = ()
self._reached_downstream_types: Tuple[Type[Frame], ...] = ()
self._reached_upstream_types: Set[Type[Frame]] = set()
self._reached_downstream_types: Set[Type[Frame]] = set()
self._register_event_handler("on_frame_reached_upstream")
self._register_event_handler("on_frame_reached_downstream")
self._register_event_handler("on_idle_timeout")
@@ -398,6 +416,35 @@ class PipelineTask(BasePipelineTask):
"""
return self._turn_trace_observer
@property
def rtvi(self) -> RTVIProcessor:
"""Get the RTVI processor if RTVI is enabled.
Returns:
The RTVI processor added to the pipeline when RTVI is enabled.
"""
if not self._rtvi:
raise Exception(f"{self} RTVI is not enabled.")
return self._rtvi
@property
def reached_upstream_types(self) -> Tuple[Type[Frame], ...]:
"""Get the currently configured upstream frame type filters.
Returns:
Tuple of frame types that trigger the on_frame_reached_upstream event.
"""
return tuple(self._reached_upstream_types)
@property
def reached_downstream_types(self) -> Tuple[Type[Frame], ...]:
"""Get the currently configured downstream frame type filters.
Returns:
Tuple of frame types that trigger the on_frame_reached_downstream event.
"""
return tuple(self._reached_downstream_types)
def event_handler(self, event_name: str):
"""Decorator for registering event handlers.
@@ -441,7 +488,7 @@ class PipelineTask(BasePipelineTask):
Args:
types: Tuple of frame types to monitor for upstream events.
"""
self._reached_upstream_types = types
self._reached_upstream_types = set(types)
def set_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Set which frame types trigger the on_frame_reached_downstream event.
@@ -449,7 +496,23 @@ class PipelineTask(BasePipelineTask):
Args:
types: Tuple of frame types to monitor for downstream events.
"""
self._reached_downstream_types = types
self._reached_downstream_types = set(types)
def add_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Add frame types to trigger the on_frame_reached_upstream event.
Args:
types: Tuple of frame types to add to upstream monitoring.
"""
self._reached_upstream_types.update(types)
def add_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Add frame types to trigger the on_frame_reached_downstream event.
Args:
types: Tuple of frame types to add to downstream monitoring.
"""
self._reached_downstream_types.update(types)
def has_finished(self) -> bool:
"""Check if the pipeline task has finished execution.
@@ -749,7 +812,7 @@ class PipelineTask(BasePipelineTask):
pipeline to be stopped (e.g. EndTaskFrame) in which case we would send
an EndFrame down the pipeline.
"""
if isinstance(frame, self._reached_upstream_types):
if isinstance(frame, tuple(self._reached_upstream_types)):
await self._call_event_handler("on_frame_reached_upstream", frame)
if isinstance(frame, EndTaskFrame):
@@ -788,7 +851,7 @@ class PipelineTask(BasePipelineTask):
processors have handled the EndFrame and therefore we can exit the task
cleanly.
"""
if isinstance(frame, self._reached_downstream_types):
if isinstance(frame, tuple(self._reached_downstream_types)):
await self._call_event_handler("on_frame_reached_downstream", frame)
if isinstance(frame, StartFrame):

View File

@@ -12,7 +12,9 @@ management, and frame flow control mechanisms.
"""
import asyncio
import dataclasses
import traceback
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import (
@@ -779,8 +781,40 @@ class FrameProcessor(BaseObject):
frame_cls: The class of the frame to be broadcasted.
**kwargs: Keyword arguments to be passed to the frame's constructor.
"""
await self.push_frame(frame_cls(**kwargs))
await self.push_frame(frame_cls(**kwargs), FrameDirection.UPSTREAM)
await self.push_frame(frame_cls(**deepcopy(kwargs)))
await self.push_frame(frame_cls(**deepcopy(kwargs)), FrameDirection.UPSTREAM)
async def broadcast_frame_instance(self, frame: Frame):
"""Broadcasts a frame instance upstream and downstream.
This method creates two new frame instances copying all fields from the
original frame except `id` and `name`, which get fresh values.
Args:
frame: The frame instance to broadcast.
Note:
Prefer using `broadcast_frame()` when possible, as it is more
efficient. This method should only be used when you are not the
creator of the frame and need to broadcast an existing instance.
"""
frame_cls = type(frame)
init_fields = {f.name: getattr(frame, f.name) for f in dataclasses.fields(frame) if f.init}
extra_fields = {
f.name: getattr(frame, f.name)
for f in dataclasses.fields(frame)
if not f.init and f.name not in ("id", "name")
}
new_frame = frame_cls(**deepcopy(init_fields))
for k, v in deepcopy(extra_fields).items():
setattr(new_frame, k, v)
await self.push_frame(new_frame)
new_frame = frame_cls(**deepcopy(init_fields))
for k, v in deepcopy(extra_fields).items():
setattr(new_frame, k, v)
await self.push_frame(new_frame, FrameDirection.UPSTREAM)
async def __start(self, frame: StartFrame):
"""Handle the start frame to initialize processor state.
@@ -950,7 +984,8 @@ class FrameProcessor(BaseObject):
# Process current queue and keep UninterruptibleFrame frames.
while not self.__process_queue.empty():
item = self.__process_queue.get_nowait()
if isinstance(item, UninterruptibleFrame):
frame = item[0]
if isinstance(frame, UninterruptibleFrame):
new_queue.put_nowait(item)
self.__process_queue.task_done()

View File

@@ -1100,13 +1100,11 @@ class RTVIObserver(BaseObserver):
if (
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
and (direction == FrameDirection.DOWNSTREAM)
and self._params.user_speaking_enabled
):
await self._handle_interruptions(frame)
elif (
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
and (direction == FrameDirection.UPSTREAM)
and self._params.bot_speaking_enabled
):
await self._handle_bot_speaking(frame)
@@ -1413,6 +1411,18 @@ class RTVIProcessor(FrameProcessor):
self._registered_services[service.name] = service
def create_rtvi_observer(self, *, params: Optional[RTVIObserverParams] = None, **kwargs):
"""Creates a new RTVI Observer.
Args:
params: Settings to enable/disable specific messages.
**kwargs: Additional arguments passed to the observer.
Returns:
A new RTVI observer.
"""
return RTVIObserver(self, params=params, **kwargs)
async def set_client_ready(self):
"""Mark the client as ready and trigger the ready event."""
self._client_ready = True

View File

@@ -126,7 +126,7 @@ class ProtobufFrameSerializer(FrameSerializer):
if "pts" in args_dict:
del args_dict["pts"]
# Special handling for MessageFrame -> OutputTransportMessageUrgentFrame
# Special handling for MessageFrame -> InputTransportMessageFrame
if class_name == MessageFrame:
try:
msg = json.loads(args_dict["data"])

View File

@@ -148,11 +148,11 @@ class AIService(FrameProcessor):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self.start(frame)
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
await self._start(frame)
elif isinstance(frame, EndFrame):
await self.stop(frame)
await self._stop(frame)
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
"""Process frames from an async generator.
@@ -169,3 +169,21 @@ class AIService(FrameProcessor):
await self.push_error_frame(f)
else:
await self.push_frame(f)
async def _start(self, frame: StartFrame):
try:
await self.start(frame)
except Exception as e:
logger.error(f"{self}: exception processing {frame}: {e}")
async def _stop(self, frame: EndFrame):
try:
await self.stop(frame)
except Exception as e:
logger.error(f"{self}: exception processing {frame}: {e}")
async def _cancel(self, frame: CancelFrame):
try:
await self.cancel(frame)
except Exception as e:
logger.error(f"{self}: exception processing {frame}: {e}")

View File

@@ -161,7 +161,7 @@ class AssemblyAISTTService(WebsocketSTTService):
"""
await super().process_frame(frame, direction)
if isinstance(frame, VADUserStartedSpeakingFrame):
await self.start_ttfb_metrics()
pass
elif isinstance(frame, VADUserStoppedSpeakingFrame):
if (
self._vad_force_turn_endpoint
@@ -354,7 +354,6 @@ class AssemblyAISTTService(WebsocketSTTService):
"""Handle transcription results."""
if not message.transcript:
return
await self.stop_ttfb_metrics()
if message.end_of_turn and (
not self._connection_params.formatted_finals or message.turn_is_formatted
):

View File

@@ -296,6 +296,7 @@ class AWSNovaSonicLLMService(LLMService):
self._user_text_buffer = ""
self._assistant_text_buffer = ""
self._completed_tool_calls = set()
self._audio_input_started = False
file_path = files("pipecat.services.aws.nova_sonic").joinpath("ready.wav")
with wave.open(file_path.open("rb"), "rb") as wav_file:
@@ -532,14 +533,30 @@ class AWSNovaSonicLLMService(LLMService):
if system_instruction:
await self._send_text_event(text=system_instruction, role=Role.SYSTEM)
# Send conversation history
for message in llm_connection_params["messages"]:
# Send conversation history (except for the last message if it's from the
# user, which we'll send as interactive after starting audio input)
messages = llm_connection_params["messages"]
last_user_message = None
for i, message in enumerate(messages):
# logger.debug(f"Seeding conversation history with message: {message}")
await self._send_text_event(text=message.text, role=message.role)
is_last_message = i == len(messages) - 1
if is_last_message and message.role == Role.USER:
# Save for sending after audio input starts
last_user_message = message
else:
await self._send_text_event(text=message.text, role=message.role)
# Start audio input
await self._send_audio_input_start_event()
# Now send the last user message as interactive to trigger bot response
if last_user_message:
# logger.debug(
# f"Sending last user message as interactive to trigger bot response: {last_user_message}")
await self._send_text_event(
text=last_user_message.text, role=last_user_message.role, interactive=True
)
# Start receiving events
self._receive_task = self.create_task(self._receive_task_handler())
@@ -602,6 +619,7 @@ class AWSNovaSonicLLMService(LLMService):
self._user_text_buffer = ""
self._assistant_text_buffer = ""
self._completed_tool_calls = set()
self._audio_input_started = False
logger.info("Finished disconnecting")
except Exception as e:
@@ -727,8 +745,18 @@ class AWSNovaSonicLLMService(LLMService):
}}
'''
await self._send_client_event(audio_content_start)
self._audio_input_started = True
async def _send_text_event(self, text: str, role: Role):
async def _send_text_event(self, text: str, role: Role, interactive: bool = False):
"""Send a text event to the LLM.
Args:
text: The text content to send.
role: The role associated with the text (e.g., USER, ASSISTANT, SYSTEM).
interactive: Whether the content is interactive. Defaults to False.
False: conversation history or system instruction, sent prior to interactive audio
True: text input sent during (or at the start of) interactive audio
"""
if not self._stream or not self._prompt_name or not text:
return
@@ -741,7 +769,7 @@ class AWSNovaSonicLLMService(LLMService):
"promptName": "{self._prompt_name}",
"contentName": "{content_name}",
"type": "TEXT",
"interactive": true,
"interactive": {json.dumps(interactive)},
"role": "{role.value}",
"textInputConfiguration": {{
"mediaType": "text/plain"
@@ -779,7 +807,7 @@ class AWSNovaSonicLLMService(LLMService):
await self._send_client_event(text_content_end)
async def _send_user_audio_event(self, audio: bytes):
if not self._stream:
if not self._stream or not self._audio_input_started:
return
blob = base64.b64encode(audio)

View File

@@ -158,7 +158,6 @@ class AWSTranscribeSTTService(WebsocketSTTService):
await self._websocket.send(event_message)
# Start metrics after first chunk sent
await self.start_processing_metrics()
await self.start_ttfb_metrics()
except Exception as e:
yield ErrorFrame(error=f"Error sending audio: {e}")
@@ -470,7 +469,6 @@ class AWSTranscribeSTTService(WebsocketSTTService):
is_final = not result.get("IsPartial", True)
if transcript:
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(

View File

@@ -116,7 +116,6 @@ class AzureSTTService(STTService):
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
if self._audio_stream:
self._audio_stream.write(audio)
yield None
@@ -191,7 +190,6 @@ class AzureSTTService(STTService):
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
def _on_handle_recognized(self, event):

View File

@@ -199,9 +199,10 @@ class CambTTSService(TTSService):
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or CambTTSService.InputParams()
self._api_key = api_key
self._timeout = timeout
self._client = AsyncCambAI(api_key=api_key, timeout=timeout)
params = params or CambTTSService.InputParams()
# Warn if sample rate doesn't match model's supported rate
if sample_rate and sample_rate != MODEL_SAMPLE_RATES.get(model):
@@ -222,6 +223,8 @@ class CambTTSService(TTSService):
self.set_voice(str(voice_id))
self._voice_id = voice_id
self._client = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -249,6 +252,8 @@ class CambTTSService(TTSService):
"""
await super().start(frame)
self._client = AsyncCambAI(api_key=self._api_key, timeout=self._timeout)
# Use model-specific sample rate if not explicitly specified
if not self._init_sample_rate:
self._sample_rate = MODEL_SAMPLE_RATES.get(self.model_name, 22050)
@@ -289,6 +294,8 @@ class CambTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame()
assert self._client is not None, "Camb.ai TTS service not initialized"
# Buffer for aligning chunks to 2-byte boundaries (16-bit PCM)
audio_buffer = b""

View File

@@ -207,9 +207,8 @@ class CartesiaSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def start_metrics(self):
async def _start_metrics(self):
"""Start performance metrics collection for transcription processing."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -222,10 +221,13 @@ class CartesiaSTTService(WebsocketSTTService):
await super().process_frame(frame, direction)
if isinstance(frame, VADUserStartedSpeakingFrame):
await self.start_metrics()
# Reset finalize state for new utterance
self.set_finalize_pending(False)
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
# Send finalize command to flush the transcription session
if self._websocket and self._websocket.state is State.OPEN:
self.set_finalize_pending(True)
await self._websocket.send("finalize")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
@@ -342,7 +344,6 @@ class CartesiaSTTService(WebsocketSTTService):
pass
if len(transcript) > 0:
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(

View File

@@ -659,6 +659,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
average_confidence = self._calculate_average_confidence(data)
if not self._params.min_confidence or average_confidence > self._params.min_confidence:
# EndOfTurn means Flux has determined the turn is complete,
# so this TranscriptionFrame is always finalized
await self.push_frame(
TranscriptionFrame(
transcript,
@@ -666,6 +668,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
time_now_iso8601(),
self._language,
result=data,
finalized=True,
)
)
else:

View File

@@ -276,9 +276,8 @@ class DeepgramSTTService(STTService):
# GH issue: https://github.com/deepgram/deepgram-python-sdk/issues/570
await self._connection.finish()
async def start_metrics(self):
"""Start TTFB and processing metrics collection."""
await self.start_ttfb_metrics()
async def _start_metrics(self):
"""Start processing metrics collection for this utterance."""
await self.start_processing_metrics()
async def _on_error(self, *args, **kwargs):
@@ -292,7 +291,7 @@ class DeepgramSTTService(STTService):
await self._connect()
async def _on_speech_started(self, *args, **kwargs):
await self.start_metrics()
await self._start_metrics()
await self._call_event_handler("on_speech_started", *args, **kwargs)
await self.broadcast_frame(UserStartedSpeakingFrame)
if self._should_interrupt:
@@ -320,8 +319,12 @@ class DeepgramSTTService(STTService):
language = result.channel.alternatives[0].languages[0]
language = Language(language)
if len(transcript) > 0:
await self.stop_ttfb_metrics()
if is_final:
# Check if this response is from a finalize() call.
# Only mark as finalized when both we requested it AND Deepgram confirms it.
from_finalize = getattr(result, "from_finalize", False)
if from_finalize:
self.confirm_finalize()
await self.push_frame(
TranscriptionFrame(
transcript,
@@ -356,8 +359,10 @@ class DeepgramSTTService(STTService):
if isinstance(frame, VADUserStartedSpeakingFrame) and not self.vad_enabled:
# Start metrics if Deepgram VAD is disabled & pipeline VAD has detected speech
await self.start_metrics()
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
# https://developers.deepgram.com/docs/finalize
# Mark that we're awaiting a from_finalize response
self.request_finalize()
await self._connection.finalize()
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")

View File

@@ -363,9 +363,6 @@ class DeepgramSageMakerSTTService(STTService):
if not transcript.strip():
return
# Stop TTFB metrics on first transcript
await self.stop_ttfb_metrics()
is_final = parsed.get("is_final", False)
speech_final = parsed.get("speech_final", False)
@@ -417,9 +414,8 @@ class DeepgramSageMakerSTTService(STTService):
"""
pass
async def start_metrics(self):
"""Start TTFB and processing metrics collection."""
await self.start_ttfb_metrics()
async def _start_metrics(self):
"""Start processing metrics collection."""
await self.start_processing_metrics()
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -433,7 +429,7 @@ class DeepgramSageMakerSTTService(STTService):
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
if isinstance(frame, VADUserStartedSpeakingFrame):
await self.start_metrics()
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
# Send finalize message to Deepgram when user stops speaking
# This tells Deepgram to flush any remaining audio and return final results

View File

@@ -310,7 +310,6 @@ class ElevenLabsSTTService(SegmentedSTTService):
self, transcript: str, is_final: bool, language: Optional[str] = None
):
"""Handle a transcription result with tracing."""
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
@@ -328,7 +327,6 @@ class ElevenLabsSTTService(SegmentedSTTService):
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Upload audio and get transcription result directly
result = await self._transcribe_audio(audio)
@@ -539,9 +537,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
await super().cancel(frame)
await self._disconnect()
async def start_metrics(self):
async def _start_metrics(self):
"""Start performance metrics collection for transcription processing."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -554,13 +551,17 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
await super().process_frame(frame, direction)
if isinstance(frame, VADUserStartedSpeakingFrame):
# Reset finalize state for new utterance
self.set_finalize_pending(False)
# Start metrics when user starts speaking
await self.start_metrics()
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
# Send commit when user stops speaking (manual commit mode)
if self._params.commit_strategy == CommitStrategy.MANUAL:
if self._websocket and self._websocket.state is State.OPEN:
try:
# Mark that the next committed transcript should be finalized
self.set_finalize_pending(True)
commit_message = {
"message_type": "input_audio_chunk",
"audio_base_64": "",
@@ -764,8 +765,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
if not text:
return
await self.stop_ttfb_metrics()
# Get language if provided
language = data.get("language_code")
@@ -803,7 +802,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
if not text:
return
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Get language if provided
@@ -845,7 +843,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
if not text:
return
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Get language if provided

View File

@@ -249,7 +249,6 @@ class FalSTTService(SegmentedSTTService):
self, transcript: str, is_final: bool, language: Optional[str] = None
):
"""Handle a transcription result with tracing."""
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
@@ -267,7 +266,6 @@ class FalSTTService(SegmentedSTTService):
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Send to Fal directly (audio is already in WAV format from base class)
data_uri = fal_client.encode(audio, "audio/x-wav")

View File

@@ -385,7 +385,6 @@ class GladiaSTTService(WebsocketSTTService):
Yields:
None (processing is handled asynchronously via WebSocket).
"""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
# Add audio to buffer
@@ -513,7 +512,6 @@ class GladiaSTTService(WebsocketSTTService):
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[str] = None
):
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
async def _on_speech_started(self):

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Google RTVI integration models and observer implementation.
"""Google RTVI processor and observer implementation.
This module provides integration with Google's services through the RTVI framework,
including models for search responses and an observer for handling Google-specific
@@ -16,7 +16,7 @@ from typing import List, Literal, Optional
from pydantic import BaseModel
from pipecat.observers.base_observer import FramePushed
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIObserverParams, RTVIProcessor
from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame
@@ -86,4 +86,23 @@ class GoogleRTVIObserver(RTVIObserver):
rendered_content=frame.rendered_content,
)
)
await self.push_transport_message_urgent(message)
await self.send_rtvi_message(message)
class GoogleRTVIProcessor(RTVIProcessor):
"""RTVI processor for Google service integration.
Creates a specific Google RTVI Observer.
"""
def create_rtvi_observer(self, *, params: Optional[RTVIObserverParams] = None, **kwargs):
"""Creates a new RTVI Observer.
Args:
params: Settings to enable/disable specific messages.
**kwargs: Additional arguments passed to the observer.
Returns:
A new RTVI observer.
"""
return GoogleRTVIObserver(self)

View File

@@ -823,7 +823,6 @@ class GoogleSTTService(STTService):
"""
if self._streaming_task:
# Queue the audio data
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._request_queue.put(audio)
yield None
@@ -875,7 +874,6 @@ class GoogleSTTService(STTService):
)
else:
self._last_transcript_was_final = False
await self.stop_ttfb_metrics()
await self.push_frame(
InterimTranscriptionFrame(
transcript,

View File

@@ -40,6 +40,7 @@ from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language, resolve_language
try:
from google.api_core.client_options import ClientOptions
from google.auth import default
from google.auth.exceptions import GoogleAuthError
from google.cloud import texttospeech_v1
@@ -515,6 +516,7 @@ class GoogleHttpTTSService(TTSService):
*,
credentials: Optional[str] = None,
credentials_path: Optional[str] = None,
location: Optional[str] = None,
voice_id: str = "en-US-Chirp3-HD-Charon",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
@@ -525,6 +527,7 @@ class GoogleHttpTTSService(TTSService):
Args:
credentials: JSON string containing Google Cloud service account credentials.
credentials_path: Path to Google Cloud service account JSON file.
location: Google Cloud location for regional endpoint (e.g., "us-central1").
voice_id: Google TTS voice identifier (e.g., "en-US-Standard-A").
sample_rate: Audio sample rate in Hz. If None, uses default.
params: Voice customization parameters including pitch, rate, volume, etc.
@@ -534,6 +537,7 @@ class GoogleHttpTTSService(TTSService):
params = params or GoogleHttpTTSService.InputParams()
self._location = location
self._settings = {
"pitch": params.pitch,
"rate": params.rate,
@@ -586,7 +590,15 @@ class GoogleHttpTTSService(TTSService):
if not creds:
raise ValueError("No valid credentials provided.")
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
client_options = None
if self._location:
client_options = ClientOptions(
api_endpoint=f"{self._location}-texttospeech.googleapis.com"
)
return texttospeech_v1.TextToSpeechAsyncClient(
credentials=creds, client_options=client_options
)
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -783,7 +795,15 @@ class GoogleBaseTTSService(TTSService):
if not creds:
raise ValueError("No valid credentials provided.")
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
client_options = None
if self._location:
client_options = ClientOptions(
api_endpoint=f"{self._location}-texttospeech.googleapis.com"
)
return texttospeech_v1.TextToSpeechAsyncClient(
credentials=creds, client_options=client_options
)
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -903,6 +923,7 @@ class GoogleTTSService(GoogleBaseTTSService):
*,
credentials: Optional[str] = None,
credentials_path: Optional[str] = None,
location: Optional[str] = None,
voice_id: str = "en-US-Chirp3-HD-Charon",
voice_cloning_key: Optional[str] = None,
sample_rate: Optional[int] = None,
@@ -914,6 +935,7 @@ class GoogleTTSService(GoogleBaseTTSService):
Args:
credentials: JSON string containing Google Cloud service account credentials.
credentials_path: Path to Google Cloud service account JSON file.
location: Google Cloud location for regional endpoint (e.g., "us-central1").
voice_id: Google TTS voice identifier (e.g., "en-US-Chirp3-HD-Charon").
voice_cloning_key: The voice cloning key for Chirp 3 custom voices.
sample_rate: Audio sample rate in Hz. If None, uses default.
@@ -924,6 +946,7 @@ class GoogleTTSService(GoogleBaseTTSService):
params = params or GoogleTTSService.InputParams()
self._location = location
self._settings = {
"language": self.language_to_service_language(params.language)
if params.language
@@ -1083,6 +1106,7 @@ class GeminiTTSService(GoogleBaseTTSService):
model: str = "gemini-2.5-flash-tts",
credentials: Optional[str] = None,
credentials_path: Optional[str] = None,
location: Optional[str] = None,
voice_id: str = "Kore",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
@@ -1101,6 +1125,7 @@ class GeminiTTSService(GoogleBaseTTSService):
"gemini-2.5-flash-tts" or "gemini-2.5-pro-tts".
credentials: JSON string containing Google Cloud service account credentials.
credentials_path: Path to Google Cloud service account JSON file.
location: Google Cloud location for regional endpoint (e.g., "us-central1").
voice_id: Voice name from the available Gemini voices.
sample_rate: Audio sample rate in Hz. If None, uses Google's default 24kHz.
params: TTS configuration parameters.
@@ -1127,6 +1152,7 @@ class GeminiTTSService(GoogleBaseTTSService):
if voice_id not in self.AVAILABLE_VOICES:
logger.warning(f"Voice '{voice_id}' not in known voices list. Using anyway.")
self._location = location
self._model = model
self._voice_id = voice_id
self._settings = {

View File

@@ -122,7 +122,6 @@ class GradiumSTTService(WebsocketSTTService):
None (processing handled via WebSocket messages).
"""
self._audio_buffer.extend(audio)
await self.start_ttfb_metrics()
await self.start_processing_metrics()
while len(self._audio_buffer) >= self._chunk_size_bytes:

View File

@@ -111,7 +111,6 @@ class HathoraSTTService(SegmentedSTTService):
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
url = f"{self._base_url}"
@@ -153,7 +152,6 @@ class HathoraSTTService(SegmentedSTTService):
result=response,
)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
except Exception as e:

View File

@@ -72,7 +72,7 @@ class InworldHttpTTSService(WordTTSService):
api_key: str,
aiohttp_session: aiohttp.ClientSession,
voice_id: str = "Ashley",
model: str = "inworld-tts-1",
model: str = "inworld-tts-1.5-max",
streaming: bool = True,
sample_rate: Optional[int] = None,
encoding: str = "LINEAR16",
@@ -427,7 +427,7 @@ class InworldTTSService(AudioContextWordTTSService):
*,
api_key: str,
voice_id: str = "Ashley",
model: str = "inworld-tts-1",
model: str = "inworld-tts-1.5-max",
url: str = "wss://api.inworld.ai/tts/v1/voice:streamBidirectional",
sample_rate: Optional[int] = None,
encoding: str = "LINEAR16",

View File

@@ -134,6 +134,7 @@ class NvidiaSTTService(STTService):
params = params or NvidiaSTTService.InputParams()
self._server = server
self._api_key = api_key
self._use_ssl = use_ssl
self._profanity_filter = False
@@ -162,18 +163,53 @@ class NvidiaSTTService(STTService):
self.set_model_name(model_function_map.get("model_name"))
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
self._asr_service = riva.client.ASRService(auth)
self._asr_service = None
self._queue = None
self._config = None
self._thread_task = None
self._response_task = None
def _initialize_client(self):
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
def _create_recognition_config(self):
"""Create the NVIDIA Riva ASR recognition configuration."""
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self.sample_rate,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
return config
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -206,49 +242,15 @@ class NvidiaSTTService(STTService):
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_recognition_config()
if self._config:
return
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self.sample_rate,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
self._queue = asyncio.Queue()
if not self._thread_task:
self._thread_task = self.create_task(self._thread_task_handler())
if not self._response_task:
self._response_queue = asyncio.Queue()
self._response_task = self.create_task(self._response_task_handler())
logger.debug(f"Initialized NvidiaSTTService with model: {self.model_name}")
async def stop(self, frame: EndFrame):
"""Stop the NVIDIA Riva STT service and clean up resources.
@@ -273,10 +275,6 @@ class NvidiaSTTService(STTService):
await self.cancel_task(self._thread_task)
self._thread_task = None
if self._response_task:
await self.cancel_task(self._response_task)
self._response_task = None
def _response_handler(self):
responses = self._asr_service.streaming_response_generator(
audio_chunks=self,
@@ -285,9 +283,7 @@ class NvidiaSTTService(STTService):
for response in responses:
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
self._response_queue.put(response), self.get_event_loop()
)
asyncio.run_coroutine_threadsafe(self._handle_response(response), self.get_event_loop())
async def _thread_task_handler(self):
try:
@@ -311,7 +307,6 @@ class NvidiaSTTService(STTService):
transcript = result.alternatives[0].transcript
if transcript and len(transcript) > 0:
await self.stop_ttfb_metrics()
if result.is_final:
await self.stop_processing_metrics()
await self.push_frame(
@@ -339,12 +334,6 @@ class NvidiaSTTService(STTService):
)
)
async def _response_task_handler(self):
while True:
response = await self._response_queue.get()
await self._handle_response(response)
self._response_queue.task_done()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text transcription.
@@ -354,7 +343,6 @@ class NvidiaSTTService(STTService):
Yields:
None - transcription results are pushed to the pipeline via frames.
"""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._queue.put(audio)
yield None
@@ -503,8 +491,6 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
def _create_recognition_config(self):
"""Create the NVIDIA Riva ASR recognition configuration."""
# Create base configuration
@@ -572,6 +558,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
await super().start(frame)
self._initialize_client()
self._config = self._create_recognition_config()
logger.debug(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
async def set_language(self, language: Language):
"""Set the language for the STT service.
@@ -605,65 +592,51 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
Frame: TranscriptionFrame containing the transcribed text.
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Make sure the client is initialized
if self._asr_service is None:
self._initialize_client()
# Make sure the config is created
if self._config is None:
self._config = self._create_recognition_config()
# Type assertion to satisfy the IDE
assert self._asr_service is not None, "ASR service not initialized"
assert self._config is not None, "Recognition config not created"
await self.start_processing_metrics()
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Process the response - handle different possible return types
try:
# If it's a future-like object, get the result
if hasattr(raw_response, "result"):
response = raw_response.result()
else:
response = raw_response
# If it's a future-like object, get the result
if hasattr(raw_response, "result"):
response = raw_response.result()
else:
response = raw_response
# Process transcription results
transcription_found = False
# Process transcription results
transcription_found = False
# Now we can safely check results
# Type hint for the IDE
results = getattr(response, "results", [])
# Now we can safely check results
# Type hint for the IDE
results = getattr(response, "results", [])
for result in results:
alternatives = getattr(result, "alternatives", [])
if alternatives:
text = alternatives[0].transcript.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
self._language_enum,
)
transcription_found = True
for result in results:
alternatives = getattr(result, "alternatives", [])
if alternatives:
text = alternatives[0].transcript.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
self._language_enum,
)
transcription_found = True
await self._handle_transcription(text, True, self._language_enum)
if not transcription_found:
logger.debug("No transcription results found in NVIDIA Riva response")
except AttributeError as ae:
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
await self._handle_transcription(text, True, self._language_enum)
if not transcription_found:
logger.debug(f"{self}: No transcription results found in NVIDIA Riva response")
except AttributeError as ae:
logger.error(f"{self}: Unexpected response structure from NVIDIA Riva: {ae}")
yield ErrorFrame(f"{self}: Unexpected NVIDIA Riva response format: {str(ae)}")
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")

View File

@@ -12,7 +12,7 @@ gRPC API for high-quality speech synthesis.
import asyncio
import os
from typing import AsyncGenerator, Mapping, Optional
from typing import AsyncGenerator, AsyncIterable, Generator, Mapping, Optional
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -25,6 +25,7 @@ from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -34,14 +35,12 @@ from pipecat.transcriptions.language import Language
try:
import riva.client
import riva.client.proto.riva_tts_pb2 as rtts
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
raise Exception(f"Missing module: {e}")
NVIDIA_TTS_TIMEOUT_SECS = 5
class NvidiaTTSService(TTSService):
"""NVIDIA Riva text-to-speech service.
@@ -93,6 +92,7 @@ class NvidiaTTSService(TTSService):
params = params or NvidiaTTSService.InputParams()
self._server = server
self._api_key = api_key
self._voice_id = voice_id
self._language_code = params.language
@@ -102,18 +102,8 @@ class NvidiaTTSService(TTSService):
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
config_response = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
self._service = None
self._config = None
async def set_model(self, model: str):
"""Attempt to set the TTS model.
@@ -129,6 +119,39 @@ class NvidiaTTSService(TTSService):
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
def _initialize_client(self):
if self._service is not None:
return
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
def _create_synthesis_config(self):
if not self._service:
return
# warm up the service
config = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
return config
async def start(self, frame: StartFrame):
"""Start the Cartesia TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_synthesis_config()
logger.debug(f"Initialized NvidiaTTSService with model: {self.model_name}")
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
@@ -140,39 +163,43 @@ class NvidiaTTSService(TTSService):
Frame: Audio frames containing the synthesized speech data.
"""
def read_audio_responses(queue: asyncio.Queue):
def add_response(r):
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
def read_audio_responses() -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
responses = self._service.synthesize_online(
text,
self._voice_id,
self._language_code,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._quality,
custom_dictionary={},
)
return responses
def async_next(it):
try:
responses = self._service.synthesize_online(
text,
self._voice_id,
self._language_code,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._quality,
custom_dictionary={},
)
for r in responses:
add_response(r)
add_response(None)
except Exception as e:
logger.error(f"{self} exception: {e}")
add_response(None)
return next(it)
except StopIteration:
return None
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
async def async_iterator(iterator) -> AsyncIterable[rtts.SynthesizeSpeechResponse]:
while True:
item = await asyncio.to_thread(async_next, iterator)
if item is None:
return
yield item
try:
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)
assert self._service is not None, "TTS service not initialized"
assert self._config is not None, "Synthesis configuration not created"
# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
while resp:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
responses = await asyncio.to_thread(read_audio_responses)
async for resp in async_iterator(responses):
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
@@ -180,10 +207,12 @@ class NvidiaTTSService(TTSService):
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")

View File

@@ -10,7 +10,7 @@ This module provides an OpenAI-compatible interface for interacting with OpenRou
extending the base OpenAI LLM service functionality.
"""
from typing import Optional
from typing import Any, Dict, Optional
from loguru import logger
@@ -61,3 +61,35 @@ class OpenRouterLLMService(OpenAILLMService):
"""
logger.debug(f"Creating OpenRouter client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
def build_chat_completion_params(self, params_from_context: Dict[str, Any]) -> Dict[str, Any]:
"""Builds chat parameters, handling model-specific constraints.
Args:
params_from_context: Parameters from the LLM context.
Returns:
Transformed parameters ready for the API call.
"""
params = super().build_chat_completion_params(params_from_context)
model = getattr(self, "model_name", getattr(self, "model", "")).lower()
if "gemini" in model:
messages = params.get("messages", [])
if not messages:
return params
transformed_messages = []
system_message_seen = False
for msg in messages:
if msg.get("role") == "system":
if not system_message_seen:
transformed_messages.append(msg)
system_message_seen = True
else:
new_msg = msg.copy()
new_msg["role"] = "user"
transformed_messages.append(new_msg)
else:
transformed_messages.append(msg)
params["messages"] = transformed_messages
return params

View File

@@ -15,9 +15,15 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.sarvam._sdk import sdk_headers
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language, resolve_language
@@ -75,14 +81,14 @@ class SarvamSTTService(STTService):
language: Target language for transcription. Defaults to None (required for saarika models).
prompt: Optional prompt to guide translation style/context for STT-Translate models.
Only applicable to saaras (STT-Translate) models. Defaults to None.
vad_signals: Enable VAD signals in response. Defaults to True.
high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to False.
vad_signals: Enable VAD signals in response. Defaults to None.
high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to None.
"""
language: Optional[Language] = None
prompt: Optional[str] = None
vad_signals: bool = True
high_vad_sensitivity: bool = False
vad_signals: bool = None
high_vad_sensitivity: bool = None
def __init__(
self,
@@ -155,6 +161,7 @@ class SarvamSTTService(STTService):
self._websocket_context = None
self._socket_client = None
self._receive_task = None
logger.info(f"Sarvam STT initialized with SDK headers: {self._sdk_headers}")
def language_to_service_language(self, language: Language) -> str:
"""Convert pipecat Language enum to Sarvam's language code.
@@ -175,6 +182,24 @@ class SarvamSTTService(STTService):
"""
return True
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames.
Handles VAD frames for TTFB tracking when using Pipecat's VAD
instead of Sarvam's built-in VAD.
"""
await super().process_frame(frame, direction)
# Only handle VAD frames when not using Sarvam's VAD signals
if not self._vad_signals:
if isinstance(frame, VADUserStartedSpeakingFrame):
self.set_finalize_pending(False)
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
if self._socket_client:
self.set_finalize_pending(True)
await self._socket_client.flush()
async def set_language(self, language: Language):
"""Set the recognition language and reconnect.
@@ -411,16 +436,18 @@ class SarvamSTTService(STTService):
logger.debug(f"VAD Signal: {signal}, Occurred at: {timestamp}")
if signal == "START_SPEECH":
await self.start_metrics()
await self._start_metrics()
logger.debug("User started speaking")
await self._call_event_handler("on_speech_started")
await self.broadcast_frame(UserStartedSpeakingFrame)
await self.push_interruption_task_frame_and_wait()
elif signal == "END_SPEECH":
logger.debug("User stopped speaking")
await self._call_event_handler("on_speech_stopped")
await self.broadcast_frame(UserStoppedSpeakingFrame)
elif message.type == "data":
await self.stop_ttfb_metrics()
transcript = message.data.transcript
language_code = message.data.language_code
# Prefer language from message (auto-detected for translate models). Fallback to configured.
@@ -482,7 +509,6 @@ class SarvamSTTService(STTService):
}
return mapping.get(language_code, Language.HI_IN)
async def start_metrics(self):
"""Start TTFB and processing metrics collection."""
await self.start_ttfb_metrics()
async def _start_metrics(self):
"""Start processing metrics collection."""
await self.start_processing_metrics()

View File

@@ -21,7 +21,7 @@ from pipecat.frames.frames import (
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
UserStoppedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.stt_service import WebsocketSTTService
@@ -162,7 +162,7 @@ class SonioxSTTService(WebsocketSTTService):
sample_rate: Audio sample rate.
params: Additional configuration parameters, such as language hints, context and
speaker diarization.
vad_force_turn_endpoint: Listen to `UserStoppedSpeakingFrame` to send finalize message to Soniox. If disabled, Soniox will detect the end of the speech.
vad_force_turn_endpoint: Listen to `VADUserStoppedSpeakingFrame` to send finalize message to Soniox. If disabled, Soniox will detect the end of the speech.
**kwargs: Additional arguments passed to the STTService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -247,7 +247,7 @@ class SonioxSTTService(WebsocketSTTService):
"""
await super().process_frame(frame, direction)
if isinstance(frame, UserStoppedSpeakingFrame) and self._vad_force_turn_endpoint:
if isinstance(frame, VADUserStoppedSpeakingFrame) and self._vad_force_turn_endpoint:
# Send finalize message to Soniox so we get the final tokens asap.
if self._websocket and self._websocket.state is State.OPEN:
await self._websocket.send(FINALIZE_MESSAGE)
@@ -374,12 +374,15 @@ class SonioxSTTService(WebsocketSTTService):
async def send_endpoint_transcript():
if self._final_transcription_buffer:
text = "".join(map(lambda token: token["text"], self._final_transcription_buffer))
# Soniox only pushes TranscriptionFrame when an end token is received,
# so every TranscriptionFrame is inherently finalized
await self.push_frame(
TranscriptionFrame(
text=text,
user_id=self._user_id,
timestamp=time_now_iso8601(),
result=self._final_transcription_buffer,
finalized=True,
)
)
await self._handle_transcription(text, is_final=True)

View File

@@ -6,7 +6,9 @@
"""Base classes for Speech-to-Text services with continuous and segmented processing."""
import asyncio
import io
import time
import wave
from abc import abstractmethod
from typing import Any, AsyncGenerator, Dict, Mapping, Optional
@@ -17,12 +19,17 @@ from pipecat.frames.frames import (
AudioRawFrame,
ErrorFrame,
Frame,
InterruptionFrame,
MetricsFrame,
SpeechControlParamsFrame,
StartFrame,
STTMuteFrame,
STTUpdateSettingsFrame,
TranscriptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.metrics.metrics import TTFBMetricsData
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.services.websocket_service import WebsocketService
@@ -61,6 +68,8 @@ class STTService(AIService):
audio_passthrough=True,
# STT input sample rate
sample_rate: Optional[int] = None,
# STT TTFB timeout - time to wait after VAD stop before reporting TTFB
stt_ttfb_timeout: float = 2.0,
**kwargs,
):
"""Initialize the STT service.
@@ -70,6 +79,12 @@ class STTService(AIService):
Defaults to True.
sample_rate: The sample rate for audio input. If None, will be determined
from the start frame.
stt_ttfb_timeout: Time in seconds to wait after VAD stop before reporting
TTFB. This delay allows the final transcription to arrive. Defaults to 2.0.
Note: STT "TTFB" differs from traditional TTFB (which measures from a discrete
request to first response byte). Since STT receives continuous audio, we measure
from when the user stops speaking to when the final transcript arrives—capturing
the latency that matters for voice AI applications.
**kwargs: Additional arguments passed to the parent AIService.
"""
super().__init__(**kwargs)
@@ -81,6 +96,16 @@ class STTService(AIService):
self._muted: bool = False
self._user_id: str = ""
# STT TTFB tracking state
self._stt_ttfb_timeout = stt_ttfb_timeout
self._ttfb_timeout_task: Optional[asyncio.Task] = None
self._vad_stop_secs: Optional[float] = None
self._speech_end_time: Optional[float] = None
self._user_speaking: bool = False
self._last_transcription_time: Optional[float] = None
self._finalize_pending: bool = False
self._finalize_requested: bool = False
self._register_event_handler("on_connected")
self._register_event_handler("on_disconnected")
self._register_event_handler("on_connection_error")
@@ -94,6 +119,44 @@ class STTService(AIService):
"""
return self._muted
def set_finalize_pending(self, value: bool):
"""Set whether the next TranscriptionFrame should be marked as finalized.
When True, the next TranscriptionFrame pushed will have its `finalized`
field set to True, and this flag will automatically reset to False.
This is used to signal that a transcript is the final result for an
utterance, enabling immediate TTFB reporting.
Args:
value: True to mark the next transcription as finalized.
"""
self._finalize_pending = value
def request_finalize(self):
"""Mark that a finalize request has been sent, awaiting server confirmation.
For providers that require server confirmation before marking transcripts
as finalized (e.g., Deepgram's from_finalize field), call this when sending
the finalize request. Then call confirm_finalize() when the server confirms.
This is an alternative to set_finalize_pending() for providers that need
two-step finalization.
"""
self._finalize_requested = True
def confirm_finalize(self):
"""Confirm that the server has acknowledged the finalize request.
Call this when the server response confirms finalization (e.g., Deepgram's
from_finalize=True). The next TranscriptionFrame pushed will be marked
as finalized.
Only has effect if request_finalize() was previously called.
"""
if self._finalize_requested:
self._finalize_pending = True
self._finalize_requested = False
@property
def sample_rate(self) -> int:
"""Get the current sample rate for audio processing.
@@ -144,6 +207,11 @@ class STTService(AIService):
self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
self._tracing_enabled = frame.enable_tracing
async def cleanup(self):
"""Clean up STT service resources."""
await super().cleanup()
await self._cancel_ttfb_timeout()
async def _update_settings(self, settings: Mapping[str, Any]):
logger.info(f"Updating STT settings: {self._settings}")
for key, value in settings.items():
@@ -152,6 +220,8 @@ class STTService(AIService):
self._settings[key] = value
if key == "language":
await self.set_language(value)
elif key == "language":
await self.set_language(value)
elif key == "model":
self.set_model_name(value)
else:
@@ -204,14 +274,166 @@ class STTService(AIService):
await self.process_audio_frame(frame, direction)
if self._audio_passthrough:
await self.push_frame(frame, direction)
elif isinstance(frame, SpeechControlParamsFrame):
await self._handle_speech_control_params(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, VADUserStartedSpeakingFrame):
await self._handle_vad_user_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, VADUserStoppedSpeakingFrame):
await self._handle_vad_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, STTUpdateSettingsFrame):
await self._update_settings(frame.settings)
elif isinstance(frame, STTMuteFrame):
self._muted = frame.mute
logger.debug(f"STT service {'muted' if frame.mute else 'unmuted'}")
elif isinstance(frame, InterruptionFrame):
await self._reset_stt_ttfb_state()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame downstream, tracking TranscriptionFrame timestamps for TTFB.
Stores the timestamp of each TranscriptionFrame for TTFB calculation.
If the frame is marked as finalized (either directly or via set_finalize_pending),
reports TTFB immediately and cancels any pending timeout. Otherwise, TTFB is
reported after a timeout.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
if isinstance(frame, TranscriptionFrame):
# Store the transcription time for TTFB calculation
self._last_transcription_time = time.time()
# Set finalized from pending state and auto-reset
if self._finalize_pending:
frame.finalized = True
self._finalize_pending = False
# If this is a finalized transcription, report TTFB immediately
if frame.finalized and self._speech_end_time is not None:
ttfb = self._last_transcription_time - self._speech_end_time
await self._emit_stt_ttfb_metric(ttfb)
# Cancel the timeout since we've already reported
await self._cancel_ttfb_timeout()
# Clear state
self._speech_end_time = None
self._last_transcription_time = None
await super().push_frame(frame, direction)
async def _handle_speech_control_params(self, frame: SpeechControlParamsFrame):
"""Handle speech control parameters frame to extract VAD stop_secs.
Args:
frame: The speech control parameters frame.
"""
if frame.vad_params is not None:
self._vad_stop_secs = frame.vad_params.stop_secs
async def _cancel_ttfb_timeout(self):
"""Cancel any pending TTFB timeout task."""
if self._ttfb_timeout_task:
await self.cancel_task(self._ttfb_timeout_task)
self._ttfb_timeout_task = None
async def _reset_stt_ttfb_state(self):
"""Reset STT TTFB measurement state.
Called when starting a new utterance or on interruption to ensure
we don't use stale state for TTFB calculations. This specifically guards
against the case where a TranscriptionFrame is received without corresponding
VADUserStartedSpeakingFrame and VADUserStoppedSpeakingFrame frames.
Note: Does not reset _user_speaking since InterruptionFrame can arrive
while user is still speaking.
"""
await self._cancel_ttfb_timeout()
self._speech_end_time = None
self._last_transcription_time = None
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
"""Handle VAD user started speaking frame to start tracking transcriptions.
Cancels any pending TTFB timeout, resets TTFB tracking state, and marks user as speaking.
Args:
frame: The VAD user started speaking frame.
"""
await self._reset_stt_ttfb_state()
self._user_speaking = True
self._finalize_requested = False
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
"""Handle VAD user stopped speaking frame.
Calculates the actual speech end time and starts a timeout task to wait
for the final transcription before reporting TTFB.
Args:
frame: The VAD user stopped speaking frame.
"""
self._user_speaking = False
# Skip TTFB measurement if we don't have VAD params
if self._vad_stop_secs is None:
return
# Calculate the actual speech end time (current time minus VAD stop delay).
# This approximates when the last user audio was sent to the STT service,
# which we use to measure against the eventual transcription response.
self._speech_end_time = time.time() - self._vad_stop_secs
# Start timeout task (any previous timeout was cancelled by VADUserStartedSpeakingFrame
# or InterruptionFrame)
self._ttfb_timeout_task = self.create_task(
self._ttfb_timeout_handler(), name="stt_ttfb_timeout"
)
async def _ttfb_timeout_handler(self):
"""Wait for timeout then report TTFB using the last transcription timestamp.
This timeout allows the final transcription to arrive before we calculate
and report TTFB. If no transcription arrived, no TTFB is reported.
"""
try:
await asyncio.sleep(self._stt_ttfb_timeout)
# Report TTFB if we have both speech end time and transcription time
if self._speech_end_time is not None and self._last_transcription_time is not None:
ttfb = self._last_transcription_time - self._speech_end_time
await self._emit_stt_ttfb_metric(ttfb)
# Clear state after reporting
self._speech_end_time = None
self._last_transcription_time = None
except asyncio.CancelledError:
# Task was cancelled (new utterance or interruption), which is expected behavior
pass
finally:
self._ttfb_timeout_task = None
async def _emit_stt_ttfb_metric(self, ttfb: float):
"""Emit STT TTFB metric if value is non-negative.
Args:
ttfb: The TTFB value in seconds.
"""
if ttfb >= 0:
logger.debug(f"{self} TTFB: {ttfb:.3f}s")
if self.metrics_enabled:
ttfb_data = TTFBMetricsData(
processor=self.name,
model=self.model_name,
value=ttfb,
)
await super().push_frame(MetricsFrame(data=[ttfb_data]))
class SegmentedSTTService(STTService):
"""STT service that processes speech in segments using VAD events.
@@ -248,6 +470,20 @@ class SegmentedSTTService(STTService):
await super().start(frame)
self._audio_buffer_size_1s = self.sample_rate * 2
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame, marking TranscriptionFrames as finalized.
Segmented STT services process complete speech segments and return a single
TranscriptionFrame per segment, so every transcription is inherently finalized.
Args:
frame: The frame to push.
direction: The direction of frame flow in the pipeline.
"""
if isinstance(frame, TranscriptionFrame):
frame.finalized = True
await super().push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames, handling VAD events and audio segmentation."""
await super().process_frame(frame, direction)

View File

@@ -204,11 +204,9 @@ class BaseWhisperSTTService(SegmentedSTTService):
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
response = await self._transcribe(audio)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
text = response.text.strip()

View File

@@ -289,7 +289,6 @@ class WhisperSTTService(SegmentedSTTService):
return
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Divide by 32768 because we have signed 16-bit data.
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
@@ -303,7 +302,6 @@ class WhisperSTTService(SegmentedSTTService):
if segment.no_speech_prob < self._no_speech_prob:
text += f"{segment.text} "
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
if text:
@@ -388,7 +386,6 @@ class WhisperSTTServiceMLX(WhisperSTTService):
import mlx_whisper
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Divide by 32768 because we have signed 16-bit data.
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
@@ -413,7 +410,6 @@ class WhisperSTTServiceMLX(WhisperSTTService):
if len(text.strip()) == 0:
text = None
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
if text:

View File

@@ -123,9 +123,10 @@ class QueuedFrameProcessor(FrameProcessor):
async def run_test(
processor: FrameProcessor,
*,
frames_to_send: Sequence[Frame],
enable_rtvi: bool = False,
expected_down_frames: Optional[Sequence[type]] = None,
expected_up_frames: Optional[Sequence[type]] = None,
frames_to_send: Sequence[Frame],
ignore_start: bool = True,
observers: Optional[List[BaseObserver]] = None,
pipeline_params: Optional[PipelineParams] = None,
@@ -139,9 +140,10 @@ async def run_test(
Args:
processor: The frame processor to test.
frames_to_send: Sequence of frames to send through the processor.
enable_rtvi: Whether RTVI should be enabled in this test.
expected_down_frames: Expected frame types flowing downstream (optional).
expected_up_frames: Expected frame types flowing upstream (optional).
frames_to_send: Sequence of frames to send through the processor.
ignore_start: Whether to ignore StartFrames in frame validation.
observers: Optional list of observers to attach to the pipeline.
pipeline_params: Optional pipeline parameters.
@@ -173,9 +175,10 @@ async def run_test(
task = PipelineTask(
pipeline,
params=pipeline_params,
observers=observers,
cancel_on_idle_timeout=False,
enable_rtvi=enable_rtvi,
observers=observers,
params=pipeline_params,
)
async def push_frames():

View File

@@ -98,6 +98,7 @@ class TransportParams(BaseModel):
video_out_bitrate: Video output bitrate in bits per second.
video_out_framerate: Video output frame rate in FPS.
video_out_color_format: Video output color format string.
video_out_codec: Preferred video codec for output (e.g., 'VP8', 'H264', 'H265').
video_out_destinations: List of video output destination identifiers.
vad_enabled: Enable Voice Activity Detection (deprecated).
@@ -151,6 +152,7 @@ class TransportParams(BaseModel):
video_out_bitrate: int = 800000
video_out_framerate: int = 30
video_out_color_format: str = "RGB"
video_out_codec: Optional[str] = None
video_out_destinations: List[str] = Field(default_factory=list)
vad_enabled: bool = False
vad_audio_passthrough: bool = False

View File

@@ -759,7 +759,11 @@ class DailyTransportClient(EventHandler):
# Increment leave counter if we successfully joined.
self._leave_counter += 1
logger.info(f"Joined {self._room_url}")
participant_id = data.get("participants", {}).get("local", {}).get("id")
meeting_id = data.get("meetingSession", {}).get("id")
logger.info(
f"Joined {self._room_url}. Participant ID: {participant_id}, Meeting ID: {meeting_id}"
)
await self._callbacks.on_joined(data)
@@ -807,6 +811,11 @@ class DailyTransportClient(EventHandler):
"camera": {
"sendSettings": {
"maxQuality": "low",
**(
{"preferredCodec": self._params.video_out_codec}
if self._params.video_out_codec
else {}
),
"encodings": {
"low": {
"maxBitrate": self._params.video_out_bitrate,
@@ -1724,8 +1733,9 @@ class DailyInputTransport(BaseInputTransport):
message: The message data to send.
sender: ID of the message sender.
"""
frame = DailyInputTransportMessageFrame(message=message, participant_id=sender)
await self.push_frame(frame)
await self.broadcast_frame_class(
DailyInputTransportMessageFrame, message=message, participant_id=sender
)
#
# Audio in

View File

@@ -539,11 +539,14 @@ class LiveKitTransportClient:
elif track.kind == rtc.TrackKind.KIND_VIDEO:
logger.info(f"Video track subscribed: {track.sid} from participant {participant.sid}")
self._video_tracks[participant.sid] = track
video_stream = rtc.VideoStream(track)
self._task_manager.create_task(
self._process_video_stream(video_stream, participant.sid),
f"{self}::_process_video_stream",
)
# Only process video stream if video input is enabled to prevent
# unbounded queue growth when there is no consumer for video frames.
if self._params.video_in_enabled:
video_stream = rtc.VideoStream(track)
self._task_manager.create_task(
self._process_video_stream(video_stream, participant.sid),
f"{self}::_process_video_stream",
)
await self._callbacks.on_video_track_subscribed(participant.sid)
async def _async_on_track_unsubscribed(

View File

@@ -698,8 +698,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
message: The application message to process.
"""
logger.debug(f"Received app message inside SmallWebRTCInputTransport {message}")
frame = InputTransportMessageFrame(message=message)
await self.push_frame(frame)
await self.broadcast_frame_class(InputTransportMessageFrame, message=message)
# Add this method similar to DailyInputTransport.request_participant_image
async def request_participant_image(self, frame: UserImageRequestFrame):

View File

@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
InputTransportMessageFrame,
OutputAudioRawFrame,
OutputTransportMessageFrame,
OutputTransportMessageUrgentFrame,
@@ -298,6 +299,8 @@ class WebsocketClientInputTransport(BaseInputTransport):
return
if isinstance(frame, InputAudioRawFrame) and self._params.audio_in_enabled:
await self.push_audio_frame(frame)
elif isinstance(frame, InputTransportMessageFrame):
await self.broadcast_frame(frame)
else:
await self.push_frame(frame)

View File

@@ -26,6 +26,7 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
InputTransportMessageFrame,
InterruptionFrame,
OutputAudioRawFrame,
OutputTransportMessageFrame,
@@ -311,6 +312,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
if isinstance(frame, InputAudioRawFrame):
await self.push_audio_frame(frame)
elif isinstance(frame, InputTransportMessageFrame):
await self.broadcast_frame(frame)
else:
await self.push_frame(frame)
except Exception as e:

View File

@@ -25,6 +25,8 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
InputTransportMessageFrame,
InputTransportMessageUrgentFrame,
InterruptionFrame,
OutputAudioRawFrame,
OutputTransportMessageFrame,
@@ -214,6 +216,8 @@ class WebsocketServerInputTransport(BaseInputTransport):
if isinstance(frame, InputAudioRawFrame):
await self.push_audio_frame(frame)
elif isinstance(frame, InputTransportMessageFrame):
await self.broadcast_frame(frame)
else:
await self.push_frame(frame)
except Exception as e:

View File

@@ -51,6 +51,5 @@ class MuteUntilFirstBotCompleteUserMuteStrategy(BaseUserMuteStrategy):
return not self._first_speech_handled
async def _handle_bot_stopped_speaking(self, frame: BotStoppedSpeakingFrame):
self._bot_speaking = False
if not self._first_speech_handled:
self._first_speech_handled = True

53
src/pipecat/utils/env.py Normal file
View File

@@ -0,0 +1,53 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Environment variable helpers.
This module provides small, centralized parsing helpers for environment variables.
"""
from __future__ import annotations
import os
class InvalidEnvVarValueError(ValueError):
"""Raised when an environment variable value cannot be parsed."""
def __init__(self, name: str, value: str, expected: str):
"""Initialize an InvalidEnvVarValueError."""
super().__init__(f"Invalid value for env var {name!r}: {value!r}. Expected {expected}.")
self.name = name
self.value = value
self.expected = expected
def env_truthy(name: str, default: bool = False) -> bool:
"""Interpret an environment variable as a boolean.
- If the variable is **not set**, returns `default`.
- If the variable is set to a recognized boolean string, returns the parsed value.
- Otherwise, raises `InvalidEnvVarValueError`.
Recognized values (case-insensitive, whitespace ignored):
- Truthy: "1", "true", "yes", "y", "on"
- Falsy: "0", "false", "no", "n", "off", ""
"""
raw = os.getenv(name)
if raw is None:
return default
val = raw.strip().lower()
if val in {"1", "true", "yes", "y", "on"}:
return True
if val in {"0", "false", "no", "n", "off", ""}:
return False
raise InvalidEnvVarValueError(
name=name,
value=raw,
expected="true or false",
)

View File

@@ -6,7 +6,8 @@
import asyncio
import unittest
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List
from pipecat.frames.frames import (
DataFrame,
@@ -24,6 +25,15 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import SleepFrame, run_test
@dataclass
class BroadcastTestFrame(DataFrame):
"""Test frame with init fields for broadcast testing."""
text: str = ""
value: int = 0
items: List[str] = field(default_factory=list)
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
async def test_before_after_events(self):
identity = IdentityFilter()
@@ -186,3 +196,157 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
async def test_broadcast_frame(self):
"""Test that broadcast_frame creates two separate frames with fresh IDs."""
downstream_frames: List[Frame] = []
upstream_frames: List[Frame] = []
class BroadcastTestProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.broadcast_frame(
BroadcastTestFrame, text="hello", value=42, items=["a", "b"]
)
else:
await self.push_frame(frame, direction)
class CaptureProcessor(FrameProcessor):
def __init__(self, capture_list: List[Frame], direction: FrameDirection):
super().__init__()
self._capture_list = capture_list
self._capture_direction = direction
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if direction == self._capture_direction and isinstance(frame, BroadcastTestFrame):
self._capture_list.append(frame)
await self.push_frame(frame, direction)
up_capture = CaptureProcessor(upstream_frames, FrameDirection.UPSTREAM)
broadcaster = BroadcastTestProcessor()
down_capture = CaptureProcessor(downstream_frames, FrameDirection.DOWNSTREAM)
pipeline = Pipeline([up_capture, broadcaster, down_capture])
frames_to_send = [TextFrame(text="trigger")]
expected_down_frames = [BroadcastTestFrame]
expected_up_frames = [BroadcastTestFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
# Verify we got one frame in each direction
self.assertEqual(len(downstream_frames), 1)
self.assertEqual(len(upstream_frames), 1)
down_frame = downstream_frames[0]
up_frame = upstream_frames[0]
# Verify the frames have different IDs (they are separate instances)
self.assertNotEqual(down_frame.id, up_frame.id)
# Verify the frames have the correct field values
self.assertEqual(down_frame.text, "hello")
self.assertEqual(down_frame.value, 42)
self.assertEqual(down_frame.items, ["a", "b"])
self.assertEqual(up_frame.text, "hello")
self.assertEqual(up_frame.value, 42)
self.assertEqual(up_frame.items, ["a", "b"])
# Verify the items lists are separate instances (not shared references)
self.assertIsNot(down_frame.items, up_frame.items)
async def test_broadcast_frame_instance(self):
"""Test that broadcast_frame_instance copies all fields except id and name."""
downstream_frames: List[Frame] = []
upstream_frames: List[Frame] = []
original_frame: List[Frame] = []
class BroadcastInstanceTestProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, BroadcastTestFrame):
# Set some non-init fields on the frame
frame.pts = 12345
frame.metadata = {"key": "value", "nested": {"a": 1}}
original_frame.append(frame)
await self.broadcast_frame_instance(frame)
else:
await self.push_frame(frame, direction)
class CaptureProcessor(FrameProcessor):
def __init__(self, capture_list: List[Frame], direction: FrameDirection):
super().__init__()
self._capture_list = capture_list
self._capture_direction = direction
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if direction == self._capture_direction and isinstance(frame, BroadcastTestFrame):
self._capture_list.append(frame)
await self.push_frame(frame, direction)
up_capture = CaptureProcessor(upstream_frames, FrameDirection.UPSTREAM)
broadcaster = BroadcastInstanceTestProcessor()
down_capture = CaptureProcessor(downstream_frames, FrameDirection.DOWNSTREAM)
pipeline = Pipeline([up_capture, broadcaster, down_capture])
# Create a frame with mutable fields to test deep copying
test_frame = BroadcastTestFrame(text="test", value=99, items=["x", "y", "z"])
frames_to_send = [test_frame]
expected_down_frames = [BroadcastTestFrame]
expected_up_frames = [BroadcastTestFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
# Verify we got one frame in each direction
self.assertEqual(len(downstream_frames), 1)
self.assertEqual(len(upstream_frames), 1)
self.assertEqual(len(original_frame), 1)
orig = original_frame[0]
down_frame = downstream_frames[0]
up_frame = upstream_frames[0]
# Verify the frames have different IDs and names (fresh values)
self.assertNotEqual(down_frame.id, orig.id)
self.assertNotEqual(up_frame.id, orig.id)
self.assertNotEqual(down_frame.id, up_frame.id)
self.assertNotEqual(down_frame.name, orig.name)
self.assertNotEqual(up_frame.name, orig.name)
# Verify init fields are copied correctly
self.assertEqual(down_frame.text, "test")
self.assertEqual(down_frame.value, 99)
self.assertEqual(down_frame.items, ["x", "y", "z"])
self.assertEqual(up_frame.text, "test")
self.assertEqual(up_frame.value, 99)
self.assertEqual(up_frame.items, ["x", "y", "z"])
# Verify non-init fields (except id/name) are copied
self.assertEqual(down_frame.pts, 12345)
self.assertEqual(down_frame.metadata, {"key": "value", "nested": {"a": 1}})
self.assertEqual(up_frame.pts, 12345)
self.assertEqual(up_frame.metadata, {"key": "value", "nested": {"a": 1}})
# Verify mutable fields are deep copied (not shared references)
self.assertIsNot(down_frame.items, orig.items)
self.assertIsNot(up_frame.items, orig.items)
self.assertIsNot(down_frame.items, up_frame.items)
self.assertIsNot(down_frame.metadata, orig.metadata)
self.assertIsNot(up_frame.metadata, orig.metadata)
self.assertIsNot(down_frame.metadata, up_frame.metadata)
self.assertIsNot(down_frame.metadata["nested"], up_frame.metadata["nested"])

View File

@@ -0,0 +1,124 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for LiveKit transport video stream handling.
Regression tests for issue #3116: Memory leak when video_in_enabled=False
but video tracks are subscribed. The fix ensures video stream processing
only starts when there is a consumer for the frames.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
try:
from livekit import rtc
from pipecat.transports.livekit.transport import (
LiveKitCallbacks,
LiveKitParams,
LiveKitTransportClient,
)
LIVEKIT_AVAILABLE = True
except ImportError:
LIVEKIT_AVAILABLE = False
@unittest.skipUnless(LIVEKIT_AVAILABLE, "livekit package not installed")
class TestLiveKitVideoStreamMemoryLeak(unittest.IsolatedAsyncioTestCase):
"""Regression tests for video queue memory leak (#3116).
The bug: When video_in_enabled=False, subscribing to a video track would
start a producer that fills _video_queue, but no consumer would drain it,
causing unbounded memory growth (~3GB/min).
The fix: Only start video stream processing when video_in_enabled=True.
"""
def _create_client(self, video_in_enabled: bool) -> LiveKitTransportClient:
"""Create a client with the specified video input setting."""
params = LiveKitParams(video_in_enabled=video_in_enabled)
callbacks = LiveKitCallbacks(
on_connected=AsyncMock(),
on_disconnected=AsyncMock(),
on_before_disconnect=AsyncMock(),
on_participant_connected=AsyncMock(),
on_participant_disconnected=AsyncMock(),
on_audio_track_subscribed=AsyncMock(),
on_audio_track_unsubscribed=AsyncMock(),
on_video_track_subscribed=AsyncMock(),
on_video_track_unsubscribed=AsyncMock(),
on_data_received=AsyncMock(),
on_first_participant_joined=AsyncMock(),
)
client = LiveKitTransportClient(
url="wss://test.livekit.cloud",
token="test-token",
room_name="test-room",
params=params,
callbacks=callbacks,
transport_name="test-transport",
)
client._task_manager = MagicMock()
return client
def _create_mock_video_track(self):
"""Create a mock video track subscription event."""
track = MagicMock()
track.kind = rtc.TrackKind.KIND_VIDEO
track.sid = "video-track-123"
publication = MagicMock()
participant = MagicMock()
participant.sid = "participant-456"
return track, publication, participant
async def test_disabled_video_input_does_not_start_queue_producer(self):
"""When video input is disabled, no producer should fill the queue.
This prevents the memory leak where frames accumulate with no consumer.
"""
client = self._create_client(video_in_enabled=False)
track, publication, participant = self._create_mock_video_track()
await client._async_on_track_subscribed(track, publication, participant)
# Verify no video processing task was started
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
video_tasks = [name for name in task_names if "video" in name.lower()]
self.assertEqual(video_tasks, [], "No video processing task should be started")
# Queue should remain empty
self.assertEqual(client._video_queue.qsize(), 0)
# Track metadata should still be recorded
self.assertIn(participant.sid, client._video_tracks)
# Callback should still fire for user code
client._callbacks.on_video_track_subscribed.assert_called_once()
async def test_enabled_video_input_starts_queue_producer(self):
"""When video input is enabled, the producer should start."""
client = self._create_client(video_in_enabled=True)
track, publication, participant = self._create_mock_video_track()
with patch.object(rtc, "VideoStream"):
await client._async_on_track_subscribed(track, publication, participant)
# Verify video processing task was started
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
video_tasks = [name for name in task_names if "video" in name.lower()]
self.assertEqual(len(video_tasks), 1, "Video processing task should be started")
# Track metadata should be recorded
self.assertIn(participant.sid, client._video_tracks)
# Callback should fire
client._callbacks.on_video_track_subscribed.assert_called_once()
if __name__ == "__main__":
unittest.main()