Compare commits
37 Commits
v0.0.100
...
kompfner-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c861beb066 | ||
|
|
8951442b8e | ||
|
|
7e6e3031e7 | ||
|
|
308829f92b | ||
|
|
82a799e63e | ||
|
|
6b5bcae86f | ||
|
|
836073849c | ||
|
|
b13b65d6e2 | ||
|
|
3d545b718d | ||
|
|
f2fa5d9733 | ||
|
|
76b774072c | ||
|
|
b6341ffaa5 | ||
|
|
29fae67c9e | ||
|
|
718ea1c15e | ||
|
|
8e09d94614 | ||
|
|
de73e28563 | ||
|
|
55250b4f7e | ||
|
|
281145a991 | ||
|
|
7bd32e2fe5 | ||
|
|
8f05d95f50 | ||
|
|
87c12f3098 | ||
|
|
9c0bf89247 | ||
|
|
6e44a2ab49 | ||
|
|
7aa7b86aed | ||
|
|
5ad9faeb4c | ||
|
|
9e8f8b45c6 | ||
|
|
0ee11ad333 | ||
|
|
124a3c35af | ||
|
|
054e504868 | ||
|
|
e85a00cc0e | ||
|
|
cc61cdbba3 | ||
|
|
62f4708d43 | ||
|
|
ba0ddb1832 | ||
|
|
eacd2a4b71 | ||
|
|
7ed110650d | ||
|
|
4a724379fc | ||
|
|
1ceb01665f |
40
.claude/skills/changelog/SKILL.md
Normal file
40
.claude/skills/changelog/SKILL.md
Normal file
@@ -0,0 +1,40 @@
|
||||
---
|
||||
name: changelog
|
||||
description: Create changelog files for important commits in a PR
|
||||
---
|
||||
|
||||
Create changelog files for the important commits in this PR. The PR number is provided as an argument.
|
||||
|
||||
## Instructions
|
||||
|
||||
1. First, check what commits are on the current branch compared to main:
|
||||
```
|
||||
git log main..HEAD --oneline
|
||||
```
|
||||
|
||||
2. For each significant change, create a changelog file in the `changelog/` folder using the format:
|
||||
- `{PR_NUMBER}.added.md` - for new features
|
||||
- `{PR_NUMBER}.added.2.md`, `{PR_NUMBER}.added.3.md` - for additional new features
|
||||
- `{PR_NUMBER}.changed.md` - for changes to existing functionality
|
||||
- `{PR_NUMBER}.fixed.md` - for bug fixes
|
||||
- `{PR_NUMBER}.deprecated.md` - for deprecations
|
||||
|
||||
3. Each changelog file should at least contain a main single line starting with `- ` followed by a clear description of the change.
|
||||
|
||||
4. If the change is complicated, changelog files can have indented lines after the main line with additional details or code samples.
|
||||
|
||||
5. Use ⚠️ emoji prefix for breaking changes.
|
||||
|
||||
## Example
|
||||
|
||||
For PR #3519 with a new feature and a bug fix:
|
||||
|
||||
`changelog/3519.added.md`:
|
||||
```
|
||||
- Added `SomeNewFeature` for doing something useful.
|
||||
```
|
||||
|
||||
`changelog/3519.fixed.md`:
|
||||
```
|
||||
- Fixed an issue where something was not working correctly.
|
||||
```
|
||||
257
.claude/skills/docstring/SKILL.md
Normal file
257
.claude/skills/docstring/SKILL.md
Normal file
@@ -0,0 +1,257 @@
|
||||
---
|
||||
name: docstring
|
||||
description: Document a Python module and its classes using Google style
|
||||
---
|
||||
|
||||
Document a Python module and its classes using Google-style docstrings following project conventions. The class name is provided as an argument.
|
||||
|
||||
## Instructions
|
||||
|
||||
1. First, find the class in the codebase:
|
||||
```
|
||||
Search for "class ClassName" in src/pipecat/
|
||||
```
|
||||
|
||||
2. If multiple files contain that class name:
|
||||
- List all matches with their file paths
|
||||
- Ask the user which one they want to document
|
||||
- Wait for confirmation before proceeding
|
||||
|
||||
3. Once the file is identified, read the module to understand its structure:
|
||||
- Identify all classes, functions, and important type aliases
|
||||
- Understand the purpose of each component
|
||||
|
||||
4. Apply documentation in this order:
|
||||
- Module docstring (at top, after imports)
|
||||
- Class docstrings
|
||||
- `__init__` methods (always document constructor parameters)
|
||||
- Public methods (not starting with `_`)
|
||||
- Dataclass/config classes with field descriptions
|
||||
|
||||
5. Skip documentation for:
|
||||
- Private methods (starting with `_`)
|
||||
- Simple dunder methods (`__str__`, `__repr__`, `__post_init__`)
|
||||
- Very simple pass-through properties
|
||||
- **Already documented code** - If a class, method, or function already has a complete docstring that follows the project style, do not modify it. A docstring is complete if it has:
|
||||
- A one-line summary
|
||||
- Args section (if it has parameters)
|
||||
- Returns section (if it returns something meaningful)
|
||||
- Only add or improve documentation where it is missing or incomplete
|
||||
|
||||
## Module Docstring Format
|
||||
|
||||
```python
|
||||
"""[One-line description of module purpose].
|
||||
|
||||
[Optional: Longer explanation of functionality, key classes, or use cases.]
|
||||
"""
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
"""Neuphonic text-to-speech service implementations.
|
||||
|
||||
This module provides WebSocket and HTTP-based integrations with Neuphonic's
|
||||
text-to-speech API for real-time audio synthesis.
|
||||
"""
|
||||
```
|
||||
|
||||
## Class Docstring Format
|
||||
|
||||
```python
|
||||
class ClassName:
|
||||
"""One-line summary describing what the class does.
|
||||
|
||||
[Longer description explaining purpose, behavior, and key features.
|
||||
Use action-oriented language.]
|
||||
|
||||
[Optional: Event handlers, usage notes, or important caveats.]
|
||||
"""
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
class FrameProcessor(BaseObject):
|
||||
"""Base class for all frame processors in the pipeline.
|
||||
|
||||
Frame processors are the building blocks of Pipecat pipelines, they can be
|
||||
linked to form complex processing pipelines. They receive frames, process
|
||||
them, and pass them to the next or previous processor in the chain.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_before_process_frame: Called before a frame is processed
|
||||
- on_after_process_frame: Called after a frame is processed
|
||||
|
||||
Example::
|
||||
|
||||
@processor.event_handler("on_before_process_frame")
|
||||
async def on_before_process_frame(processor, frame):
|
||||
...
|
||||
|
||||
@processor.event_handler("on_after_process_frame")
|
||||
async def on_after_process_frame(processor, frame):
|
||||
...
|
||||
"""
|
||||
```
|
||||
|
||||
Note: When listing event handlers, do NOT use backticks. Include an `Example::` section (with double colon for Sphinx) showing the decorator pattern and function signature for each event.
|
||||
|
||||
## Constructor (`__init__`) Format
|
||||
|
||||
```python
|
||||
def __init__(self, *, param1: Type, param2: Type = default, **kwargs):
|
||||
"""Initialize the [ClassName].
|
||||
|
||||
Args:
|
||||
param1: Description of param1 and its purpose.
|
||||
param2: Description of param2. Defaults to [default].
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: Optional[str] = None,
|
||||
sample_rate: Optional[int] = 22050,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Neuphonic TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Neuphonic API key for authentication.
|
||||
voice_id: ID of the voice to use for synthesis.
|
||||
sample_rate: Audio sample rate in Hz. Defaults to 22050.
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService.
|
||||
"""
|
||||
```
|
||||
|
||||
## Method Docstring Format
|
||||
|
||||
```python
|
||||
async def method_name(self, param1: Type) -> ReturnType:
|
||||
"""One-line summary of what method does.
|
||||
|
||||
[Longer description if behavior isn't obvious.]
|
||||
|
||||
Args:
|
||||
param1: Description of param1.
|
||||
|
||||
Returns:
|
||||
Description of return value.
|
||||
|
||||
Raises:
|
||||
ExceptionType: When this exception is raised.
|
||||
"""
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
|
||||
"""Put an item into the priority queue.
|
||||
|
||||
System frames (`SystemFrame`) have higher priority than any other
|
||||
frames. If a non-frame item is provided it will have the highest priority.
|
||||
|
||||
Args:
|
||||
item: The item to enqueue.
|
||||
"""
|
||||
```
|
||||
|
||||
## Dataclass/Config Format
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ConfigName:
|
||||
"""One-line description of configuration.
|
||||
|
||||
[Explanation of when/how to use this config.]
|
||||
|
||||
Parameters:
|
||||
field1: Description of field1.
|
||||
field2: Description of field2. Defaults to [default].
|
||||
"""
|
||||
|
||||
field1: Type
|
||||
field2: Type = default_value
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
@dataclass
|
||||
class FrameProcessorSetup:
|
||||
"""Configuration parameters for frame processor initialization.
|
||||
|
||||
Parameters:
|
||||
clock: The clock instance for timing operations.
|
||||
task_manager: The task manager for handling async operations.
|
||||
observer: Optional observer for monitoring frame processing events.
|
||||
"""
|
||||
|
||||
clock: BaseClock
|
||||
task_manager: BaseTaskManager
|
||||
observer: Optional[BaseObserver] = None
|
||||
```
|
||||
|
||||
## Enum Documentation Format
|
||||
|
||||
```python
|
||||
class EnumName(Enum):
|
||||
"""One-line description of the enum purpose.
|
||||
|
||||
[Longer description of how the enum is used.]
|
||||
|
||||
Parameters:
|
||||
VALUE1: Description of VALUE1.
|
||||
VALUE2: Description of VALUE2.
|
||||
"""
|
||||
|
||||
VALUE1 = 1
|
||||
VALUE2 = 2
|
||||
```
|
||||
|
||||
## Writing Style Guidelines
|
||||
|
||||
- **Concise and professional** - No casual language or filler words
|
||||
- **Action-oriented** - Start with verbs: "Processes...", "Manages...", "Converts..."
|
||||
- **Purpose before implementation** - Explain WHY before HOW
|
||||
- **Clear parameter descriptions** - Include type hints, defaults, and purpose
|
||||
- **No redundant type info** - Type hints are in the signature, don't repeat in description
|
||||
- **Use backticks for code references** - Wrap class names, method names, event names, parameter names, and code snippets in backticks
|
||||
|
||||
Good: "Neuphonic API key for authentication."
|
||||
Bad: "str: The API key (string) that is used for authenticating with Neuphonic."
|
||||
|
||||
Good: "Triggers `on_speech_started` when the `VADAnalyzer` detects speech."
|
||||
Bad: "Triggers on_speech_started when the VADAnalyzer detects speech."
|
||||
|
||||
## Deprecation Notice Format
|
||||
|
||||
When documenting deprecated code:
|
||||
|
||||
```python
|
||||
"""[Description].
|
||||
|
||||
.. deprecated:: X.X.X
|
||||
`ClassName` is deprecated and will be removed in a future version.
|
||||
Use `NewClassName` instead.
|
||||
"""
|
||||
```
|
||||
|
||||
## Checklist
|
||||
|
||||
Before finishing, verify:
|
||||
|
||||
- [ ] Module has a docstring at the top (after copyright header and imports)
|
||||
- [ ] All public classes have docstrings
|
||||
- [ ] All `__init__` methods document their parameters
|
||||
- [ ] All public methods have docstrings with Args/Returns/Raises as needed
|
||||
- [ ] Dataclasses use "Parameters:" section for field descriptions
|
||||
- [ ] Enums document each value in "Parameters:" section
|
||||
- [ ] Writing is concise and action-oriented
|
||||
- [ ] No documentation added to private methods (starting with `_`)
|
||||
- [ ] Existing complete docstrings were left unchanged
|
||||
128
.claude/skills/pr-description/SKILL.md
Normal file
128
.claude/skills/pr-description/SKILL.md
Normal file
@@ -0,0 +1,128 @@
|
||||
---
|
||||
name: pr-description
|
||||
description: Update a GitHub PR description with a summary of changes
|
||||
---
|
||||
|
||||
Update a GitHub pull request description based on the changes in the PR.
|
||||
|
||||
## Arguments
|
||||
|
||||
```
|
||||
/pr-description <PR_NUMBER> [--fixes <ISSUE_NUMBERS>]
|
||||
```
|
||||
|
||||
- `PR_NUMBER` (required): The pull request number to update
|
||||
- `--fixes` (optional): Comma-separated issue numbers that this PR fixes (e.g., `--fixes 123,456`)
|
||||
|
||||
Examples:
|
||||
- `/pr-description 3534`
|
||||
- `/pr-description 3534 --fixes 123`
|
||||
- `/pr-description 3534 --fixes 123,456,789`
|
||||
|
||||
## Instructions
|
||||
|
||||
1. First, gather information about the PR:
|
||||
- Use GitHub plugin to get PR details (title, current description, base branch)
|
||||
- Use local git to get commits: `git log main..HEAD --oneline`
|
||||
- Use local git to get the diff: `git diff main..HEAD`
|
||||
- Parse any `--fixes` argument for issue numbers
|
||||
|
||||
2. Check the existing PR description:
|
||||
- If it already has a complete, accurate description that reflects the changes, do nothing
|
||||
- If it's missing sections, incomplete, or outdated compared to the actual changes, proceed to update
|
||||
- If it only has the template placeholder text, generate a full description
|
||||
|
||||
3. Analyze the changes:
|
||||
- Understand the purpose of each commit
|
||||
- Identify any breaking changes (API changes, removed features, behavior changes)
|
||||
- Look for new features, bug fixes, refactoring, or documentation changes
|
||||
- Collect issue numbers from:
|
||||
- The `--fixes` argument (if provided)
|
||||
- Commit messages (patterns like "Fixes #123", "Closes #456", "Resolves #789")
|
||||
|
||||
4. Generate or update the PR description with these sections:
|
||||
|
||||
## PR Description Format
|
||||
|
||||
### Summary (always include)
|
||||
|
||||
Brief bullet points describing what changed and why. Focus on the *purpose* and *impact*, not implementation details.
|
||||
|
||||
```markdown
|
||||
## Summary
|
||||
|
||||
- Added X to enable Y
|
||||
- Fixed bug where Z would happen
|
||||
- Refactored W for better maintainability
|
||||
```
|
||||
|
||||
### Breaking Changes (include only if applicable)
|
||||
|
||||
Document any changes that affect existing users or APIs.
|
||||
|
||||
```markdown
|
||||
## Breaking Changes
|
||||
|
||||
- `ClassName.method()` now requires a `param` argument
|
||||
- Removed deprecated `old_function()` - use `new_function()` instead
|
||||
```
|
||||
|
||||
### Testing (include when non-obvious)
|
||||
|
||||
How to verify the changes work. Skip for trivial changes.
|
||||
|
||||
```markdown
|
||||
## Testing
|
||||
|
||||
- Run `uv run pytest tests/test_feature.py` to verify the fix
|
||||
- Example usage: `uv run examples/new_feature.py`
|
||||
```
|
||||
|
||||
### Fixes (include if issues are provided or found in commits)
|
||||
|
||||
List issues this PR fixes. GitHub will automatically close these issues when the PR is merged.
|
||||
|
||||
```markdown
|
||||
## Fixes
|
||||
|
||||
- Fixes #123
|
||||
- Fixes #456
|
||||
```
|
||||
|
||||
Note: Use "Fixes #X" format (not "Closes" or "Resolves") for consistency. Each issue should be on its own line with "Fixes" to ensure GitHub auto-closes them.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Be concise** - Reviewers should understand the PR in 30 seconds
|
||||
- **Focus on why** - The diff shows *what* changed, explain *why*
|
||||
- **Skip empty sections** - Only include sections that have content
|
||||
- **Use bullet points** - Easier to scan than paragraphs
|
||||
- **Don't duplicate the diff** - Avoid listing every file or line changed
|
||||
|
||||
## Example Output
|
||||
|
||||
```markdown
|
||||
## Summary
|
||||
|
||||
- Added `/docstring` skill for documenting Python modules with Google-style docstrings
|
||||
- Skill finds classes by name and handles conflicts when multiple matches exist
|
||||
- Skips already-documented code to avoid unnecessary changes
|
||||
|
||||
## Testing
|
||||
|
||||
/docstring ClassName
|
||||
|
||||
## Fixes
|
||||
|
||||
- Fixes #123
|
||||
```
|
||||
|
||||
## Checklist
|
||||
|
||||
Before updating the PR:
|
||||
|
||||
- [ ] Verified existing description needs updating (not already complete)
|
||||
- [ ] Summary accurately reflects the changes
|
||||
- [ ] Breaking changes are clearly documented (if any)
|
||||
- [ ] No unnecessary sections included
|
||||
- [ ] Description is concise and scannable
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -4,7 +4,14 @@ __pycache__/
|
||||
*~
|
||||
venv
|
||||
.venv
|
||||
/.idea
|
||||
.idea
|
||||
.gradle
|
||||
.next
|
||||
next-env.d.ts
|
||||
local.properties
|
||||
*.log
|
||||
*.lock
|
||||
smart_turn_audio_log
|
||||
#*#
|
||||
|
||||
# Distribution / Packaging
|
||||
@@ -27,7 +34,7 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
.DS_Store
|
||||
.env
|
||||
.env*
|
||||
fly.toml
|
||||
|
||||
# Examples
|
||||
|
||||
@@ -81,7 +81,7 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
| Serializers | [Exotel](https://docs.pipecat.ai/server/utilities/serializers/exotel), [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx), [Vonage](https://docs.pipecat.ai/server/utilities/serializers/vonage) |
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/google-imagen), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter) |
|
||||
| Analytics & Metrics | [OpenTelemetry](https://docs.pipecat.ai/server/utilities/opentelemetry), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
|
||||
|
||||
1
changelog/3510.added.2.md
Normal file
1
changelog/3510.added.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `add_reached_upstream_filter()` and `add_reached_downstream_filter()` methods to `PipelineTask` for appending frame types.
|
||||
1
changelog/3510.added.md
Normal file
1
changelog/3510.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `reached_upstream_types` and `reached_downstream_types` read-only properties to `PipelineTask` for inspecting current frame filters.
|
||||
1
changelog/3510.changed.3.md
Normal file
1
changelog/3510.changed.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Changed frame filter storage from tuples to sets in `PipelineTask`.
|
||||
1
changelog/3519.added.2.md
Normal file
1
changelog/3519.added.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `RTVIProcessor.create_rtvi_observer()` factory method for creating RTVI observers.
|
||||
1
changelog/3519.added.3.md
Normal file
1
changelog/3519.added.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `FrameProcessor.broadcast_frame_instance(frame)` method to broadcast a frame instance by extracting its fields and creating new instances for each direction.
|
||||
1
changelog/3519.added.md
Normal file
1
changelog/3519.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- `PipelineTask` now automatically adds `RTVIProcessor` and registers `RTVIObserver` when `enable_rtvi=True` (default), simplifying pipeline setup.
|
||||
1
changelog/3519.fixed.2.md
Normal file
1
changelog/3519.fixed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `FrameProcessor.broadcast_frame()` to deep copy kwargs, preventing shared mutable references between the downstream and upstream frame instances.
|
||||
1
changelog/3519.fixed.md
Normal file
1
changelog/3519.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Transports now properly broadcast `InputTransportMessageFrame` frames both upstream and downstream instead of only pushing downstream.
|
||||
1
changelog/3520.added.md
Normal file
1
changelog/3520.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `video_out_codec` parameter to `TransportParams` allowing configuration of the preferred video codec (e.g., `"VP8"`, `"H264"`, `"H265"`) for video output in `DailyTransport`.
|
||||
1
changelog/3523.added.md
Normal file
1
changelog/3523.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `location` parameter to Google TTS services (`GoogleHttpTTSService`, `GoogleTTSService`, `GeminiTTSService`) for regional endpoint support.
|
||||
1
changelog/3525.added.md
Normal file
1
changelog/3525.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added new `SMART_TURN_LOG_DATA` environment variable, which causes Smart Turn input data to be saved to disk
|
||||
2
changelog/3531.changed.md
Normal file
2
changelog/3531.changed.md
Normal file
@@ -0,0 +1,2 @@
|
||||
- Changed default Inworld TTS model from `inworld-tts-1` to
|
||||
`inworld-tts-1.5-max`.
|
||||
@@ -23,7 +23,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
@@ -93,12 +92,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
rtvi = RTVIProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
rtvi,
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
@@ -115,7 +111,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
@@ -88,12 +87,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
rtvi,
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
@@ -110,7 +106,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
@@ -90,12 +89,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
rtvi,
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
@@ -114,7 +110,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
@@ -123,10 +118,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
],
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
async def on_client_ready(rtvi):
|
||||
await rtvi.set_bot_ready()
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
|
||||
@@ -59,7 +59,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
@@ -255,12 +254,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
),
|
||||
)
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
rtvi,
|
||||
stt,
|
||||
user_aggregator,
|
||||
memory,
|
||||
@@ -278,12 +275,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[RTVIObserver(rtvi)],
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
@task.rtvi.event_handler("on_client_ready")
|
||||
async def on_client_ready(rtvi):
|
||||
await rtvi.set_bot_ready()
|
||||
# Get personalized greeting based on user memories. Can pass agent_id and run_id as per requirement of the application to manage short term memory or agent specific memory.
|
||||
greeting = await get_initial_greeting(
|
||||
memory_client=memory.memory_client, user_id=USER_ID, agent_id=None, run_id=None
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -87,8 +86,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
rtvi = RTVIProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
@@ -108,13 +105,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[RTVIObserver(rtvi)],
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
@task.rtvi.event_handler("on_client_ready")
|
||||
async def on_client_ready(rtvi):
|
||||
await rtvi.set_bot_ready()
|
||||
# Kick off the conversation
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
@@ -22,7 +22,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
@@ -125,14 +124,10 @@ async def run_bot(pipecat_transport):
|
||||
),
|
||||
)
|
||||
|
||||
# RTVI events for Pipecat client UI
|
||||
rtvi = RTVIProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
pipecat_transport.input(),
|
||||
user_aggregator,
|
||||
rtvi,
|
||||
llm, # LLM
|
||||
EdgeDetectionProcessor(
|
||||
pipecat_transport._params.video_out_width,
|
||||
@@ -149,13 +144,11 @@ async def run_bot(pipecat_transport):
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[RTVIObserver(rtvi)],
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
@task.rtvi.event_handler("on_client_ready")
|
||||
async def on_client_ready(rtvi):
|
||||
logger.info("Pipecat client ready.")
|
||||
await rtvi.set_bot_ready()
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
from pipecat.utils.env import env_truthy
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
@@ -48,6 +49,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._log_data = env_truthy("SMART_TURN_LOG_DATA", default=False)
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.2-cpu.onnx"
|
||||
@@ -81,6 +84,49 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
|
||||
def _write_audio_to_wav(
|
||||
self, audio_array: np.ndarray, sample_rate: int = 16000, suffix: str = ""
|
||||
) -> None:
|
||||
"""Write audio data to a WAV file in a background thread.
|
||||
|
||||
Args:
|
||||
audio_array: The audio data as a numpy array (float32, normalized to [-1, 1]).
|
||||
sample_rate: The sample rate of the audio data.
|
||||
suffix: Optional suffix to append to the filename (e.g., "_raw", "_padded").
|
||||
"""
|
||||
import os
|
||||
import threading
|
||||
import wave
|
||||
from datetime import datetime
|
||||
|
||||
# Generate filename with current timestamp (millisecond precision)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S.%f")[:-3]
|
||||
log_dir = "./smart_turn_audio_log"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
filename = os.path.join(log_dir, f"{timestamp}{suffix}.wav")
|
||||
|
||||
# Make a copy of the audio data to avoid issues with the array being modified
|
||||
audio_copy = audio_array.copy()
|
||||
|
||||
def write_wav():
|
||||
try:
|
||||
# Convert float32 audio to int16 for WAV file
|
||||
audio_int16 = (audio_copy * 32767).astype(np.int16)
|
||||
|
||||
with wave.open(filename, "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono
|
||||
wav_file.setsampwidth(2) # 2 bytes for int16
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_int16.tobytes())
|
||||
|
||||
logger.debug(f"Wrote audio to {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audio to {filename}: {e}")
|
||||
|
||||
# Start background thread to write the WAV file
|
||||
thread = threading.Thread(target=write_wav, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
@@ -95,6 +141,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||
return audio_array
|
||||
|
||||
audio_for_logging = audio_array
|
||||
|
||||
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
||||
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
||||
|
||||
@@ -122,6 +170,10 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
if self._log_data:
|
||||
suffix = "_complete" if prediction == 1 else "_incomplete"
|
||||
self._write_audio_to_wav(audio_for_logging, sample_rate=16000, suffix=suffix)
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
|
||||
@@ -15,7 +15,7 @@ import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -49,6 +49,7 @@ from pipecat.pipeline.pipeline import Pipeline, PipelineSink, PipelineSource
|
||||
from pipecat.pipeline.task_observer import TaskObserver
|
||||
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserverParams, RTVIProcessor
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager, TaskManager, TaskManagerParams
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
@@ -225,9 +226,12 @@ class PipelineTask(BasePipelineTask):
|
||||
conversation_id: Optional[str] = None,
|
||||
enable_tracing: bool = False,
|
||||
enable_turn_tracking: bool = True,
|
||||
enable_rtvi: bool = True,
|
||||
idle_timeout_frames: Tuple[Type[Frame], ...] = (BotSpeakingFrame, UserSpeakingFrame),
|
||||
idle_timeout_secs: Optional[float] = IDLE_TIMEOUT_SECS,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
rtvi_processor: Optional[RTVIProcessor] = None,
|
||||
rtvi_observer_params: Optional[RTVIObserverParams] = None,
|
||||
task_manager: Optional[BaseTaskManager] = None,
|
||||
):
|
||||
"""Initialize the PipelineTask.
|
||||
@@ -244,6 +248,7 @@ class PipelineTask(BasePipelineTask):
|
||||
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
|
||||
clock: Clock implementation for timing operations.
|
||||
conversation_id: Optional custom ID for the conversation.
|
||||
enable_rtvi: Whether to automatically add RTVI support to the pipeline.
|
||||
enable_tracing: Whether to enable tracing.
|
||||
enable_turn_tracking: Whether to enable turn tracking.
|
||||
idle_timeout_frames: A tuple with the frames that should trigger an idle
|
||||
@@ -252,6 +257,8 @@ class PipelineTask(BasePipelineTask):
|
||||
None. If a pipeline is idle the pipeline task will be cancelled
|
||||
automatically.
|
||||
observers: List of observers for monitoring pipeline execution.
|
||||
rtvi_observer_params: The RTVI observer parameter to use if RTVI is enabled.
|
||||
rtvi_processor: The RTVI processor to add if RTVI is enabled.
|
||||
task_manager: Optional task manager for handling asyncio tasks.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -306,6 +313,16 @@ class PipelineTask(BasePipelineTask):
|
||||
self._heartbeat_push_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_monitor_task: Optional[asyncio.Task] = None
|
||||
|
||||
# RTVI support
|
||||
self._rtvi = None
|
||||
if enable_rtvi:
|
||||
self._rtvi = rtvi_processor or RTVIProcessor()
|
||||
observers.append(self._rtvi.create_rtvi_observer(params=rtvi_observer_params))
|
||||
|
||||
@self.rtvi.event_handler("on_client_ready")
|
||||
async def on_client_ready(rtvi: RTVIProcessor):
|
||||
await rtvi.set_bot_ready()
|
||||
|
||||
# This is the idle event. When selected frames are pushed from any
|
||||
# processor we consider the pipeline is not idle. We use an observer
|
||||
# which will be listening any part of the pipeline.
|
||||
@@ -335,7 +352,8 @@ class PipelineTask(BasePipelineTask):
|
||||
# allows us to receive and react to downstream frames.
|
||||
source = PipelineSource(self._source_push_frame, name=f"{self}::Source")
|
||||
sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
|
||||
self._pipeline = Pipeline([pipeline], source=source, sink=sink)
|
||||
processors = [self._rtvi, pipeline] if self._rtvi else [pipeline]
|
||||
self._pipeline = Pipeline(processors, source=source, sink=sink)
|
||||
|
||||
# The task observer acts as a proxy to the provided observers. This way,
|
||||
# we only need to pass a single observer (using the StartFrame) which
|
||||
@@ -348,8 +366,8 @@ class PipelineTask(BasePipelineTask):
|
||||
# in. This is mainly for efficiency reason because each event handler
|
||||
# creates a task and most likely you only care about one or two frame
|
||||
# types.
|
||||
self._reached_upstream_types: Tuple[Type[Frame], ...] = ()
|
||||
self._reached_downstream_types: Tuple[Type[Frame], ...] = ()
|
||||
self._reached_upstream_types: Set[Type[Frame]] = set()
|
||||
self._reached_downstream_types: Set[Type[Frame]] = set()
|
||||
self._register_event_handler("on_frame_reached_upstream")
|
||||
self._register_event_handler("on_frame_reached_downstream")
|
||||
self._register_event_handler("on_idle_timeout")
|
||||
@@ -398,6 +416,35 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
return self._turn_trace_observer
|
||||
|
||||
@property
|
||||
def rtvi(self) -> RTVIProcessor:
|
||||
"""Get the RTVI processor if RTVI is enabled.
|
||||
|
||||
Returns:
|
||||
The RTVI processor added to the pipeline when RTVI is enabled.
|
||||
"""
|
||||
if not self._rtvi:
|
||||
raise Exception(f"{self} RTVI is not enabled.")
|
||||
return self._rtvi
|
||||
|
||||
@property
|
||||
def reached_upstream_types(self) -> Tuple[Type[Frame], ...]:
|
||||
"""Get the currently configured upstream frame type filters.
|
||||
|
||||
Returns:
|
||||
Tuple of frame types that trigger the on_frame_reached_upstream event.
|
||||
"""
|
||||
return tuple(self._reached_upstream_types)
|
||||
|
||||
@property
|
||||
def reached_downstream_types(self) -> Tuple[Type[Frame], ...]:
|
||||
"""Get the currently configured downstream frame type filters.
|
||||
|
||||
Returns:
|
||||
Tuple of frame types that trigger the on_frame_reached_downstream event.
|
||||
"""
|
||||
return tuple(self._reached_downstream_types)
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
"""Decorator for registering event handlers.
|
||||
|
||||
@@ -441,7 +488,7 @@ class PipelineTask(BasePipelineTask):
|
||||
Args:
|
||||
types: Tuple of frame types to monitor for upstream events.
|
||||
"""
|
||||
self._reached_upstream_types = types
|
||||
self._reached_upstream_types = set(types)
|
||||
|
||||
def set_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]):
|
||||
"""Set which frame types trigger the on_frame_reached_downstream event.
|
||||
@@ -449,7 +496,23 @@ class PipelineTask(BasePipelineTask):
|
||||
Args:
|
||||
types: Tuple of frame types to monitor for downstream events.
|
||||
"""
|
||||
self._reached_downstream_types = types
|
||||
self._reached_downstream_types = set(types)
|
||||
|
||||
def add_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
|
||||
"""Add frame types to trigger the on_frame_reached_upstream event.
|
||||
|
||||
Args:
|
||||
types: Tuple of frame types to add to upstream monitoring.
|
||||
"""
|
||||
self._reached_upstream_types.update(types)
|
||||
|
||||
def add_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]):
|
||||
"""Add frame types to trigger the on_frame_reached_downstream event.
|
||||
|
||||
Args:
|
||||
types: Tuple of frame types to add to downstream monitoring.
|
||||
"""
|
||||
self._reached_downstream_types.update(types)
|
||||
|
||||
def has_finished(self) -> bool:
|
||||
"""Check if the pipeline task has finished execution.
|
||||
@@ -749,7 +812,7 @@ class PipelineTask(BasePipelineTask):
|
||||
pipeline to be stopped (e.g. EndTaskFrame) in which case we would send
|
||||
an EndFrame down the pipeline.
|
||||
"""
|
||||
if isinstance(frame, self._reached_upstream_types):
|
||||
if isinstance(frame, tuple(self._reached_upstream_types)):
|
||||
await self._call_event_handler("on_frame_reached_upstream", frame)
|
||||
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
@@ -788,7 +851,7 @@ class PipelineTask(BasePipelineTask):
|
||||
processors have handled the EndFrame and therefore we can exit the task
|
||||
cleanly.
|
||||
"""
|
||||
if isinstance(frame, self._reached_downstream_types):
|
||||
if isinstance(frame, tuple(self._reached_downstream_types)):
|
||||
await self._call_event_handler("on_frame_reached_downstream", frame)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
|
||||
@@ -12,7 +12,9 @@ management, and frame flow control mechanisms.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -779,8 +781,40 @@ class FrameProcessor(BaseObject):
|
||||
frame_cls: The class of the frame to be broadcasted.
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
"""
|
||||
await self.push_frame(frame_cls(**kwargs))
|
||||
await self.push_frame(frame_cls(**kwargs), FrameDirection.UPSTREAM)
|
||||
await self.push_frame(frame_cls(**deepcopy(kwargs)))
|
||||
await self.push_frame(frame_cls(**deepcopy(kwargs)), FrameDirection.UPSTREAM)
|
||||
|
||||
async def broadcast_frame_instance(self, frame: Frame):
|
||||
"""Broadcasts a frame instance upstream and downstream.
|
||||
|
||||
This method creates two new frame instances copying all fields from the
|
||||
original frame except `id` and `name`, which get fresh values.
|
||||
|
||||
Args:
|
||||
frame: The frame instance to broadcast.
|
||||
|
||||
Note:
|
||||
Prefer using `broadcast_frame()` when possible, as it is more
|
||||
efficient. This method should only be used when you are not the
|
||||
creator of the frame and need to broadcast an existing instance.
|
||||
"""
|
||||
frame_cls = type(frame)
|
||||
init_fields = {f.name: getattr(frame, f.name) for f in dataclasses.fields(frame) if f.init}
|
||||
extra_fields = {
|
||||
f.name: getattr(frame, f.name)
|
||||
for f in dataclasses.fields(frame)
|
||||
if not f.init and f.name not in ("id", "name")
|
||||
}
|
||||
|
||||
new_frame = frame_cls(**deepcopy(init_fields))
|
||||
for k, v in deepcopy(extra_fields).items():
|
||||
setattr(new_frame, k, v)
|
||||
await self.push_frame(new_frame)
|
||||
|
||||
new_frame = frame_cls(**deepcopy(init_fields))
|
||||
for k, v in deepcopy(extra_fields).items():
|
||||
setattr(new_frame, k, v)
|
||||
await self.push_frame(new_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
async def __start(self, frame: StartFrame):
|
||||
"""Handle the start frame to initialize processor state.
|
||||
|
||||
@@ -1100,13 +1100,11 @@ class RTVIObserver(BaseObserver):
|
||||
|
||||
if (
|
||||
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
|
||||
and (direction == FrameDirection.DOWNSTREAM)
|
||||
and self._params.user_speaking_enabled
|
||||
):
|
||||
await self._handle_interruptions(frame)
|
||||
elif (
|
||||
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
|
||||
and (direction == FrameDirection.UPSTREAM)
|
||||
and self._params.bot_speaking_enabled
|
||||
):
|
||||
await self._handle_bot_speaking(frame)
|
||||
@@ -1413,6 +1411,18 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
self._registered_services[service.name] = service
|
||||
|
||||
def create_rtvi_observer(self, *, params: Optional[RTVIObserverParams] = None, **kwargs):
|
||||
"""Creates a new RTVI Observer.
|
||||
|
||||
Args:
|
||||
params: Settings to enable/disable specific messages.
|
||||
**kwargs: Additional arguments passed to the observer.
|
||||
|
||||
Returns:
|
||||
A new RTVI observer.
|
||||
"""
|
||||
return RTVIObserver(self, params=params, **kwargs)
|
||||
|
||||
async def set_client_ready(self):
|
||||
"""Mark the client as ready and trigger the ready event."""
|
||||
self._client_ready = True
|
||||
|
||||
@@ -126,7 +126,7 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
if "pts" in args_dict:
|
||||
del args_dict["pts"]
|
||||
|
||||
# Special handling for MessageFrame -> OutputTransportMessageUrgentFrame
|
||||
# Special handling for MessageFrame -> InputTransportMessageFrame
|
||||
if class_name == MessageFrame:
|
||||
try:
|
||||
msg = json.loads(args_dict["data"])
|
||||
|
||||
@@ -1674,7 +1674,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
# start a timeout task to flush it later
|
||||
if self._user_transcription_buffer:
|
||||
self._transcription_timeout_task = self.create_task(
|
||||
self._transcription_timeout_handler()
|
||||
await self._transcription_timeout_handler()
|
||||
)
|
||||
|
||||
async def _handle_msg_output_transcription(self, message: LiveServerMessage):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Google RTVI integration models and observer implementation.
|
||||
"""Google RTVI processor and observer implementation.
|
||||
|
||||
This module provides integration with Google's services through the RTVI framework,
|
||||
including models for search responses and an observer for handling Google-specific
|
||||
@@ -16,7 +16,7 @@ from typing import List, Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.observers.base_observer import FramePushed
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIObserverParams, RTVIProcessor
|
||||
from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame
|
||||
|
||||
|
||||
@@ -86,4 +86,23 @@ class GoogleRTVIObserver(RTVIObserver):
|
||||
rendered_content=frame.rendered_content,
|
||||
)
|
||||
)
|
||||
await self.push_transport_message_urgent(message)
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
|
||||
class GoogleRTVIProcessor(RTVIProcessor):
|
||||
"""RTVI processor for Google service integration.
|
||||
|
||||
Creates a specific Google RTVI Observer.
|
||||
"""
|
||||
|
||||
def create_rtvi_observer(self, *, params: Optional[RTVIObserverParams] = None, **kwargs):
|
||||
"""Creates a new RTVI Observer.
|
||||
|
||||
Args:
|
||||
params: Settings to enable/disable specific messages.
|
||||
**kwargs: Additional arguments passed to the observer.
|
||||
|
||||
Returns:
|
||||
A new RTVI observer.
|
||||
"""
|
||||
return GoogleRTVIObserver(self)
|
||||
|
||||
@@ -40,6 +40,7 @@ from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
|
||||
try:
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.auth import default
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import texttospeech_v1
|
||||
@@ -515,6 +516,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
voice_id: str = "en-US-Chirp3-HD-Charon",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -525,6 +527,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
Args:
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
location: Google Cloud location for regional endpoint (e.g., "us-central1").
|
||||
voice_id: Google TTS voice identifier (e.g., "en-US-Standard-A").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses default.
|
||||
params: Voice customization parameters including pitch, rate, volume, etc.
|
||||
@@ -534,6 +537,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
|
||||
params = params or GoogleHttpTTSService.InputParams()
|
||||
|
||||
self._location = location
|
||||
self._settings = {
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
@@ -586,7 +590,15 @@ class GoogleHttpTTSService(TTSService):
|
||||
if not creds:
|
||||
raise ValueError("No valid credentials provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
client_options = None
|
||||
if self._location:
|
||||
client_options = ClientOptions(
|
||||
api_endpoint=f"{self._location}-texttospeech.googleapis.com"
|
||||
)
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(
|
||||
credentials=creds, client_options=client_options
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -783,7 +795,15 @@ class GoogleBaseTTSService(TTSService):
|
||||
if not creds:
|
||||
raise ValueError("No valid credentials provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
client_options = None
|
||||
if self._location:
|
||||
client_options = ClientOptions(
|
||||
api_endpoint=f"{self._location}-texttospeech.googleapis.com"
|
||||
)
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(
|
||||
credentials=creds, client_options=client_options
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -903,6 +923,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
voice_id: str = "en-US-Chirp3-HD-Charon",
|
||||
voice_cloning_key: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -914,6 +935,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
Args:
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
location: Google Cloud location for regional endpoint (e.g., "us-central1").
|
||||
voice_id: Google TTS voice identifier (e.g., "en-US-Chirp3-HD-Charon").
|
||||
voice_cloning_key: The voice cloning key for Chirp 3 custom voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses default.
|
||||
@@ -924,6 +946,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
|
||||
params = params or GoogleTTSService.InputParams()
|
||||
|
||||
self._location = location
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
@@ -1083,6 +1106,7 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
model: str = "gemini-2.5-flash-tts",
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
voice_id: str = "Kore",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -1101,6 +1125,7 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
"gemini-2.5-flash-tts" or "gemini-2.5-pro-tts".
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
location: Google Cloud location for regional endpoint (e.g., "us-central1").
|
||||
voice_id: Voice name from the available Gemini voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses Google's default 24kHz.
|
||||
params: TTS configuration parameters.
|
||||
@@ -1127,6 +1152,7 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
if voice_id not in self.AVAILABLE_VOICES:
|
||||
logger.warning(f"Voice '{voice_id}' not in known voices list. Using anyway.")
|
||||
|
||||
self._location = location
|
||||
self._model = model
|
||||
self._voice_id = voice_id
|
||||
self._settings = {
|
||||
|
||||
@@ -72,7 +72,7 @@ class InworldHttpTTSService(WordTTSService):
|
||||
api_key: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
voice_id: str = "Ashley",
|
||||
model: str = "inworld-tts-1",
|
||||
model: str = "inworld-tts-1.5-max",
|
||||
streaming: bool = True,
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "LINEAR16",
|
||||
@@ -427,7 +427,7 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "Ashley",
|
||||
model: str = "inworld-tts-1",
|
||||
model: str = "inworld-tts-1.5-max",
|
||||
url: str = "wss://api.inworld.ai/tts/v1/voice:streamBidirectional",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "LINEAR16",
|
||||
|
||||
@@ -152,6 +152,8 @@ class STTService(AIService):
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
await self.set_language(value)
|
||||
elif key == "language":
|
||||
await self.set_language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
|
||||
@@ -123,9 +123,10 @@ class QueuedFrameProcessor(FrameProcessor):
|
||||
async def run_test(
|
||||
processor: FrameProcessor,
|
||||
*,
|
||||
frames_to_send: Sequence[Frame],
|
||||
enable_rtvi: bool = False,
|
||||
expected_down_frames: Optional[Sequence[type]] = None,
|
||||
expected_up_frames: Optional[Sequence[type]] = None,
|
||||
frames_to_send: Sequence[Frame],
|
||||
ignore_start: bool = True,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
pipeline_params: Optional[PipelineParams] = None,
|
||||
@@ -139,9 +140,10 @@ async def run_test(
|
||||
|
||||
Args:
|
||||
processor: The frame processor to test.
|
||||
frames_to_send: Sequence of frames to send through the processor.
|
||||
enable_rtvi: Whether RTVI should be enabled in this test.
|
||||
expected_down_frames: Expected frame types flowing downstream (optional).
|
||||
expected_up_frames: Expected frame types flowing upstream (optional).
|
||||
frames_to_send: Sequence of frames to send through the processor.
|
||||
ignore_start: Whether to ignore StartFrames in frame validation.
|
||||
observers: Optional list of observers to attach to the pipeline.
|
||||
pipeline_params: Optional pipeline parameters.
|
||||
@@ -173,9 +175,10 @@ async def run_test(
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=pipeline_params,
|
||||
observers=observers,
|
||||
cancel_on_idle_timeout=False,
|
||||
enable_rtvi=enable_rtvi,
|
||||
observers=observers,
|
||||
params=pipeline_params,
|
||||
)
|
||||
|
||||
async def push_frames():
|
||||
|
||||
@@ -98,6 +98,7 @@ class TransportParams(BaseModel):
|
||||
video_out_bitrate: Video output bitrate in bits per second.
|
||||
video_out_framerate: Video output frame rate in FPS.
|
||||
video_out_color_format: Video output color format string.
|
||||
video_out_codec: Preferred video codec for output (e.g., 'VP8', 'H264', 'H265').
|
||||
video_out_destinations: List of video output destination identifiers.
|
||||
vad_enabled: Enable Voice Activity Detection (deprecated).
|
||||
|
||||
@@ -151,6 +152,7 @@ class TransportParams(BaseModel):
|
||||
video_out_bitrate: int = 800000
|
||||
video_out_framerate: int = 30
|
||||
video_out_color_format: str = "RGB"
|
||||
video_out_codec: Optional[str] = None
|
||||
video_out_destinations: List[str] = Field(default_factory=list)
|
||||
vad_enabled: bool = False
|
||||
vad_audio_passthrough: bool = False
|
||||
|
||||
@@ -811,6 +811,11 @@ class DailyTransportClient(EventHandler):
|
||||
"camera": {
|
||||
"sendSettings": {
|
||||
"maxQuality": "low",
|
||||
**(
|
||||
{"preferredCodec": self._params.video_out_codec}
|
||||
if self._params.video_out_codec
|
||||
else {}
|
||||
),
|
||||
"encodings": {
|
||||
"low": {
|
||||
"maxBitrate": self._params.video_out_bitrate,
|
||||
@@ -1728,8 +1733,9 @@ class DailyInputTransport(BaseInputTransport):
|
||||
message: The message data to send.
|
||||
sender: ID of the message sender.
|
||||
"""
|
||||
frame = DailyInputTransportMessageFrame(message=message, participant_id=sender)
|
||||
await self.push_frame(frame)
|
||||
await self.broadcast_frame_class(
|
||||
DailyInputTransportMessageFrame, message=message, participant_id=sender
|
||||
)
|
||||
|
||||
#
|
||||
# Audio in
|
||||
|
||||
@@ -698,8 +698,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
|
||||
message: The application message to process.
|
||||
"""
|
||||
logger.debug(f"Received app message inside SmallWebRTCInputTransport {message}")
|
||||
frame = InputTransportMessageFrame(message=message)
|
||||
await self.push_frame(frame)
|
||||
await self.broadcast_frame_class(InputTransportMessageFrame, message=message)
|
||||
|
||||
# Add this method similar to DailyInputTransport.request_participant_image
|
||||
async def request_participant_image(self, frame: UserImageRequestFrame):
|
||||
|
||||
@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputTransportMessageFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputTransportMessageFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
@@ -298,6 +299,8 @@ class WebsocketClientInputTransport(BaseInputTransport):
|
||||
return
|
||||
if isinstance(frame, InputAudioRawFrame) and self._params.audio_in_enabled:
|
||||
await self.push_audio_frame(frame)
|
||||
elif isinstance(frame, InputTransportMessageFrame):
|
||||
await self.broadcast_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputTransportMessageFrame,
|
||||
InterruptionFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputTransportMessageFrame,
|
||||
@@ -311,6 +312,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
await self.push_audio_frame(frame)
|
||||
elif isinstance(frame, InputTransportMessageFrame):
|
||||
await self.broadcast_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
|
||||
@@ -25,6 +25,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputTransportMessageFrame,
|
||||
InputTransportMessageUrgentFrame,
|
||||
InterruptionFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputTransportMessageFrame,
|
||||
@@ -214,6 +216,8 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
await self.push_audio_frame(frame)
|
||||
elif isinstance(frame, InputTransportMessageFrame):
|
||||
await self.broadcast_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
|
||||
@@ -51,6 +51,5 @@ class MuteUntilFirstBotCompleteUserMuteStrategy(BaseUserMuteStrategy):
|
||||
return not self._first_speech_handled
|
||||
|
||||
async def _handle_bot_stopped_speaking(self, frame: BotStoppedSpeakingFrame):
|
||||
self._bot_speaking = False
|
||||
if not self._first_speech_handled:
|
||||
self._first_speech_handled = True
|
||||
|
||||
53
src/pipecat/utils/env.py
Normal file
53
src/pipecat/utils/env.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Environment variable helpers.
|
||||
|
||||
This module provides small, centralized parsing helpers for environment variables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class InvalidEnvVarValueError(ValueError):
|
||||
"""Raised when an environment variable value cannot be parsed."""
|
||||
|
||||
def __init__(self, name: str, value: str, expected: str):
|
||||
"""Initialize an InvalidEnvVarValueError."""
|
||||
super().__init__(f"Invalid value for env var {name!r}: {value!r}. Expected {expected}.")
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.expected = expected
|
||||
|
||||
|
||||
def env_truthy(name: str, default: bool = False) -> bool:
|
||||
"""Interpret an environment variable as a boolean.
|
||||
|
||||
- If the variable is **not set**, returns `default`.
|
||||
- If the variable is set to a recognized boolean string, returns the parsed value.
|
||||
- Otherwise, raises `InvalidEnvVarValueError`.
|
||||
|
||||
Recognized values (case-insensitive, whitespace ignored):
|
||||
- Truthy: "1", "true", "yes", "y", "on"
|
||||
- Falsy: "0", "false", "no", "n", "off", ""
|
||||
"""
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
|
||||
val = raw.strip().lower()
|
||||
if val in {"1", "true", "yes", "y", "on"}:
|
||||
return True
|
||||
if val in {"0", "false", "no", "n", "off", ""}:
|
||||
return False
|
||||
|
||||
raise InvalidEnvVarValueError(
|
||||
name=name,
|
||||
value=raw,
|
||||
expected="true or false",
|
||||
)
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
@@ -24,6 +25,15 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadcastTestFrame(DataFrame):
|
||||
"""Test frame with init fields for broadcast testing."""
|
||||
|
||||
text: str = ""
|
||||
value: int = 0
|
||||
items: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_before_after_events(self):
|
||||
identity = IdentityFilter()
|
||||
@@ -186,3 +196,157 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_broadcast_frame(self):
|
||||
"""Test that broadcast_frame creates two separate frames with fresh IDs."""
|
||||
downstream_frames: List[Frame] = []
|
||||
upstream_frames: List[Frame] = []
|
||||
|
||||
class BroadcastTestProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.broadcast_frame(
|
||||
BroadcastTestFrame, text="hello", value=42, items=["a", "b"]
|
||||
)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class CaptureProcessor(FrameProcessor):
|
||||
def __init__(self, capture_list: List[Frame], direction: FrameDirection):
|
||||
super().__init__()
|
||||
self._capture_list = capture_list
|
||||
self._capture_direction = direction
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if direction == self._capture_direction and isinstance(frame, BroadcastTestFrame):
|
||||
self._capture_list.append(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
up_capture = CaptureProcessor(upstream_frames, FrameDirection.UPSTREAM)
|
||||
broadcaster = BroadcastTestProcessor()
|
||||
down_capture = CaptureProcessor(downstream_frames, FrameDirection.DOWNSTREAM)
|
||||
|
||||
pipeline = Pipeline([up_capture, broadcaster, down_capture])
|
||||
|
||||
frames_to_send = [TextFrame(text="trigger")]
|
||||
expected_down_frames = [BroadcastTestFrame]
|
||||
expected_up_frames = [BroadcastTestFrame]
|
||||
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
# Verify we got one frame in each direction
|
||||
self.assertEqual(len(downstream_frames), 1)
|
||||
self.assertEqual(len(upstream_frames), 1)
|
||||
|
||||
down_frame = downstream_frames[0]
|
||||
up_frame = upstream_frames[0]
|
||||
|
||||
# Verify the frames have different IDs (they are separate instances)
|
||||
self.assertNotEqual(down_frame.id, up_frame.id)
|
||||
|
||||
# Verify the frames have the correct field values
|
||||
self.assertEqual(down_frame.text, "hello")
|
||||
self.assertEqual(down_frame.value, 42)
|
||||
self.assertEqual(down_frame.items, ["a", "b"])
|
||||
self.assertEqual(up_frame.text, "hello")
|
||||
self.assertEqual(up_frame.value, 42)
|
||||
self.assertEqual(up_frame.items, ["a", "b"])
|
||||
|
||||
# Verify the items lists are separate instances (not shared references)
|
||||
self.assertIsNot(down_frame.items, up_frame.items)
|
||||
|
||||
async def test_broadcast_frame_instance(self):
|
||||
"""Test that broadcast_frame_instance copies all fields except id and name."""
|
||||
downstream_frames: List[Frame] = []
|
||||
upstream_frames: List[Frame] = []
|
||||
original_frame: List[Frame] = []
|
||||
|
||||
class BroadcastInstanceTestProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, BroadcastTestFrame):
|
||||
# Set some non-init fields on the frame
|
||||
frame.pts = 12345
|
||||
frame.metadata = {"key": "value", "nested": {"a": 1}}
|
||||
original_frame.append(frame)
|
||||
await self.broadcast_frame_instance(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class CaptureProcessor(FrameProcessor):
|
||||
def __init__(self, capture_list: List[Frame], direction: FrameDirection):
|
||||
super().__init__()
|
||||
self._capture_list = capture_list
|
||||
self._capture_direction = direction
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if direction == self._capture_direction and isinstance(frame, BroadcastTestFrame):
|
||||
self._capture_list.append(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
up_capture = CaptureProcessor(upstream_frames, FrameDirection.UPSTREAM)
|
||||
broadcaster = BroadcastInstanceTestProcessor()
|
||||
down_capture = CaptureProcessor(downstream_frames, FrameDirection.DOWNSTREAM)
|
||||
|
||||
pipeline = Pipeline([up_capture, broadcaster, down_capture])
|
||||
|
||||
# Create a frame with mutable fields to test deep copying
|
||||
test_frame = BroadcastTestFrame(text="test", value=99, items=["x", "y", "z"])
|
||||
|
||||
frames_to_send = [test_frame]
|
||||
expected_down_frames = [BroadcastTestFrame]
|
||||
expected_up_frames = [BroadcastTestFrame]
|
||||
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
# Verify we got one frame in each direction
|
||||
self.assertEqual(len(downstream_frames), 1)
|
||||
self.assertEqual(len(upstream_frames), 1)
|
||||
self.assertEqual(len(original_frame), 1)
|
||||
|
||||
orig = original_frame[0]
|
||||
down_frame = downstream_frames[0]
|
||||
up_frame = upstream_frames[0]
|
||||
|
||||
# Verify the frames have different IDs and names (fresh values)
|
||||
self.assertNotEqual(down_frame.id, orig.id)
|
||||
self.assertNotEqual(up_frame.id, orig.id)
|
||||
self.assertNotEqual(down_frame.id, up_frame.id)
|
||||
self.assertNotEqual(down_frame.name, orig.name)
|
||||
self.assertNotEqual(up_frame.name, orig.name)
|
||||
|
||||
# Verify init fields are copied correctly
|
||||
self.assertEqual(down_frame.text, "test")
|
||||
self.assertEqual(down_frame.value, 99)
|
||||
self.assertEqual(down_frame.items, ["x", "y", "z"])
|
||||
self.assertEqual(up_frame.text, "test")
|
||||
self.assertEqual(up_frame.value, 99)
|
||||
self.assertEqual(up_frame.items, ["x", "y", "z"])
|
||||
|
||||
# Verify non-init fields (except id/name) are copied
|
||||
self.assertEqual(down_frame.pts, 12345)
|
||||
self.assertEqual(down_frame.metadata, {"key": "value", "nested": {"a": 1}})
|
||||
self.assertEqual(up_frame.pts, 12345)
|
||||
self.assertEqual(up_frame.metadata, {"key": "value", "nested": {"a": 1}})
|
||||
|
||||
# Verify mutable fields are deep copied (not shared references)
|
||||
self.assertIsNot(down_frame.items, orig.items)
|
||||
self.assertIsNot(up_frame.items, orig.items)
|
||||
self.assertIsNot(down_frame.items, up_frame.items)
|
||||
self.assertIsNot(down_frame.metadata, orig.metadata)
|
||||
self.assertIsNot(up_frame.metadata, orig.metadata)
|
||||
self.assertIsNot(down_frame.metadata, up_frame.metadata)
|
||||
self.assertIsNot(down_frame.metadata["nested"], up_frame.metadata["nested"])
|
||||
|
||||
Reference in New Issue
Block a user