Compare commits

..

68 Commits

Author SHA1 Message Date
James Hush
77c82c64c0 Simple content filter demo 2025-12-04 09:52:41 +01:00
Aleix Conchillo Flaqué
b5e79f9dc5 Merge pull request #3181 from pipecat-ai/aleix/sync-to-utils-sync
move pipecat.sync to pipecat.utils.sync
2025-12-03 19:41:18 -08:00
Aleix Conchillo Flaqué
613b96819f Merge pull request #3180 from pipecat-ai/aleix/deepgram-tts-service-fix
DeepgramTTSService: fix websocket header logging
2025-12-03 19:40:43 -08:00
Mark Backman
57c24670ea Merge pull request #3132 from pipecat-ai/mb/normalize-llm-text-frame-output
Add split_text_by_spaces string util, normalize aggregator input
2025-12-03 22:05:14 -05:00
Mark Backman
d79dd94019 Make aggregate return an AsyncIterator, other clean up 2025-12-03 22:00:34 -05:00
Mark Backman
fa8e7458e1 Clean up 2025-12-03 22:00:04 -05:00
Mark Backman
4d66191963 fix: PatternPairAggregator to process patterns only once 2025-12-03 22:00:04 -05:00
Mark Backman
7e9d67002e SkipTagsAggregator and PatternPairAggregator now subclass SimpleTextAggregator 2025-12-03 22:00:04 -05:00
Mark Backman
ffbb6e5937 Update SimpleTextAggregator to handle character by character input, use a buffer to handle ambiguous EOS scenarios, and add a flush method to all aggregators 2025-12-03 22:00:02 -05:00
Mark Backman
535b85cf90 Add split_text_by_spaces string util 2025-12-03 21:55:30 -05:00
Aleix Conchillo Flaqué
8dc9872ed5 deprecate pipecat.sync package 2025-12-03 18:44:41 -08:00
Aleix Conchillo Flaqué
f37a53cc25 utils(sync): move sync to utils.sync 2025-12-03 18:20:12 -08:00
Aleix Conchillo Flaqué
9cce28c64c DeepgramTTSService: use websocket response headers for logging 2025-12-03 18:16:25 -08:00
Aleix Conchillo Flaqué
3ca94363ec Merge pull request #3168 from pipecat-ai/aleix/dont-override-skip-tts
LLMTextFrame: don't override skip_tts
2025-12-03 18:15:50 -08:00
Mark Backman
050f287ec4 Merge pull request #3072 from jjmaldonis/deepgram/add-deepgram-request-ids-to-debug-logs
deepgram: added request IDs to debug logs
2025-12-03 09:37:25 -05:00
Jason Maldonis
e6f5561785 updated changelog 2025-12-03 08:18:09 -06:00
Jason Maldonis
2df91f4b37 fixed linting 2025-12-03 08:09:16 -06:00
Jason Maldonis
7db49b9067 deepgram: added request IDs to debug logs
Deepgram request IDs are necessary for investigating behavior at the
request level. This commit adds DEBUG logs that print Deepgram request
IDs when using Deepgram's STT or TTS.
2025-12-03 08:09:13 -06:00
Vanessa Pyne
7c497bdc89 Merge pull request #3130 from pipecat-ai/vp-nvidia-docs
update nvidia services naming
2025-12-02 13:04:16 -06:00
vipyne
1aa4247d2b remove nim from pyproject.toml 2025-12-02 12:55:13 -06:00
vipyne
acba544e6f pr notes for nvidia service name change 2025-12-01 22:41:17 -06:00
vipyne
5d93c64ee5 typo fixes and uv.lock update 2025-12-01 22:41:17 -06:00
vipyne
de10bc8803 changelog for riva,nim -> nvidia name change 2025-12-01 22:41:17 -06:00
vipyne
36f5c1722d deprecate riva and nim service paths in favor of nvidia 2025-12-01 22:41:17 -06:00
vipyne
a8280522e5 examples: rename nvidia foundational examples 2025-12-01 22:41:17 -06:00
vipyne
05d65dfdd3 Update NVIDIA NIM and Riva services to Nvidia
- pip install pipecat-ai[nim]
- pip install pipecat-ai[riva]

+ pip install pipecat-ai[nvidia]

and

- from pipecat.services.nim.llm import NimLLMService
+ from pipecat.services.nvidia.llm import NvidiaLLMService

- from pipecat.services.riva.stt import RivaSTTService
+ from pipecat.services.nvidia.stt import NvidiaSTTService

- from pipecat.services.riva.tts import RivaTTSService
+ from pipecat.services.nvidia.tts import NvidiaTTSService
2025-12-01 22:41:17 -06:00
Aleix Conchillo Flaqué
a3962e3b47 LLMTextFrame: don't override skip_tts 2025-12-01 18:37:07 -08:00
Aleix Conchillo Flaqué
cd231cf829 Merge pull request #3120 from pipecat-ai/aleix/function-calls-wait-for-all
allow waiting for all function calls to complete
2025-12-01 18:35:53 -08:00
Aleix Conchillo Flaqué
9fafc1692d update uv.lock 2025-12-01 18:32:00 -08:00
Aleix Conchillo Flaqué
7648d0436c examples(19): linting 2025-12-01 18:30:34 -08:00
Aleix Conchillo Flaqué
bff8747e38 LLMService: allow waiting for all function calls to complete 2025-12-01 18:30:25 -08:00
Mark Backman
d227c0c097 Merge pull request #3155 from pipecat-ai/mb/fix-sarvam-tts-not-flushing
fix: flush audio in SarvamTTSService
2025-12-01 17:22:33 -05:00
Mark Backman
9ccde60521 fix: flush audio in SarvamTTSService 2025-12-01 17:18:34 -05:00
Mark Backman
b84a40666c Merge pull request #3156 from pipecat-ai/mb/deepgram-stt-stopped-frame
fix: DeepgramTTSService, let the base class push TTSStoppedFrame
2025-12-01 17:18:19 -05:00
Mark Backman
e72b135a4c fix: DeepgramTTSService, let the base class push TTSStoppedFrame 2025-12-01 17:15:51 -05:00
Aleix Conchillo Flaqué
2235d8f5a2 CHANGELOG formatting 2025-12-01 10:24:42 -08:00
Mark Backman
6e20a50a4b Merge pull request #3153 from pipecat-ai/mb/fix-aws-stt-region
fix: AWSTranscribeSTTService always set to us-east-1
2025-12-01 13:07:22 -05:00
Mark Backman
89d9ca045a fix: AWSTranscribeSTTService always set to us-east-1 2025-12-01 13:02:08 -05:00
Mark Backman
4b95ee92eb Merge pull request #3166 from pipecat-ai/mb/update-changelog-AWSBedrockAgentCoreProcessor
Retroactively add changelog to 0.0.96 for AWSBedrockAgentCoreProcessor
2025-12-01 11:51:47 -05:00
Mark Backman
d481ac6cc6 Retroactively add changelog to 0.0.96 for AWSBedrockAgentCoreProcessor 2025-12-01 11:49:00 -05:00
Mark Backman
e5a91296b5 Merge pull request #3162 from ai-coustics/add-stt-optimized-model
Add Quail STT as default model for `AICFilter`
2025-11-30 09:59:37 -05:00
Corvin Jaedicke
d8d10a0685 add changelog entry 2025-11-28 15:24:19 +01:00
Corvin Jaedicke
6dd9ed03b1 bump version to include new STT model, noise gate deprecation warning 2025-11-28 15:14:43 +01:00
Filipi da Silva Fuchter
d486c80804 Merge pull request #3151 from pipecat-ai/filipi/fix_runner_ice_servers
Fixing runner ICE servers to be compatible with what is expected by the mobile SDKs.
2025-11-27 10:24:02 -03:00
Filipi Fuchter
dedea7c420 Fixing runner ICE servers to be compatible with what is expected by the mobile SDKs. 2025-11-27 09:27:26 -03:00
Aleix Conchillo Flaqué
b78eb5de6b Merge pull request #3148 from pipecat-ai/aleix/pipecat-0.0.96-update
update CHANGELOG for 0.0.96 with proper date
2025-11-26 17:21:31 -08:00
Aleix Conchillo Flaqué
95aa13beb1 update CHANGELOG for 0.0.96 with proper date 2025-11-26 17:16:54 -08:00
Mark Backman
88ce85342c Merge pull request #3147 from pipecat-ai/mb/fix-sagemaker-error-handling
Fix error handling in DeepramSageMakerSTTService
2025-11-26 20:15:45 -05:00
Mark Backman
bedd40ae8b Fix error handling in DeepramSageMakerSTTService 2025-11-26 20:12:31 -05:00
Mark Backman
fda327b3ee Merge pull request #3146 from pipecat-ai/mb/fix-aws-bedrock-region
fix: AWSBedrockLLMService was always set to us-east-1
2025-11-26 19:56:09 -05:00
Mark Backman
ace95b6e6d fix: AWSBedrockLLMService was always set to us-east-1 2025-11-26 19:52:04 -05:00
Aleix Conchillo Flaqué
26c5c28c5c Merge pull request #3145 from pipecat-ai/aleix/simli-enable-logging-param
SimliVideoService: add enable_logging input parameter
2025-11-26 16:49:12 -08:00
Aleix Conchillo Flaqué
81f862749d SimliVideoService: add enable_logging input parameter 2025-11-26 16:36:06 -08:00
Aleix Conchillo Flaqué
b8bf7b4132 Merge pull request #3143 from pipecat-ai/aleix/pipecat-0.0.96
update CHANGELOG for 0.0.96
2025-11-26 16:31:44 -08:00
Aleix Conchillo Flaqué
d90121ef3b update CHANGELOG for 0.0.96 2025-11-26 15:30:06 -08:00
Filipi da Silva Fuchter
d0b7b4fb0a Merge pull request #3144 from pipecat-ai/filipi/fix_flux_reconnection_issue
Fixed an issue with DeepgramFluxSTTService where it sometimes failed to reconnect.
2025-11-26 20:29:41 -03:00
Filipi Fuchter
4acc317923 Fixed an issue with DeepgramFluxSTTService where it sometimes failed to reconnect. 2025-11-26 20:23:03 -03:00
Filipi da Silva Fuchter
7caf5751ee Merge pull request #3084 from pipecat-ai/filipi/improve_error_handler
Improving error handler.
2025-11-26 18:40:44 -03:00
Filipi Fuchter
1330ef3ad6 Enhanced error handling across the framework.
Co-authored-by: Mark Backman <m.backman@gmail.com>
2025-11-26 18:34:25 -03:00
Mark Backman
9efb21d61e Merge pull request #3115 from pipecat-ai/mb/deepgram-websocket-tts
Update DeepgramTTSService to use Deepgram's Websocket TTS API
2025-11-26 13:30:52 -05:00
Mark Backman
6d93b8e9d8 Update DeepgramTTSService to use Deepgram's Websocket TTS API 2025-11-26 13:25:34 -05:00
Aleix Conchillo Flaqué
6f527e509e update CHANGELOG with FishAudioTTSService s1 model update 2025-11-26 10:22:59 -08:00
Aleix Conchillo Flaqué
6cf1d0417e Merge pull request #3136 from kcui5/patch-1
Update Fish Audio default model to s1
2025-11-26 10:19:26 -08:00
Mark Backman
19d8b0dfc2 Merge pull request #3011 from thsunkid/feat/add-cached-reasoning-tokens-metrics-to-opentel-spans 2025-11-26 07:45:33 -05:00
Kyle Cui
7fa0cbf2a9 Update Fish Audio default model to s1
Update default model from speech-1.5 to s1 for Fish Audio TTS service
2025-11-26 01:50:38 -08:00
Thu Nguyen
36c4bc2df2 Update changelog 2025-11-26 13:01:48 +07:00
Thu Nguyen
42be0183af Merge branch 'main' into feat/add-cached-reasoning-tokens-metrics-to-opentel-spans 2025-11-26 12:59:43 +07:00
Thu Nguyen
35593b8574 Add cached and reasoning token metrics to OpenTelemetry spans 2025-11-09 00:38:30 +07:00
123 changed files with 2420 additions and 2400 deletions

View File

@@ -9,6 +9,103 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `wait_for_all` argument to the base `LLMService`. When enabled, this
ensures all function calls complete before returning results to the LLM (i.e.,
before running a new inference with those results).
### Changed
- Improved interruption handling to prevent bots from repeating themselves.
LLM services that return multiple sentences in a single response (e.g.,
`GoogleLLMService`) are now split into individual sentences before being sent
to TTS. This ensures interruptions occur at sentence boundaries, preventing
the bot from repeating content after being interrupted during long responses.
- Text Aggregation Improvements:
- **Breaking Change**: `BaseTextAggregator.aggregate()` now returns
`AsyncIterator[Aggregation]` instead of `Optional[Aggregation]`. This
enables the aggregator to return multiple results based on the provided
text.
- Refactored text aggregators to use inheritance: `SkipTagsAggregator` and
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
the base class's sentence detection logic.
- Updated `AICFilter` to use Quail STT as the default model
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
interaction (e.g., voice agents, speech-to-text) and operates at a native
sample rate of 16 kHz with fixed enhancement parameters.
- Updated Deepgram logging to include Deepgram request IDs for improved debugging.
### Deprecated
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
has any effect. Noise gating is now handled automatically by the AIC VAD
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
- NVIDIA Services name changes (all functionality is unchanged):
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
- `RivaSTTService` is now deprecated, use `NvidiaSTTService` instead.
- `RivaTTSService` is now deprecated, use `NvidiaTTSService` instead.
- Use `uv pip install pipecat-ai[nvidia]` instead of
`uv pip install pipecat-ai[riva]`
### Fixed
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
services.
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
multiple times for `KEEP` or `AGGREGATE` patterns.
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
spoken. Now, audio is flushed when the TTS services receives the
`LLMFullResponseEndFrame` or `EndFrame`.
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was
always set to `us-east-1` when providing an AWS_REGION env var.
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
incorrectly pushed after a functional call. This caused an issue with the
voice-ui-kit's conversational panel rending of the LLM output after a
function call.
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃
### Added
- Added `AWSBedrockAgentCoreProcessor` to support invoking an AgentCore-hosted
agent in a Pipecat pipeline.
- Enhanced error handling across the framework:
- Added `on_error` callback to `FrameProcessor` for centralized error
handling.
- Renamed `push_error(error: ErrorFrame)` to `push_error_frame(error: ErrorFrame)`
for clarity.
- Added new `push_error` method for simplified error reporting:
```python
async def push_error(error_msg: str,
exception: Optional[Exception] = None,
fatal: bool = False)
```
- Standardized error logging by replacing `logger.exception` calls with
`logger.error` throughout the codebase.
- Added `cache_read_input_tokens`, `cache_creation_input_tokens` and
`reasoning_tokens` to OTel spans for LLM call
- Added `LiveKitRESTHelper` utility class for managing LiveKit rooms via REST API.
- Added `DeepgramSageMakerSTTService` which connects to a SageMaker hosted
@@ -88,8 +185,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new emotions: calm and fluent
- Added `enable_logging` to `SimliVideoService` input parameters. It's disabled
by default.
### Changed
- Updated `FishAudioTTSService` default model to `s1`.
- Updated `DeepgramTTSService` to use Deepgram's TTS websocket API. ⚠️ This is
a potential breaking change, which only affects you if you're self-hosting
`DeepgramTTSService`. The new service uses Websockets and improves TTFB
latency.
- Updated `daily-python` to 0.22.0.
- `BaseTextAggregator` changes:
@@ -247,6 +354,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue in `AWSBedrockLLMService` where the `aws_region` arg was
always set to `us-east-1` when providing an AWS_REGION env var.
- Fixed an issue with `DeepgramFluxSTTService` where it sometimes failed to reconnect.
- Fixed an issue in `ElevenLabsRealtimeSTTService` where dynamic language
updates were not working.

View File

@@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel
**Examples:**
- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py)
- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py)
- [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py)
#### Key requirements:

View File

@@ -1,103 +0,0 @@
# TurnAwareTranscriptProcessor Example
## Overview
The `TurnAwareTranscriptProcessor` combines user and assistant transcript tracking with turn boundary detection. It correctly handles interruptions by only capturing what was actually spoken.
## Basic Usage
```python
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
# Create the processor
turn_processor = TurnAwareTranscriptProcessor()
# Register event handlers
@turn_processor.event_handler("on_turn_started")
async def handle_turn_started(processor, turn_number):
print(f"Turn {turn_number} started")
@turn_processor.event_handler("on_turn_ended")
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, was_interrupted):
print(f"\nTurn {turn_number} ended:")
print(f" User said: {user_text}")
print(f" Assistant said: {assistant_text}")
print(f" Was interrupted: {was_interrupted}")
@turn_processor.event_handler("on_transcript_update")
async def handle_transcript_update(processor, frame):
for msg in frame.messages:
print(f"[{msg.role}]: {msg.content}")
# Add to pipeline
pipeline = Pipeline([
transport.input(),
stt,
turn_processor, # Process transcripts and track turns
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
])
```
## Features
1. **Turn Boundary Detection**: Automatically detects when turns start and end based on user and bot speaking patterns
2. **Interruption Handling**: Correctly captures only what was actually spoken when interruptions occur
3. **Real-time Transcripts**: Emits transcript messages for both user and assistant speech
4. **Turn Events**: Provides start/end events with accumulated transcripts for each turn
## Events
### on_turn_started
Emitted when a new turn begins (user starts speaking).
**Handler signature**: `async def handler(processor, turn_number)`
### on_turn_ended
Emitted when a turn ends with accumulated transcripts.
**Handler signature**: `async def handler(processor, turn_number, user_transcript, assistant_transcript, was_interrupted)`
### on_transcript_update
Inherited from `BaseTranscriptProcessor`, emitted for individual transcript messages.
**Handler signature**: `async def handler(processor, frame)`
## Turn Logic
- Turns start when the user begins speaking (`UserStartedSpeakingFrame`)
- Turns end when:
- The user starts speaking again (previous turn ends, new turn starts)
- The bot is interrupted (`InterruptionFrame`)
- The pipeline ends (`EndFrame`/`CancelFrame`)
## Integration with OpenTelemetry
You can use turn events to enrich OpenTelemetry spans:
```python
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
turn_tracker = TurnTrackingObserver()
turn_tracer = TurnTraceObserver(turn_tracker)
turn_processor = TurnAwareTranscriptProcessor()
@turn_processor.event_handler("on_turn_ended")
async def add_transcripts_to_span(processor, turn_number, user_text, assistant_text, interrupted):
# Get current span and add transcript data
from opentelemetry import trace
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("turn.user_text", user_text)
current_span.set_attribute("turn.assistant_text", assistant_text)
```
## Notes
- The processor handles async frame processing correctly by delaying turn end until frames are processed
- Works with word-level timestamps from TTS services like Cartesia
- Accumulates both user (`TranscriptionFrame`) and assistant (`TTSTextFrame`) speech
- Emits individual transcript messages in addition to turn-level aggregation

View File

@@ -119,7 +119,6 @@ def import_core_modules():
"pipecat.observers",
"pipecat.runner",
"pipecat.serializers",
"pipecat.sync",
"pipecat.transcriptions",
"pipecat.utils",
]

View File

@@ -30,7 +30,6 @@ Quick Links
Runner <api/pipecat.runner>
Serializers <api/pipecat.serializers>
Services <api/pipecat.services>
Sync <api/pipecat.sync>
Transcriptions <api/pipecat.transcriptions>
Transports <api/pipecat.transports>
Utils <api/pipecat.utils>
Utils <api/pipecat.utils>

View File

@@ -15,7 +15,7 @@ from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.riva.tts import FastPitchTTSService
from pipecat.services.nvidia.tts import NvidiaTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -36,7 +36,7 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
task = PipelineTask(
Pipeline([tts, transport.output()]),

View File

@@ -13,12 +13,13 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.frames.frames import Frame, LLMContextFrame, LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
@@ -30,6 +31,44 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
FILTERED_WORDS = ["apple", "banana", "car"]
class ContentFilterProcessor(FrameProcessor):
"""Processor that filters LLMContextFrames containing specific words.
If the user's message contains any of the filtered words, the context
is replaced with a message indicating the assistant cannot respond.
"""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, LLMContextFrame):
# Check the last user message for filtered words
messages = frame.context.messages
if messages:
last_message = messages[-1]
content = last_message.get("content", "")
if isinstance(content, str):
content_lower = content.lower()
if any(word in content_lower for word in FILTERED_WORDS):
logger.info(f"Filtered content detected: {content}")
# Create a new context with a filtered response instruction
filtered_context = LLMContext(
messages=[
{
"role": "system",
"content": "The user is asking about something you cannot give an answer about. Tell them you don't know how to respond.",
}
]
)
await self.push_frame(LLMContextFrame(filtered_context), direction)
return
await self.push_frame(frame, direction)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -76,12 +115,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
context = LLMContext(messages)
context_aggregator = LLMContextAggregatorPair(context)
content_filter = ContentFilterProcessor()
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
content_filter, # Content filter
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output

View File

@@ -22,9 +22,9 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.nim.llm import NimLLMService
from pipecat.services.riva.stt import RivaSTTService
from pipecat.services.riva.tts import RivaTTSService
from pipecat.services.nvidia.llm import NvidiaLLMService
from pipecat.services.nvidia.stt import NvidiaSTTService
from pipecat.services.nvidia.tts import NvidiaTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -59,11 +59,13 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = RivaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
llm = NvidiaLLMService(
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
)
tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
messages = [
{

View File

@@ -27,7 +27,7 @@ from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.nim.llm import NimLLMService
from pipecat.services.nvidia.llm import NvidiaLLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -75,11 +75,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# text_filters=[MarkdownTextFilter()],
)
llm = NimLLMService(
llm = NvidiaLLMService(
api_key=os.getenv("NVIDIA_API_KEY"),
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
# Recommended when turning thinking off
params=NimLLMService.InputParams(temperature=0.0),
params=NvidiaLLMService.InputParams(temperature=0.0),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.

View File

@@ -14,20 +14,13 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
LLMRunFrame,
LLMSetToolsFrame,
LLMUpdateSettingsFrame,
TranscriptionMessage,
)
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.transcript_processor import TranscriptProcessor
from pipecat.runner.types import RunnerArguments

View File

@@ -19,7 +19,6 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport

View File

@@ -28,10 +28,10 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import LLMService
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.event_notifier import EventNotifier
load_dotenv(override=True)

View File

@@ -45,11 +45,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams, LLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)

View File

@@ -46,11 +46,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams, LLMService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)

View File

@@ -47,11 +47,11 @@ from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.llm_service import LLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)

View File

@@ -45,7 +45,7 @@ Source = "https://github.com/pipecat-ai/pipecat"
Website = "https://pipecat.ai"
[project.optional-dependencies]
aic = [ "aic-sdk~=1.1.0" ]
aic = [ "aic-sdk~=1.2.0" ]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "pipecat-ai[websockets-base]" ]
asyncai = [ "pipecat-ai[websockets-base]" ]
@@ -55,7 +55,7 @@ azure = [ "azure-cognitiveservices-speech~=1.42.0"]
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
cerebras = []
daily = [ "daily-python~=0.22.0" ]
deepgram = [ "deepgram-sdk~=4.7.0" ]
deepgram = [ "deepgram-sdk~=4.7.0", "pipecat-ai[websockets-base]" ]
deepseek = []
elevenlabs = [ "pipecat-ai[websockets-base]" ]
fal = [ "fal-client~=0.5.9" ]
@@ -83,8 +83,8 @@ mistral = []
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
neuphonic = [ "pipecat-ai[websockets-base]" ]
nim = []
noisereduce = [ "noisereduce~=3.0.3" ]
nvidia = [ "nvidia-riva-client~=2.21.1" ]
openai = [ "pipecat-ai[websockets-base]" ]
openpipe = [ "openpipe>=4.50.0,<6" ]
openrouter = []
@@ -93,7 +93,7 @@ playht = [ "pipecat-ai[websockets-base]" ]
qwen = []
remote-smart-turn = []
rime = [ "pipecat-ai[websockets-base]" ]
riva = [ "nvidia-riva-client~=2.21.1" ]
riva = [ "pipecat-ai[nvidia]" ]
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
sambanova = []

View File

@@ -103,7 +103,7 @@ TESTS_07 = [
("07o-interruptible-assemblyai.py", EVAL_SIMPLE_MATH),
("07q-interruptible-rime.py", EVAL_SIMPLE_MATH),
("07q-interruptible-rime-http.py", EVAL_SIMPLE_MATH),
("07r-interruptible-riva-nim.py", EVAL_SIMPLE_MATH),
("07r-interruptible-nvidia.py", EVAL_SIMPLE_MATH),
("07s-interruptible-google-audio-in.py", EVAL_SIMPLE_MATH),
("07t-interruptible-fish.py", EVAL_SIMPLE_MATH),
("07v-interruptible-neuphonic.py", EVAL_SIMPLE_MATH),
@@ -136,7 +136,7 @@ TESTS_14 = [
("14g-function-calling-grok.py", EVAL_WEATHER),
("14h-function-calling-azure.py", EVAL_WEATHER),
("14i-function-calling-fireworks.py", EVAL_WEATHER),
("14j-function-calling-nim.py", EVAL_WEATHER),
("14j-function-calling-nvidia.py", EVAL_WEATHER),
("14k-function-calling-cerebras.py", EVAL_WEATHER),
("14m-function-calling-openrouter.py", EVAL_WEATHER),
("14n-function-calling-perplexity.py", EVAL_WEATHER),

View File

@@ -39,7 +39,7 @@ class AICFilter(BaseAudioFilter):
self,
*,
license_key: str = "",
model_type: AICModelType = AICModelType.QUAIL_L,
model_type: AICModelType = AICModelType.QUAIL_STT,
enhancement_level: Optional[float] = 1.0,
voice_gain: Optional[float] = 1.0,
noise_gate_enable: Optional[bool] = True,
@@ -52,12 +52,27 @@ class AICFilter(BaseAudioFilter):
enhancement_level: Optional overall enhancement strength (0.0..1.0).
voice_gain: Optional linear gain applied to detected speech (0.0..4.0).
noise_gate_enable: Optional enable/disable noise gate (default: True).
.. deprecated:: 1.3.0
The `noise_gate_enable` parameter is deprecated and no longer has any effect.
It will be removed in a future version.
"""
self._license_key = license_key
self._model_type = model_type
self._enhancement_level = enhancement_level
self._voice_gain = voice_gain
if noise_gate_enable is not None:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter `noise_gate_enable` is deprecated and no longer has any effect. "
"It will be removed in a future version. Use AIC VAD instead (create_vad_analyzer()).",
DeprecationWarning,
)
self._noise_gate_enable = noise_gate_enable
self._enabled = True
@@ -149,10 +164,6 @@ class AICFilter(BaseAudioFilter):
)
if self._voice_gain is not None:
self._aic.set_parameter(AICParameter.VOICE_GAIN, float(self._voice_gain))
if self._noise_gate_enable is not None:
self._aic.set_parameter(
AICParameter.NOISE_GATE_ENABLE, 1.0 if bool(self._noise_gate_enable) else 0.0
)
self._aic_ready = True

View File

@@ -18,8 +18,10 @@ from loguru import logger
from pipecat.audio.dtmf.types import KeypadEntry
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
EndFrame,
Frame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMMessagesUpdateFrame,
LLMTextFrame,
OutputDTMFUrgentFrame,
@@ -149,11 +151,18 @@ class IVRProcessor(FrameProcessor):
elif isinstance(frame, LLMTextFrame):
# Process text through the pattern aggregator
result = await self._aggregator.aggregate(frame.text)
if result:
async for result in self._aggregator.aggregate(frame.text):
# Push aggregated text that doesn't contain XML patterns
await self.push_frame(LLMTextFrame(result.text), direction)
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
# Flush any remaining text from the aggregator
remaining = await self._aggregator.flush()
if remaining:
await self.push_frame(LLMTextFrame(remaining.text), direction)
# Push the end frame
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -40,8 +40,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.services.llm_service import LLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
from pipecat.utils.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.event_notifier import EventNotifier
class NotifierGate(FrameProcessor):

View File

@@ -330,7 +330,7 @@ class TextFrame(DataFrame):
"""
text: str
skip_tts: bool = field(init=False)
skip_tts: Optional[bool] = field(init=False)
# Whether any necessary inter-frame (leading/trailing) spaces are already
# included in the text.
# NOTE: Ideally this would be available at init time with a default value,
@@ -343,7 +343,7 @@ class TextFrame(DataFrame):
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.skip_tts = None
self.includes_inter_frame_spaces = False
self.append_to_context = True
@@ -835,11 +835,13 @@ class ErrorFrame(SystemFrame):
error: Description of the error that occurred.
fatal: Whether the error is fatal and requires bot shutdown.
processor: The frame processor that generated the error.
exception: The exception that occurred.
"""
error: str
fatal: bool = False
processor: Optional["FrameProcessor"] = None
exception: Optional[Exception] = None
def __str__(self):
return f"{self.name}(error: {self.error}, fatal: {self.fatal})"
@@ -1630,22 +1632,22 @@ class LLMFullResponseStartFrame(ControlFrame):
more TextFrames and a final LLMFullResponseEndFrame.
"""
skip_tts: bool = field(init=False)
skip_tts: Optional[bool] = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.skip_tts = None
@dataclass
class LLMFullResponseEndFrame(ControlFrame):
"""Frame indicating the end of an LLM response."""
skip_tts: bool = field(init=False)
skip_tts: Optional[bool] = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.skip_tts = None
@dataclass

View File

@@ -9,7 +9,7 @@
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.base_notifier import BaseNotifier
class GatedLLMContextAggregator(FrameProcessor):

View File

@@ -83,8 +83,7 @@ class LLMTextProcessor(FrameProcessor):
await self._text_aggregator.reset()
async def _handle_llm_text(self, in_frame: LLMTextFrame):
aggregation = await self._text_aggregator.aggregate(in_frame.text)
if aggregation:
async for aggregation in self._text_aggregator.aggregate(in_frame.text):
out_frame = AggregatedTextFrame(
text=aggregation.text,
aggregated_by=aggregation.type,
@@ -92,15 +91,13 @@ class LLMTextProcessor(FrameProcessor):
out_frame.skip_tts = in_frame.skip_tts
await self.push_frame(out_frame)
async def _handle_llm_end(self, skip_tts: bool = False):
# Flush any remaining aggregated text at the end of the LLM response
aggregation = self._text_aggregator.text
await self._text_aggregator.reset()
text = aggregation.text.strip()
if text:
async def _handle_llm_end(self, skip_tts: Optional[bool] = None):
# Flush any remaining text
remaining = await self._text_aggregator.flush()
if remaining:
out_frame = AggregatedTextFrame(
text=text,
aggregated_by=aggregation.type,
text=remaining.text,
aggregated_by=remaining.type,
)
out_frame.skip_tts = skip_tts
await self.push_frame(out_frame)

View File

@@ -126,6 +126,4 @@ class WakeCheckFilter(FrameProcessor):
else:
await self.push_frame(frame, direction)
except Exception as e:
error_msg = f"Error in wake word filter: {e}"
logger.exception(error_msg)
await self.push_error(ErrorFrame(error_msg))
await self.push_error(error_msg=f"Error in wake word filter: {e}", exception=e)

View File

@@ -10,7 +10,7 @@ from typing import Awaitable, Callable, Tuple, Type
from pipecat.frames.frames import Frame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.utils.sync.base_notifier import BaseNotifier
class WakeNotifierFilter(FrameProcessor):

View File

@@ -142,6 +142,7 @@ class FrameProcessor(BaseObject):
- on_after_process_frame: Called after a frame is processed
- on_before_push_frame: Called before a frame is pushed
- on_after_push_frame: Called after a frame is pushed
- on_error: Called when an error is raised in the frame processing.
"""
def __init__(
@@ -234,6 +235,7 @@ class FrameProcessor(BaseObject):
self._register_event_handler("on_after_process_frame", sync=True)
self._register_event_handler("on_before_push_frame", sync=True)
self._register_event_handler("on_after_push_frame", sync=True)
self._register_event_handler("on_error", sync=True)
@property
def id(self) -> int:
@@ -630,7 +632,43 @@ class FrameProcessor(BaseObject):
elif isinstance(frame, (FrameProcessorResumeFrame, FrameProcessorResumeUrgentFrame)):
await self.__resume(frame)
async def push_error(self, error: ErrorFrame):
async def push_error(
self,
error_msg: str,
exception: Optional[Exception] = None,
fatal: bool = False,
):
"""Creates and pushes an ErrorFrame upstream.
Creates and pushes an ErrorFrame upstream to notify other processors in the
pipeline about an error condition. The error frame will include context about
which processor generated the error.
Args:
error_msg: Descriptive message explaining the error condition.
exception: Optional exception object that caused the error, if available.
This provides additional context for debugging and error handling.
fatal: Whether this error should be considered fatal to the pipeline.
Fatal errors typically cause the entire pipeline to stop processing.
Defaults to False for non-fatal errors.
Example::
```python
# Non-fatal error
await self.push_error("Failed to process audio chunk, skipping")
# Fatal error with exception context
try:
result = some_critical_operation()
except Exception as e:
await self.push_error("Critical operation failed", exception=e, fatal=True)
```
"""
error_frame = ErrorFrame(error=error_msg, fatal=fatal, exception=exception, processor=self)
await self.push_error_frame(error=error_frame)
async def push_error_frame(self, error: ErrorFrame):
"""Push an error frame upstream.
Args:
@@ -638,6 +676,8 @@ class FrameProcessor(BaseObject):
"""
if not error.processor:
error.processor = self
await self._call_event_handler("on_error", error)
logger.error(f"{error.processor} error: {error.error}")
await self.push_frame(error, FrameDirection.UPSTREAM)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -759,8 +799,10 @@ class FrameProcessor(BaseObject):
await self.__cancel_process_task()
self.__create_process_task()
except Exception as e:
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
await self.push_error(ErrorFrame(str(e)))
await self.push_error(
error_msg=f"Uncaught exception handling _start_interruption: {e}",
exception=e,
)
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
"""Internal method to push frames to adjacent processors.
@@ -797,8 +839,7 @@ class FrameProcessor(BaseObject):
await self._observer.on_push_frame(data)
await self._prev.queue_frame(frame, direction)
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
await self.push_error(error_msg=f"Uncaught exception: {e}", exception=e)
def _check_started(self, frame: Frame):
"""Check if the processor has been started.
@@ -874,8 +915,7 @@ class FrameProcessor(BaseObject):
await self._call_event_handler("on_after_process_frame", frame)
except Exception as e:
logger.exception(f"{self}: error processing frame: {e}")
await self.push_error(ErrorFrame(str(e)))
await self.push_error(error_msg=f"Error processing frame: {e}", exception=e)
async def __input_frame_task_handler(self):
"""Handle frames from the input queue.

View File

@@ -24,7 +24,7 @@ try:
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import Runnable
except ModuleNotFoundError as e:
logger.exception("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
logger.error("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
raise Exception(f"Missing module: {e}")
@@ -113,6 +113,6 @@ class LangchainProcessor(FrameProcessor):
except GeneratorExit:
logger.warning(f"{self} generator was closed prematurely")
except Exception as e:
logger.exception(f"{self} an unknown error occurred: {e}")
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.push_frame(LLMFullResponseEndFrame())

View File

@@ -23,7 +23,7 @@ try:
from strands import Agent
from strands.multiagent.graph import Graph
except ModuleNotFoundError as e:
logger.exception("In order to use Strands Agents, you need to `pip install strands-agents`.")
logger.error("In order to use Strands Agents, you need to `pip install strands-agents`.")
raise Exception(f"Missing module: {e}")
@@ -143,7 +143,7 @@ class StrandsAgentsProcessor(FrameProcessor):
except GeneratorExit:
logger.warning(f"{self} generator was closed prematurely")
except Exception as e:
logger.exception(f"{self} an unknown error occurred: {e}")
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
if ttfb_tracking:
await self.stop_ttfb_metrics()

View File

@@ -15,7 +15,6 @@ from typing import List, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -25,7 +24,6 @@ from pipecat.frames.frames import (
TranscriptionMessage,
TranscriptionUpdateFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
@@ -308,267 +306,3 @@ class TranscriptProcessor:
return handler
return decorator
class TurnAwareTranscriptProcessor(BaseTranscriptProcessor):
"""Processes transcripts with turn boundary awareness.
This processor combines user and assistant transcript tracking with turn
detection, emitting events when turns start and end. It correctly handles
interruptions by only capturing what was actually spoken.
Turn boundaries are detected based on:
- User started speaking (UserStartedSpeakingFrame)
- Bot stopped speaking (BotStoppedSpeakingFrame)
- Interruptions (InterruptionFrame)
Events:
on_turn_started: Emitted when a new turn begins.
Handler signature: async def handler(processor, turn_number)
on_turn_ended: Emitted when a turn ends.
Handler signature: async def handler(processor, turn_number,
user_transcript, assistant_transcript,
was_interrupted)
on_transcript_update: Inherited from BaseTranscriptProcessor, emitted for
individual transcript messages.
Example::
turn_processor = TurnAwareTranscriptProcessor()
@turn_processor.event_handler("on_turn_started")
async def handle_turn_started(processor, turn_number):
print(f"Turn {turn_number} started")
@turn_processor.event_handler("on_turn_ended")
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, interrupted):
print(f"Turn {turn_number} ended")
print(f"User said: {user_text}")
print(f"Assistant said: {assistant_text}")
print(f"Was interrupted: {interrupted}")
pipeline = Pipeline([
transport.input(),
stt,
turn_processor,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
])
"""
def __init__(self, **kwargs):
"""Initialize the turn-aware transcript processor.
Args:
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
# Turn tracking state
self._turn_number = 0
self._turn_active = False
self._turn_start_time: Optional[str] = None
# Accumulate text for current turn
self._current_turn_user_parts: List[TextPartForConcatenation] = []
self._current_turn_assistant_parts: List[TextPartForConcatenation] = []
# Track bot speaking state
self._bot_is_speaking = False
# Register turn events
self._register_event_handler("on_turn_started")
self._register_event_handler("on_turn_ended")
async def _start_turn(self):
"""Start a new turn."""
if not self._turn_active:
self._turn_number += 1
self._turn_active = True
self._turn_start_time = time_now_iso8601()
self._current_turn_user_parts = []
self._current_turn_assistant_parts = []
logger.debug(f"Turn {self._turn_number} started")
await self._call_event_handler("on_turn_started", self._turn_number)
async def _end_turn(self, was_interrupted: bool = False):
"""End the current turn and emit aggregated transcripts.
Args:
was_interrupted: Whether the turn ended due to an interruption.
"""
if not self._turn_active:
return
# Aggregate user text
user_transcript = ""
if self._current_turn_user_parts:
user_transcript = concatenate_aggregated_text(self._current_turn_user_parts)
# Aggregate assistant text
assistant_transcript = ""
if self._current_turn_assistant_parts:
assistant_transcript = concatenate_aggregated_text(self._current_turn_assistant_parts)
# Emit turn ended event
logger.debug(
f"Turn {self._turn_number} ended (interrupted={was_interrupted}). "
f"User: '{user_transcript}', Assistant: '{assistant_transcript}'"
)
await self._call_event_handler(
"on_turn_ended",
self._turn_number,
user_transcript,
assistant_transcript,
was_interrupted,
)
# Reset turn state
self._turn_active = False
self._current_turn_user_parts = []
self._current_turn_assistant_parts = []
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for turn-aware transcript tracking.
Handles:
- UserStartedSpeakingFrame: Start new turn
- TranscriptionFrame: Accumulate user speech and emit transcript message
- BotStartedSpeakingFrame: Track bot speaking state
- TTSTextFrame: Accumulate assistant speech
- BotStoppedSpeakingFrame: End turn if no interruption pending
- InterruptionFrame: End turn immediately as interrupted
- EndFrame/CancelFrame: End any active turn
Args:
frame: Input frame to process.
direction: Frame processing direction.
"""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
# User started speaking
if self._bot_is_speaking:
# This is an interruption - end the current turn with what was spoken
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=True)
self._bot_is_speaking = False
elif self._turn_active:
# Previous turn is ending normally (bot finished speaking)
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=False)
# Start a new turn
await self._start_turn()
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
# Accumulate user speech for the current turn
if self._turn_active:
self._current_turn_user_parts.append(
TextPartForConcatenation(frame.text, includes_inter_part_spaces=True)
)
# Also emit individual transcript message
message = TranscriptionMessage(
role="user",
user_id=frame.user_id,
content=frame.text,
timestamp=frame.timestamp,
)
await self._emit_update([message])
await self.push_frame(frame, direction)
elif isinstance(frame, BotStartedSpeakingFrame):
# Bot started speaking
self._bot_is_speaking = True
await self.push_frame(frame, direction)
elif isinstance(frame, TTSTextFrame):
# Accumulate assistant speech for the current turn
if self._turn_active:
self._current_turn_assistant_parts.append(
TextPartForConcatenation(
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
)
)
await self.push_frame(frame, direction)
elif isinstance(frame, BotStoppedSpeakingFrame):
# Bot stopped speaking - just mark it, don't end turn yet
# Turn will end when next user speaks or pipeline ends
self._bot_is_speaking = False
await self.push_frame(frame, direction)
elif isinstance(frame, InterruptionFrame):
# Emit assistant transcript message with what was spoken before interruption
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(self._current_turn_assistant_parts)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
# Push frame first to ensure proper cleanup
await self.push_frame(frame, direction)
# End turn as interrupted
await self._end_turn(was_interrupted=True)
self._bot_is_speaking = False
elif isinstance(frame, (EndFrame, CancelFrame)):
# Pipeline ending - finalize any active turn
if self._turn_active:
# Emit any pending assistant transcript (allow time for TTSTextFrames to be processed)
# Give a brief moment for any pending frames to process
import asyncio
await asyncio.sleep(0.001)
if self._current_turn_assistant_parts:
assistant_content = concatenate_aggregated_text(
self._current_turn_assistant_parts
)
if assistant_content:
message = TranscriptionMessage(
role="assistant",
content=assistant_content,
timestamp=self._turn_start_time or time_now_iso8601(),
)
await self._emit_update([message])
await self._end_turn(was_interrupted=isinstance(frame, CancelFrame))
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

View File

@@ -302,7 +302,7 @@ def _setup_webrtc_routes(
result: StartBotResult = {"sessionId": session_id}
if request_data.get("enableDefaultIceServers"):
result["iceConfig"] = IceConfig(
iceServers=[IceServer(urls="stun:stun.l.google.com:19302")]
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
)
return result

View File

@@ -199,7 +199,7 @@ class PlivoFrameSerializer(FrameSerializer):
)
except Exception as e:
logger.exception(f"Failed to hang up Plivo call: {e}")
logger.error(f"Failed to hang up Plivo call: {e}")
async def deserialize(self, data: str | bytes) -> Frame | None:
"""Deserializes Plivo WebSocket data to Pipecat frames.

View File

@@ -225,7 +225,7 @@ class TelnyxFrameSerializer(FrameSerializer):
)
except Exception as e:
logger.exception(f"Failed to hang up Telnyx call: {e}")
logger.error(f"Failed to hang up Telnyx call: {e}")
async def deserialize(self, data: str | bytes) -> Frame | None:
"""Deserializes Telnyx WebSocket data to Pipecat frames.

View File

@@ -236,7 +236,7 @@ class TwilioFrameSerializer(FrameSerializer):
)
except Exception as e:
logger.exception(f"Failed to hang up Twilio call: {e}")
logger.error(f"Failed to hang up Twilio call: {e}")
async def deserialize(self, data: str | bytes) -> Frame | None:
"""Deserializes Twilio WebSocket data to Pipecat frames.

View File

@@ -166,6 +166,6 @@ class AIService(FrameProcessor):
async for f in generator:
if f:
if isinstance(f, ErrorFrame):
await self.push_error(f)
await self.push_error_frame(f)
else:
await self.push_frame(f)

View File

@@ -458,8 +458,7 @@ class AnthropicLLMService(LLMService):
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
await self.push_error(ErrorFrame(f"{e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())

View File

@@ -206,9 +206,8 @@ class AssemblyAISTTService(STTService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
self._connected = False
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise
async def _disconnect(self):
@@ -233,8 +232,7 @@ class AssemblyAISTTService(STTService):
logger.warning("Timed out waiting for termination message from server")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
if self._receive_task:
await self.cancel_task(self._receive_task)
@@ -242,8 +240,7 @@ class AssemblyAISTTService(STTService):
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
@@ -262,13 +259,11 @@ class AssemblyAISTTService(STTService):
except websockets.exceptions.ConnectionClosedOK:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
"""Parse a raw message into the appropriate message type."""
@@ -297,8 +292,7 @@ class AssemblyAISTTService(STTService):
elif isinstance(parsed_message, TerminationMessage):
await self._handle_termination(parsed_message)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def _handle_termination(self, message: TerminationMessage):
"""Handle termination message."""

View File

@@ -228,8 +228,7 @@ class AsyncAITTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -241,8 +240,7 @@ class AsyncAITTSService(InterruptibleTTSService):
logger.debug("Disconnecting from Async")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
self._started = False
@@ -287,12 +285,11 @@ class AsyncAITTSService(InterruptibleTTSService):
)
await self.push_frame(frame)
elif msg.get("error_code"):
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
await self.push_error(error_msg=f"Error: {msg['message']}")
else:
logger.error(f"{self} error, unknown message type: {msg}")
await self.push_error(error_msg=f"Unknown message type: {msg}")
async def _keepalive_task_handler(self):
"""Send periodic keepalive messages to maintain WebSocket connection."""
@@ -335,16 +332,14 @@ class AsyncAITTSService(InterruptibleTTSService):
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class AsyncAIHttpTTSService(TTSService):
@@ -477,8 +472,7 @@ class AsyncAIHttpTTSService(TTSService):
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Async API error: {error_text}")
await self.push_error(ErrorFrame(error=f"Async API error: {error_text}"))
await self.push_error(error_msg=f"Async API error: {error_text}")
raise Exception(f"Async API returned status {response.status}: {error_text}")
audio_data = await response.read()
@@ -494,8 +488,7 @@ class AsyncAIHttpTTSService(TTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -734,7 +734,7 @@ class AWSBedrockLLMService(LLMService):
aws_access_key: Optional[str] = None,
aws_secret_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region: str = "us-east-1",
aws_region: Optional[str] = None,
params: Optional[InputParams] = None,
client_config: Optional[Config] = None,
retry_timeout_secs: Optional[float] = 5.0,
@@ -1136,7 +1136,7 @@ class AWSBedrockLLMService(LLMService):
except (ReadTimeoutError, asyncio.TimeoutError):
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())

View File

@@ -453,7 +453,7 @@ class AWSNovaSonicLLMService(LLMService):
self._ready_to_send_context = True
await self._finish_connecting_if_context_available()
except Exception as e:
logger.error(f"{self} initialization error: {e}")
await self.push_error(error_msg=f"Initialization error: {e}", exception=e)
await self._disconnect()
async def _process_completed_function_calls(self, send_new_results: bool):
@@ -577,7 +577,7 @@ class AWSNovaSonicLLMService(LLMService):
logger.info("Finished disconnecting")
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
def _create_client(self) -> BedrockRuntimeClient:
config = Config(
@@ -885,7 +885,7 @@ class AWSNovaSonicLLMService(LLMService):
# Errors are kind of expected while disconnecting, so just
# ignore them and do nothing
return
logger.error(f"{self} error processing responses: {e}")
await self.push_error(error_msg=f"Error processing responses: {e}", exception=e)
if self._wants_connection:
await self.reset_conversation()

View File

@@ -58,7 +58,7 @@ class AWSTranscribeSTTService(STTService):
api_key: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_session_token: Optional[str] = None,
region: Optional[str] = "us-east-1",
region: Optional[str] = None,
sample_rate: int = 16000,
language: Language = Language.EN,
**kwargs,
@@ -69,7 +69,7 @@ class AWSTranscribeSTTService(STTService):
api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
region: AWS region for the service. Defaults to "us-east-1".
region: AWS region for the service.
sample_rate: Audio sample rate in Hz. Must be 8000 or 16000. Defaults to 16000.
language: Language for transcription. Defaults to English.
**kwargs: Additional arguments passed to parent STTService class.
@@ -140,8 +140,7 @@ class AWSTranscribeSTTService(STTService):
return
logger.warning("WebSocket connection not established after connect")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
retry_count += 1
if retry_count < max_retries:
await asyncio.sleep(1) # Wait before retrying
@@ -182,8 +181,7 @@ class AWSTranscribeSTTService(STTService):
try:
await self._connect()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
return
# Format the audio data according to AWS event stream format
@@ -200,13 +198,11 @@ class AWSTranscribeSTTService(STTService):
await self._disconnect()
# Don't yield error here - we'll retry on next frame
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
await self._disconnect()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
await self._disconnect()
async def _connect(self):
@@ -289,8 +285,7 @@ class AWSTranscribeSTTService(STTService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._disconnect()
raise
@@ -310,8 +305,7 @@ class AWSTranscribeSTTService(STTService):
await self._ws_client.send(json.dumps(end_stream))
await self._ws_client.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._ws_client = None
await self._call_event_handler("on_disconnected")
@@ -529,15 +523,15 @@ class AWSTranscribeSTTService(STTService):
)
elif headers.get(":message-type") == "exception":
error_msg = payload.get("Message", "Unknown error")
logger.error(f"{self} Exception from AWS: {error_msg}")
await self.push_frame(ErrorFrame(f"AWS Transcribe error: {error_msg}"))
await self.push_error(error_msg=f"AWS Transcribe error: {error_msg}")
else:
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
await self.push_error(
error_msg=f"WebSocket connection closed in receive loop", exception=e
)
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
break

View File

@@ -312,7 +312,6 @@ class AWSPollyTTSService(TTSService):
yield TTSStoppedFrame()
except (BotoCoreError, ClientError) as error:
logger.exception(f"{self} error generating TTS: {error}")
error_message = f"AWS Polly TTS error: {str(error)}"
yield ErrorFrame(error=error_message)

View File

@@ -91,7 +91,6 @@ class AzureImageGenServiceREST(ImageGenService):
while status != "succeeded":
attempts_left -= 1
if attempts_left == 0:
logger.error(f"{self} error: image generation timed out")
yield ErrorFrame("Image generation timed out")
return
@@ -104,7 +103,6 @@ class AzureImageGenServiceREST(ImageGenService):
image_url = json_response["result"]["data"][0]["url"] if json_response else None
if not image_url:
logger.error(f"{self} error: image generation failed")
yield ErrorFrame("Image generation failed")
return

View File

@@ -61,5 +61,5 @@ class AzureRealtimeLLMService(OpenAIRealtimeLLMService):
)
self._receive_task = self.create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"{self} initialization error: {e}")
await self.push_error(error_msg=f"initialization error: {e}", exception=e)
self._websocket = None

View File

@@ -121,8 +121,7 @@ class AzureSTTService(STTService):
self._audio_stream.write(audio)
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
async def start(self, frame: StartFrame):
"""Start the speech recognition service.
@@ -151,8 +150,9 @@ class AzureSTTService(STTService):
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
self._speech_recognizer.start_continuous_recognition_async()
except Exception as e:
logger.error(f"{self} exception during initialization: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(
error_msg=f"Uncaught exception during initialization: {e}", exception=e
)
async def stop(self, frame: EndFrame):
"""Stop the speech recognition service.

View File

@@ -327,7 +327,6 @@ class AzureTTSService(AzureBaseTTSService):
try:
if self._speech_synthesizer is None:
error_msg = "Speech synthesizer not initialized."
logger.error(error_msg)
yield ErrorFrame(error=error_msg)
return
@@ -355,15 +354,13 @@ class AzureTTSService(AzureBaseTTSService):
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
# Could add reconnection logic here if needed
return
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class AzureHttpTTSService(AzureBaseTTSService):
@@ -440,5 +437,6 @@ class AzureHttpTTSService(AzureBaseTTSService):
cancellation_details = result.cancellation_details
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
if cancellation_details.reason == CancellationReason.Error:
logger.error(f"{self} error: {cancellation_details.error_details}")
yield ErrorFrame(error=f"{self} error: {cancellation_details.error_details}")
yield ErrorFrame(
error=f"Unknown error occurred: {cancellation_details.error_details}"
)

View File

@@ -276,8 +276,7 @@ class CartesiaSTTService(WebsocketSTTService):
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def _disconnect_websocket(self):
try:
@@ -285,8 +284,7 @@ class CartesiaSTTService(WebsocketSTTService):
logger.debug("Disconnecting from Cartesia STT")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -319,8 +317,7 @@ class CartesiaSTTService(WebsocketSTTService):
elif data["type"] == "error":
error_msg = data.get("message", "Unknown error")
logger.error(f"Cartesia error: {error_msg}")
await self.push_error(ErrorFrame(error=error_msg))
await self.push_error(error_msg=error_msg)
@traced_stt
async def _handle_transcription(

View File

@@ -497,8 +497,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
)
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -510,8 +509,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
logger.debug("Disconnecting from Cartesia")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
@@ -564,13 +562,12 @@ class CartesiaTTSService(AudioContextWordTTSService):
)
await self.append_to_audio_context(msg["context_id"], frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
await self.push_error(error_msg=f"Error: {msg}")
self._context_id = None
else:
logger.error(f"{self} error, unknown message type: {msg}")
await self.push_error(error_msg=f"Error, unknown message type: {msg}")
async def _receive_messages(self):
while True:
@@ -608,16 +605,14 @@ class CartesiaTTSService(AudioContextWordTTSService):
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class CartesiaHttpTTSService(TTSService):
@@ -808,8 +803,7 @@ class CartesiaHttpTTSService(TTSService):
async with session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Cartesia API error: {error_text}")
await self.push_error(ErrorFrame(error=f"Cartesia API error: {error_text}"))
yield ErrorFrame(error=f"Cartesia API error: {error_text}")
raise Exception(f"Cartesia API returned status {response.status}: {error_text}")
audio_data = await response.read()
@@ -825,8 +819,7 @@ class CartesiaHttpTTSService(TTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -150,7 +150,17 @@ class DeepgramFluxSTTService(WebsocketSTTService):
params=params
)
"""
super().__init__(sample_rate=sample_rate, **kwargs)
# Note: For DeepgramFluxSTTService, differently from other processes, we need to create
# the _receive_task inside _connect_websocket, because the websocket should only be
# considered connected and ready to send audio once we receive from Flux the message
# which confirms the connection has been established.
# If we try to keep the logic reconnect_on_error, when receiving a message, the
# _receive_task_handler would try to reconnect in case of error, invoking the
# _connect_websocket again and leading to a case where the first _receive_task_handler
# was never destroyed.
# So we can keep it here as false, because inside the method send_with_retry, it will
# already try to reconnect if needed.
super().__init__(sample_rate=sample_rate, reconnect_on_error=False, **kwargs)
self._api_key = api_key
self._url = url
@@ -183,14 +193,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
"""
await self._connect_websocket()
# Creating the receiver task (only created once during initial connection)
if not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
# Creating the watchdog task (only created once during initial connection)
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
async def _disconnect(self):
"""Disconnect from WebSocket and clean up tasks.
@@ -200,8 +202,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
try:
await self._disconnect_websocket()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
# Reset state only after everything is cleaned up
self._websocket = None
@@ -243,14 +244,28 @@ class DeepgramFluxSTTService(WebsocketSTTService):
additional_headers={"Authorization": f"Token {self._api_key}"},
)
headers = {
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
}
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
# Creating the receiver task
if not self._receive_task:
self._receive_task = self.create_task(
self._receive_task_handler(self._report_error)
)
# Creating the watchdog task
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
# Now wait for the connection established event
logger.debug("WebSocket connected, waiting for server confirmation...")
await self._connection_established_event.wait()
logger.debug("Connected to Deepgram Flux Websocket")
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -278,8 +293,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
logger.debug("Disconnecting from Deepgram Flux Websocket")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -289,10 +303,13 @@ class DeepgramFluxSTTService(WebsocketSTTService):
This signals to the server that no more audio data will be sent.
"""
if self._websocket:
logger.debug("Sending CloseStream message to Deepgram Flux")
message = {"type": "CloseStream"}
await self._websocket.send(json.dumps(message))
try:
if self._websocket:
logger.debug("Sending CloseStream message to Deepgram Flux")
message = {"type": "CloseStream"}
await self._websocket.send(json.dumps(message))
except Exception as e:
await self.push_error(error_msg=f"Error sending closeStream: {e}", exception=e)
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -379,16 +396,13 @@ class DeepgramFluxSTTService(WebsocketSTTService):
are issues sending the audio data.
"""
if not self._websocket:
logger.error("Not connected to Deepgram Flux.")
yield ErrorFrame("Not connected to Deepgram Flux.")
return
try:
self._last_stt_time = time.monotonic()
await self.send_with_retry(audio, self._report_error)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
return
yield None
@@ -465,8 +479,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
# Skip malformed messages
continue
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
# Error will be handled inside WebsocketService->_receive_task_handler
raise
else:

View File

@@ -233,7 +233,14 @@ class DeepgramSTTService(STTService):
)
if not await self._connection.start(options=self._settings, addons=self._addons):
logger.error(f"{self}: unable to connect to Deepgram")
await self.push_error(error_msg=f"Unable to connect to Deepgram")
else:
headers = {
k: v
for k, v in self._connection._socket.response.headers.items()
if k.startswith("dg-")
}
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
async def _disconnect(self):
if await self._connection.is_connected():
@@ -256,7 +263,7 @@ class DeepgramSTTService(STTService):
async def _on_error(self, *args, **kwargs):
error: ErrorResponse = kwargs["error"]
logger.warning(f"{self} connection error, will retry: {error}")
await self.push_error(ErrorFrame(error=f"{error}"))
await self.push_error(error_msg=f"{error}")
await self.stop_all_metrics()
# NOTE(aleix): we don't disconnect (i.e. call finish on the connection)
# because this triggers more errors internally in the Deepgram SDK. So,

View File

@@ -210,8 +210,7 @@ class DeepgramSageMakerSTTService(STTService):
try:
await self._client.send_audio_chunk(audio)
except Exception as e:
logger.error(f"Error sending audio to SageMaker: {e}")
await self.push_error(ErrorFrame(error=f"SageMaker STT error: {e}"))
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield None
async def _connect(self):
@@ -260,8 +259,7 @@ class DeepgramSageMakerSTTService(STTService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"Failed to connect to SageMaker: {e}")
await self.push_error(ErrorFrame(error=f"SageMaker connection error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._call_event_handler("on_connection_error", str(e))
async def _disconnect(self):
@@ -342,8 +340,7 @@ class DeepgramSageMakerSTTService(STTService):
except asyncio.CancelledError:
logger.debug("Response processor cancelled")
except Exception as e:
logger.error(f"Error processing responses: {e}", exc_info=True)
await self.push_error(ErrorFrame(error=f"SageMaker response error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
logger.debug("Response processor stopped")

View File

@@ -10,35 +10,45 @@ This module provides integration with Deepgram's text-to-speech API
for generating speech from text using various voice models.
"""
import json
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import TTSService, WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
from deepgram import DeepgramClient, DeepgramClientOptions, SpeakOptions
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
logger.error(
"In order to use DeepgramWebsocketTTSService, you need to `pip install pipecat-ai[deepgram]`."
)
raise Exception(f"Missing module: {e}")
class DeepgramTTSService(TTSService):
"""Deepgram text-to-speech service.
class DeepgramTTSService(WebsocketTTSService):
"""Deepgram WebSocket-based text-to-speech service.
Provides text-to-speech synthesis using Deepgram's streaming API.
Supports various voice models and audio encoding formats with
configurable sample rates and quality settings.
Provides real-time text-to-speech synthesis using Deepgram's WebSocket API.
Supports streaming audio generation with interruption handling via the Clear
message for conversational AI use cases.
"""
def __init__(
@@ -46,42 +56,220 @@ class DeepgramTTSService(TTSService):
*,
api_key: str,
voice: str = "aura-2-helena-en",
base_url: str = "",
base_url: str = "wss://api.deepgram.com",
sample_rate: Optional[int] = None,
encoding: str = "linear16",
**kwargs,
):
"""Initialize the Deepgram TTS service.
"""Initialize the Deepgram WebSocket TTS service.
Args:
api_key: Deepgram API key for authentication.
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
base_url: Custom base URL for Deepgram API. Uses default if empty.
base_url: WebSocket base URL for Deepgram API. Defaults to "wss://api.deepgram.com".
sample_rate: Audio sample rate in Hz. If None, uses service default.
encoding: Audio encoding format. Defaults to "linear16".
**kwargs: Additional arguments passed to parent TTSService class.
**kwargs: Additional arguments passed to parent InterruptibleTTSService class.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
super().__init__(
sample_rate=sample_rate,
pause_frame_processing=True,
push_stop_frames=True,
**kwargs,
)
self._api_key = api_key
self._base_url = base_url
self._settings = {
"encoding": encoding,
}
self.set_voice(voice)
client_options = DeepgramClientOptions(url=base_url)
self._deepgram_client = DeepgramClient(api_key, config=client_options)
self._receive_task = None
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
Returns:
True, as Deepgram TTS service supports metrics generation.
True, as Deepgram WebSocket TTS service supports metrics generation.
"""
return True
async def start(self, frame: StartFrame):
"""Start the Deepgram WebSocket TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Deepgram WebSocket TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Deepgram WebSocket TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with special handling for LLM response end.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
# When the LLM finishes responding, flush any remaining text in Deepgram's buffer
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
async def _connect(self):
"""Connect to Deepgram WebSocket and start receive task."""
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Disconnect from Deepgram WebSocket and clean up tasks."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
"""Connect to Deepgram WebSocket API with configured settings."""
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to Deepgram WebSocket")
# Build WebSocket URL with query parameters
params = []
params.append(f"model={self._voice_id}")
params.append(f"encoding={self._settings['encoding']}")
params.append(f"sample_rate={self.sample_rate}")
url = f"{self._base_url}/v1/speak?{'&'.join(params)}"
headers = {"Authorization": f"Token {self._api_key}"}
self._websocket = await websocket_connect(url, additional_headers=headers)
headers = {
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
}
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
"""Close WebSocket connection and reset state."""
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Deepgram WebSocket")
# Send Close message to gracefully close the connection
await self._websocket.send(json.dumps({"type": "Close"}))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
"""Get active websocket connection or raise exception."""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by sending Clear message to Deepgram.
The Clear message will clear Deepgram's internal text buffer and stop
sending audio, allowing for a new response to be generated.
"""
await super()._handle_interruption(frame, direction)
# Send Clear message to stop current audio generation
if self._websocket:
try:
clear_msg = {"type": "Clear"}
await self._websocket.send(json.dumps(clear_msg))
except Exception as e:
logger.error(f"{self} error sending Clear message: {e}")
async def _receive_messages(self):
"""Receive and process messages from Deepgram WebSocket."""
async for message in self._get_websocket():
if isinstance(message, bytes):
# Binary message contains audio data
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self.sample_rate, 1)
await self.push_frame(frame)
elif isinstance(message, str):
# Text message contains metadata or control messages
try:
msg = json.loads(message)
msg_type = msg.get("type")
if msg_type == "Metadata":
logger.trace(f"Received metadata: {msg}")
elif msg_type == "Flushed":
logger.trace(f"Received Flushed: {msg}")
# Flushed indicates the end of audio generation for the current buffer
# This happens after flush_audio() is called
elif msg_type == "Cleared":
logger.trace(f"Received Cleared: {msg}")
# Buffer has been cleared after interruption
# TTSStoppedFrame will be sent by the interruption handler
elif msg_type == "Warning":
logger.warning(
f"{self} warning: {msg.get('description', 'Unknown warning')}"
)
else:
logger.debug(f"Received unknown message type: {msg}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
async def flush_audio(self):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
generation of audio from Deepgram's internal text buffer.
"""
if self._websocket:
try:
flush_msg = {"type": "Flush"}
await self._websocket.send(json.dumps(flush_msg))
except Exception as e:
logger.error(f"{self} error sending Flush message: {e}")
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Deepgram's TTS API.
"""Generate speech from text using Deepgram's WebSocket TTS API.
Args:
text: The text to synthesize into speech.
@@ -91,33 +279,27 @@ class DeepgramTTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
options = SpeakOptions(
model=self._voice_id,
encoding=self._settings["encoding"],
sample_rate=self.sample_rate,
container="none",
)
try:
# Reconnect if the websocket is closed
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
await self.start_ttfb_metrics()
response = await self._deepgram_client.speak.asyncrest.v("1").stream_raw(
{"text": text}, options
)
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame()
async for data in response.aiter_bytes():
await self.stop_ttfb_metrics()
if data:
yield TTSAudioRawFrame(audio=data, sample_rate=self.sample_rate, num_channels=1)
# Send text message to Deepgram
# Note: We don't send Flush here - that should only be sent when the
# LLM finishes a complete response via flush_audio()
speak_msg = {"type": "Speak", "text": text}
await self._get_websocket().send(json.dumps(speak_msg))
yield TTSStoppedFrame()
# The audio frames will be handled in _receive_messages
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class DeepgramHttpTTSService(TTSService):
@@ -227,5 +409,4 @@ class DeepgramHttpTTSService(TTSService):
yield TTSStoppedFrame()
except Exception as e:
logger.exception(f"{self} exception: {e}")
yield ErrorFrame(f"Error getting audio: {str(e)}")

View File

@@ -351,8 +351,7 @@ class ElevenLabsSTTService(SegmentedSTTService):
)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
def audio_format_from_sample_rate(sample_rate: int) -> str:
@@ -598,7 +597,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
}
await self._websocket.send(json.dumps(message))
except Exception as e:
logger.error(f"Error sending audio: {e}")
yield ErrorFrame(f"ElevenLabs Realtime STT error: {str(e)}")
yield None
@@ -663,8 +661,9 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
await self._call_event_handler("on_connected")
logger.debug("Connected to ElevenLabs Realtime STT")
except Exception as e:
logger.error(f"{self}: unable to connect to ElevenLabs Realtime STT: {e}")
await self.push_error(ErrorFrame(f"Connection error: {str(e)}"))
await self.push_error(
error_msg=f"Unable to connect to ElevenLabs Realtime STT: {e}", exception=e
)
async def _disconnect_websocket(self):
"""Disconnect from ElevenLabs Realtime STT WebSocket."""
@@ -673,7 +672,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
logger.debug("Disconnecting from ElevenLabs Realtime STT")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -733,17 +732,17 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
elif message_type == "error":
error_msg = data.get("error", "Unknown error")
logger.error(f"ElevenLabs error: {error_msg}")
await self.push_error(ErrorFrame(f"Error: {error_msg}"))
await self.push_error(error_msg=f"Error: {error_msg}")
elif message_type == "auth_error":
error_msg = data.get("error", "Authentication error")
logger.error(f"ElevenLabs auth error: {error_msg}")
await self.push_error(ErrorFrame(f"Auth error: {error_msg}"))
await self.push_error(error_msg=f"Auth error: {error_msg}")
elif message_type == "quota_exceeded_error":
error_msg = data.get("error", "Quota exceeded")
logger.error(f"ElevenLabs quota exceeded: {error_msg}")
await self.push_error(ErrorFrame(f"Quota exceeded: {error_msg}"))
await self.push_error(error_msg=f"Quota exceeded: {error_msg}")
else:
logger.debug(f"Unknown message type: {message_type}")

View File

@@ -424,8 +424,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
json.dumps({"context_id": self._context_id, "close_context": True})
)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._context_id = None
self._started = False
@@ -536,9 +535,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
self._websocket = None
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
@@ -553,8 +551,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._websocket.close()
logger.debug("Disconnected from ElevenLabs")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._started = False
self._context_id = None
@@ -584,8 +581,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
json.dumps({"context_id": self._context_id, "close_context": True})
)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._context_id = None
self._started = False
self._partial_word = ""
@@ -740,15 +736,13 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
else:
await self._send_text(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield TTSStoppedFrame()
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
self._started = False
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class ElevenLabsHttpTTSService(WordTTSService):
@@ -1043,7 +1037,6 @@ class ElevenLabsHttpTTSService(WordTTSService):
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"{self} error: {error_text}")
yield ErrorFrame(error=f"ElevenLabs API error: {error_text}")
return
@@ -1091,8 +1084,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
logger.warning(f"Failed to parse JSON from stream: {e}")
continue
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
continue
# After processing all chunks, emit any remaining partial word
@@ -1116,8 +1108,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
self._previous_text = text
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
# Let the parent class handle TTSStoppedFrame

View File

@@ -110,7 +110,6 @@ class FalImageGenService(ImageGenService):
image_url = response["images"][0]["url"] if response else None
if not image_url:
logger.error(f"{self} error: image generation failed")
yield ErrorFrame("Image generation failed")
return

View File

@@ -290,5 +290,4 @@ class FalSTTService(SegmentedSTTService):
)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -76,7 +76,7 @@ class FishAudioTTSService(InterruptibleTTSService):
api_key: str,
reference_id: Optional[str] = None, # This is the voice ID
model: Optional[str] = None, # Deprecated
model_id: str = "speech-1.5",
model_id: str = "s1",
output_format: FishAudioOutputFormat = "pcm",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
@@ -93,7 +93,7 @@ class FishAudioTTSService(InterruptibleTTSService):
The `model` parameter is deprecated and will be removed in version 0.1.0.
Use `reference_id` instead to specify the voice model.
model_id: Specify which Fish Audio TTS model to use (e.g. "speech-1.5")
model_id: Specify which Fish Audio TTS model to use (e.g. "s1")
output_format: Audio output format. Defaults to "pcm".
sample_rate: Audio sample rate. If None, uses default.
params: Additional input parameters for voice customization.
@@ -228,8 +228,7 @@ class FishAudioTTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -243,8 +242,7 @@ class FishAudioTTSService(InterruptibleTTSService):
await self._websocket.send(ormsgpack.packb(stop_message))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._request_id = None
self._started = False
@@ -286,8 +284,7 @@ class FishAudioTTSService(InterruptibleTTSService):
continue
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
@@ -323,8 +320,7 @@ class FishAudioTTSService(InterruptibleTTSService):
flush_message = {"event": "flush"}
await self._get_websocket().send(ormsgpack.packb(flush_message))
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
@@ -332,5 +328,4 @@ class FishAudioTTSService(InterruptibleTTSService):
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -468,8 +468,7 @@ class GladiaSTTService(STTService):
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._connection_active = False
if not self._should_reconnect:
@@ -559,8 +558,7 @@ class GladiaSTTService(STTService):
except websockets.exceptions.ConnectionClosed:
logger.debug("Connection closed during keepalive")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def _receive_task_handler(self):
try:
@@ -623,8 +621,7 @@ class GladiaSTTService(STTService):
# Expected when closing the connection
pass
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def _maybe_reconnect(self) -> bool:
"""Handle exponential backoff reconnection logic."""
@@ -632,7 +629,9 @@ class GladiaSTTService(STTService):
return False
self._reconnection_attempts += 1
if self._reconnection_attempts > self._max_reconnection_attempts:
logger.error(f"Max reconnection attempts ({self._max_reconnection_attempts}) reached")
await self.push_error(
error_msg=f"Max reconnection attempts ({self._max_reconnection_attempts}) reached",
)
self._should_reconnect = False
return False
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))

View File

@@ -1175,7 +1175,7 @@ class GeminiLiveLLMService(LLMService):
self._connection_task = self.create_task(self._connection_task_handler(config=config))
except Exception as e:
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}"))
await self.push_error(error_msg=f"Initialization error: {e}", exception=e)
async def _connection_task_handler(self, config: LiveConnectConfig):
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
@@ -1252,11 +1252,11 @@ class GeminiLiveLLMService(LLMService):
)
if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
logger.error(
error_msg = (
f"Max consecutive failures ({MAX_CONSECUTIVE_FAILURES}) reached, "
"treating as fatal error"
)
await self.push_error(ErrorFrame(error=f"{self} Error in receive loop: {error}"))
await self.push_error(error_msg=error_msg, exception=error)
return False
else:
logger.info(
@@ -1284,7 +1284,7 @@ class GeminiLiveLLMService(LLMService):
self._completed_tool_calls = set()
self._disconnecting = False
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
async def _send_user_audio(self, frame):
"""Send user audio frame to Gemini Live API."""
@@ -1723,6 +1723,8 @@ class GeminiLiveLLMService(LLMService):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cache_read_input_tokens=usage.cached_content_token_count,
reasoning_tokens=usage.thoughts_token_count,
)
await self.start_llm_usage_metrics(tokens)
@@ -1743,7 +1745,7 @@ class GeminiLiveLLMService(LLMService):
# state management, and that exponential backoff for retries can have
# cost/stability implications for a service cluster, let's just treat a
# send-side error as fatal.
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}", fatal=True))
await self.push_error(error_msg=f"Send error: {error}")
def create_context_aggregator(
self,

View File

@@ -110,7 +110,6 @@ class GoogleImageGenService(ImageGenService):
await self.stop_ttfb_metrics()
if not response or not response.generated_images:
logger.error(f"{self} error: image generation failed")
yield ErrorFrame("Image generation failed")
return
@@ -128,5 +127,4 @@ class GoogleImageGenService(ImageGenService):
yield frame
except Exception as e:
logger.error(f"{self} error generating image: {e}")
yield ErrorFrame(f"Image generation error: {str(e)}")

View File

@@ -793,7 +793,7 @@ class GoogleLLMService(LLMService):
return
generation_params.setdefault("thinking_config", {})["thinking_budget"] = 0
except Exception as e:
logger.exception(f"Failed to unset thinking budget: {e}")
logger.error(f"Failed to unset thinking budget: {e}")
async def _stream_content(
self, params_from_context: GeminiLLMInvocationParams
@@ -983,7 +983,7 @@ class GoogleLLMService(LLMService):
except DeadlineExceeded:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
if grounding_metadata and isinstance(grounding_metadata, dict):
llm_search_frame = LLMSearchResponseFrame(

View File

@@ -774,8 +774,7 @@ class GoogleSTTService(STTService):
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise
async def _stream_audio(self):
@@ -806,15 +805,13 @@ class GoogleSTTService(STTService):
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await asyncio.sleep(1) # Brief delay before reconnecting
self._stream_start_time = int(time.time() * 1000)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process an audio chunk for STT transcription.
@@ -902,8 +899,7 @@ class GoogleSTTService(STTService):
)
raise
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
# Re-raise the exception to let it propagate (e.g. in the case of a
# timeout, propagate to _stream_audio to reconnect)
raise

View File

@@ -737,7 +737,6 @@ class GoogleHttpTTSService(TTSService):
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}")
error_message = f"TTS generation error: {str(e)}"
yield ErrorFrame(error=error_message)
@@ -996,9 +995,7 @@ class GoogleTTSService(GoogleBaseTTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
error_message = f"TTS generation error: {str(e)}"
yield ErrorFrame(error=error_message)
await self.push_error(error_msg=f"TTS generation error: {str(e)}", exception=e)
class GeminiTTSService(GoogleBaseTTSService):
@@ -1248,6 +1245,5 @@ class GeminiTTSService(GoogleBaseTTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
error_message = f"Gemini TTS generation error: {str(e)}"
yield ErrorFrame(error=error_message)

View File

@@ -123,6 +123,8 @@ class GrokLLMService(OpenAILLMService):
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._cache_read_input_tokens = None
self._reasoning_tokens = None
self._has_reported_prompt_tokens = False
self._is_processing = True
@@ -137,6 +139,8 @@ class GrokLLMService(OpenAILLMService):
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
cache_read_input_tokens=self._cache_read_input_tokens,
reasoning_tokens=self._reasoning_tokens,
)
await super().start_llm_usage_metrics(tokens)
@@ -149,7 +153,7 @@ class GrokLLMService(OpenAILLMService):
Args:
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
containing prompt_tokens, completion_tokens, and optional cached/reasoning tokens.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
@@ -164,6 +168,13 @@ class GrokLLMService(OpenAILLMService):
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens
# Capture cached & reasoning tokens (these typically only appear once per request)
if tokens.cache_read_input_tokens is not None:
self._cache_read_input_tokens = tokens.cache_read_input_tokens
if tokens.reasoning_tokens is not None:
self._reasoning_tokens = tokens.reasoning_tokens
def create_context_aggregator(
self,
context: OpenAILLMContext,

View File

@@ -146,7 +146,6 @@ class GroqTTSService(TTSService):
bytes = w.readframes(num_frames)
yield TTSAudioRawFrame(bytes, frame_rate, channels)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()

View File

@@ -179,7 +179,7 @@ class HeyGenClient:
await self._task_manager.cancel_task(self._event_task)
self._event_task = None
except Exception as e:
logger.exception(f"Exception during cleanup: {e}")
logger.error(f"Exception during cleanup: {e}")
async def start(self, frame: StartFrame, audio_chunk_size: int) -> None:
"""Start the client and establish all necessary connections.

View File

@@ -287,8 +287,7 @@ class HumeTTSService(WordTTSService):
self._cumulative_time = utterance_duration
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
# Ensure TTFB timer is stopped even on early failures
await self.stop_ttfb_metrics()

View File

@@ -397,8 +397,7 @@ class InworldTTSService(TTSService):
# STEP 7: ERROR HANDLING
# ================================================================================
# Log any unexpected errors and notify the pipeline
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
# ================================================================================
# STEP 8: CLEANUP AND COMPLETION
@@ -513,7 +512,7 @@ class InworldTTSService(TTSService):
# Extract the base64-encoded audio content from response
if "audioContent" not in response_data:
logger.error("No audioContent in Inworld API response")
await self.push_error(ErrorFrame(error="No audioContent in response"))
yield ErrorFrame(error="No audioContent in response")
return
# ================================================================================

View File

@@ -166,23 +166,27 @@ class LLMService(AIService):
# However, subclasses should override this with a more specific adapter when necessary.
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
def __init__(self, run_in_parallel: bool = True, **kwargs):
def __init__(self, run_in_parallel: bool = True, wait_for_all: bool = False, **kwargs):
"""Initialize the LLM service.
Args:
run_in_parallel: Whether to run function calls in parallel or sequentially.
Defaults to True.
wait_for_all: Whether to wait for all function calls (parallel or
sequential) to complete. Defaults to False.
**kwargs: Additional arguments passed to the parent AIService.
"""
super().__init__(**kwargs)
self._run_in_parallel = run_in_parallel
self._wait_for_all = wait_for_all
self._start_callbacks = {}
self._adapter = self.adapter_class()
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
self._sequential_runner_task: Optional[asyncio.Task] = None
self._tracing_enabled: bool = False
self._skip_tts: bool = False
self._skip_tts: Optional[bool] = None
self._register_event_handler("on_function_calls_started")
self._register_event_handler("on_completion_timeout")
@@ -293,7 +297,8 @@ class LLMService(AIService):
direction: The direction of frame pushing.
"""
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
frame.skip_tts = self._skip_tts
if self._skip_tts is not None:
frame.skip_tts = self._skip_tts
await super().push_frame(frame, direction)
@@ -435,6 +440,7 @@ class LLMService(AIService):
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
runner_items = []
for function_call in function_calls:
if function_call.function_name in self._functions.keys():
item = self._functions[function_call.function_name]
@@ -446,28 +452,20 @@ class LLMService(AIService):
)
continue
runner_item = FunctionCallRunnerItem(
registry_item=item,
function_name=function_call.function_name,
tool_call_id=function_call.tool_call_id,
arguments=function_call.arguments,
context=function_call.context,
runner_items.append(
FunctionCallRunnerItem(
registry_item=item,
function_name=function_call.function_name,
tool_call_id=function_call.tool_call_id,
arguments=function_call.arguments,
context=function_call.context,
)
)
if self._run_in_parallel:
task = self.create_task(self._run_function_call(runner_item))
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
else:
await self._sequential_runner_queue.put(runner_item)
async def _call_start_function(
self, context: OpenAILLMContext | LLMContext, function_name: str
):
if function_name in self._start_callbacks.keys():
await self._start_callbacks[function_name](function_name, self, context)
elif None in self._start_callbacks.keys():
return await self._start_callbacks[None](function_name, self, context)
if self._run_in_parallel:
await self._run_parallel_function_calls(runner_items)
else:
await self._run_sequential_function_calls(runner_items)
async def request_image_frame(
self,
@@ -540,6 +538,46 @@ class LLMService(AIService):
await task
del self._function_call_tasks[task]
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
tasks = []
for runner_item in runner_items:
task = self.create_task(self._run_function_call(runner_item))
tasks.append(task)
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
if self._wait_for_all:
# Protect gather from being cancelled. This will protect all tasks
# form being cancelled. That is fine, because we cancel them
# explicitly when handling the interruption (InterruptionFrame). We
# need to set `return_exceptions=True` because `asyncio.shield()`
# will get cancelled (from FrameProcessor process task), then
# `asyncio.gather()` will keep running (because it was protected by
# the shield). Then, individiaul function call tasks will be
# cancelled by us and we don't need to propagate those
# CancelledErrors at that point.
await asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
if self._wait_for_all:
# Run each function call sequentially, waiting for each to complete.
for runner_item in runner_items:
self._function_call_tasks[None] = runner_item
await self._run_function_call(runner_item)
del self._function_call_tasks[None]
else:
# Enqueue all function calls for background execution.
for runner_item in runner_items:
await self._sequential_runner_queue.put(runner_item)
async def _call_start_function(
self, context: OpenAILLMContext | LLMContext, function_name: str
):
if function_name in self._start_callbacks.keys():
await self._start_callbacks[function_name](function_name, self, context)
elif None in self._start_callbacks.keys():
return await self._start_callbacks[None](function_name, self, context)
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
if runner_item.function_name in self._functions.keys():
item = self._functions[runner_item.function_name]
@@ -623,20 +661,19 @@ class LLMService(AIService):
name = runner_item.function_name
tool_call_id = runner_item.tool_call_id
# We remove the callback because we are going to cancel the task
# now, otherwise we will be removing it from the set while we
# are iterating.
task.remove_done_callback(self._function_call_task_finished)
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
await self.cancel_task(task)
if task:
# We remove the callback because we are going to cancel the
# task next, otherwise we will be removing it from the set
# while we are iterating.
task.remove_done_callback(self._function_call_task_finished)
await self.cancel_task(task)
cancelled_tasks.add(task)
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
await self.push_frame(frame)
cancelled_tasks.add(task)
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
# Remove all cancelled tasks from our set.

View File

@@ -214,8 +214,7 @@ class LmntTTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -231,8 +230,7 @@ class LmntTTSService(InterruptibleTTSService):
# await self._websocket.send(json.dumps({"eof": True}))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
finally:
self._started = False
self._websocket = None
@@ -266,10 +264,9 @@ class LmntTTSService(InterruptibleTTSService):
try:
msg = json.loads(message)
if "error" in msg:
logger.error(f"{self} error: {msg['error']}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
await self.push_error(error_msg=f"Error: {msg['error']}")
return
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
@@ -302,13 +299,11 @@ class LmntTTSService(InterruptibleTTSService):
await self._get_websocket().send(json.dumps({"flush": True}))
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -176,7 +176,6 @@ class MCPClient(BaseObject):
except Exception as e:
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
logger.error(error_msg)
logger.exception("Full exception details:")
await params.result_callback(error_msg)
async def _stdio_list_tools(self) -> ToolsSchema:
@@ -207,7 +206,6 @@ class MCPClient(BaseObject):
except Exception as e:
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
logger.error(error_msg)
logger.exception("Full exception details:")
await params.result_callback(error_msg)
async def _streamable_http_list_tools(self) -> ToolsSchema:
@@ -246,7 +244,6 @@ class MCPClient(BaseObject):
except Exception as e:
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
logger.error(error_msg)
logger.exception("Full exception details:")
await params.result_callback(error_msg)
async def _call_tool(self, session, function_name, arguments, result_callback):
@@ -302,7 +299,6 @@ class MCPClient(BaseObject):
except Exception as e:
logger.error(f"Failed to read tool '{tool_name}': {str(e)}")
logger.exception("Full exception details:")
continue
logger.debug(f"Completed reading {len(tool_schemas)} tools")

View File

@@ -253,8 +253,9 @@ class Mem0MemoryService(FrameProcessor):
# Otherwise, pass the enhanced context frame downstream
await self.push_frame(frame)
except Exception as e:
logger.error(f"Error processing with Mem0: {str(e)}")
await self.push_frame(ErrorFrame(f"Error processing with Mem0: {str(e)}"))
await self.push_error(
error_msg=f"Error processing with Mem0: {str(e)}", exception=e
)
await self.push_frame(frame) # Still pass the original frame through
else:
# For non-context frames, just pass them through

View File

@@ -314,7 +314,6 @@ class MiniMaxHttpTTSService(TTSService):
) as response:
if response.status != 200:
error_message = f"MiniMax TTS error: HTTP {response.status}"
logger.error(error_message)
yield ErrorFrame(error=error_message)
return
@@ -392,8 +391,7 @@ class MiniMaxHttpTTSService(TTSService):
continue
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -110,7 +110,6 @@ class MoondreamService(VisionService):
if analysis fails.
"""
if not self._model:
logger.error(f"{self} error: Moondream model not available ({self.model_name})")
yield ErrorFrame("Moondream model not available")
return

View File

@@ -285,8 +285,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -299,8 +298,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
logger.debug("Disconnecting from Neuphonic")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._started = False
self._websocket = None
@@ -365,16 +363,14 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class NeuphonicHttpTTSService(TTSService):
@@ -538,7 +534,6 @@ class NeuphonicHttpTTSService(TTSService):
if response.status != 200:
error_text = await response.text()
error_message = f"Neuphonic API error: HTTP {response.status} - {error_text}"
logger.error(error_message)
yield ErrorFrame(error=error_message)
return
@@ -568,8 +563,7 @@ class NeuphonicHttpTTSService(TTSService):
yield TTSAudioRawFrame(audio_bytes, self.sample_rate, 1)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
# Don't yield error frame for individual message failures
continue
@@ -577,8 +571,7 @@ class NeuphonicHttpTTSService(TTSService):
logger.debug("TTS generation cancelled")
raise
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -8,98 +8,23 @@
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
Microservice) API while maintaining compatibility with the OpenAI-style interface.
.. deprecated:: 0.0.96
This module is deprecated. Please NvidiaLLMService from
pipecat.services.nvidia.llm instead.
"""
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.llm import OpenAILLMService
import warnings
from pipecat.services.nvidia.llm import NvidiaLLMService
class NimLLMService(OpenAILLMService):
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"NimLLMService from pipecat.services.nim.llm is deprecated. "
"Please use NvidiaLLMService from pipecat.services.nvidia.llm instead.",
DeprecationWarning,
stacklevel=2,
)
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
compatibility with the OpenAI-style interface. It specifically handles the difference
in token usage reporting between NIM (incremental) and OpenAI (final summary).
"""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://integrate.api.nvidia.com/v1",
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
**kwargs,
):
"""Initialize the NimLLMService.
Args:
api_key: The API key for accessing NVIDIA's NIM API.
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
# Counters for accumulating token usage metrics
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = False
async def _process_context(self, context: OpenAILLMContext | LLMContext):
"""Process a context through the LLM and accumulate token usage metrics.
This method overrides the parent class implementation to handle NVIDIA's
incremental token reporting style, accumulating the counts and reporting
them once at the end of processing.
Args:
context: The context to process, containing messages and other information
needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = True
try:
await super()._process_context(context)
finally:
self._is_processing = False
# Report final accumulated token usage at the end of processing
if self._prompt_tokens > 0 or self._completion_tokens > 0:
self._total_tokens = self._prompt_tokens + self._completion_tokens
tokens = LLMTokenUsage(
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
)
await super().start_llm_usage_metrics(tokens)
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
"""Accumulate token usage metrics during processing.
This method intercepts the incremental token updates from NVIDIA's API
and accumulates them instead of passing each update to the metrics system.
The final accumulated totals are reported at the end of processing.
Args:
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
return
# Record prompt tokens the first time we see them
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
self._prompt_tokens = tokens.prompt_tokens
self._has_reported_prompt_tokens = True
# Update completion tokens count if it has increased
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens
NimLLMService = NvidiaLLMService

View File

View File

@@ -0,0 +1,105 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA NIM API service implementation.
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
Microservice) API while maintaining compatibility with the OpenAI-style interface.
"""
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.llm import OpenAILLMService
class NvidiaLLMService(OpenAILLMService):
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
compatibility with the OpenAI-style interface. It specifically handles the difference
in token usage reporting between NIM (incremental) and OpenAI (final summary).
"""
def __init__(
self,
*,
api_key: str,
base_url: str = "https://integrate.api.nvidia.com/v1",
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
**kwargs,
):
"""Initialize the NvidiaLLMService.
Args:
api_key: The API key for accessing NVIDIA's NIM API.
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
# Counters for accumulating token usage metrics
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = False
async def _process_context(self, context: OpenAILLMContext | LLMContext):
"""Process a context through the LLM and accumulate token usage metrics.
This method overrides the parent class implementation to handle NVIDIA's
incremental token reporting style, accumulating the counts and reporting
them once at the end of processing.
Args:
context: The context to process, containing messages and other information
needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_tokens = 0
self._has_reported_prompt_tokens = False
self._is_processing = True
try:
await super()._process_context(context)
finally:
self._is_processing = False
# Report final accumulated token usage at the end of processing
if self._prompt_tokens > 0 or self._completion_tokens > 0:
self._total_tokens = self._prompt_tokens + self._completion_tokens
tokens = LLMTokenUsage(
prompt_tokens=self._prompt_tokens,
completion_tokens=self._completion_tokens,
total_tokens=self._total_tokens,
)
await super().start_llm_usage_metrics(tokens)
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
"""Accumulate token usage metrics during processing.
This method intercepts the incremental token updates from NVIDIA's API
and accumulates them instead of passing each update to the metrics system.
The final accumulated totals are reported at the end of processing.
Args:
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
return
# Record prompt tokens the first time we see them
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
self._prompt_tokens = tokens.prompt_tokens
self._has_reported_prompt_tokens = True
# Update completion tokens count if it has increased
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens

View File

@@ -0,0 +1,663 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
import asyncio
from concurrent.futures import CancelledError as FuturesCancelledError
from typing import AsyncGenerator, List, Mapping, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import SegmentedSTTService, STTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.")
raise Exception(f"Missing module: {e}")
def language_to_nvidia_riva_language(language: Language) -> Optional[str]:
"""Maps Language enum to NVIDIA Riva ASR language codes.
Source:
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
Args:
language: Language enum value.
Returns:
Optional[str]: NVIDIA Riva language code or None if not supported.
"""
LANGUAGE_MAP = {
# Arabic
Language.AR: "ar-AR",
# English
Language.EN: "en-US", # Default to US
Language.EN_US: "en-US",
Language.EN_GB: "en-GB",
# French
Language.FR: "fr-FR",
Language.FR_FR: "fr-FR",
# German
Language.DE: "de-DE",
Language.DE_DE: "de-DE",
# Hindi
Language.HI: "hi-IN",
Language.HI_IN: "hi-IN",
# Italian
Language.IT: "it-IT",
Language.IT_IT: "it-IT",
# Japanese
Language.JA: "ja-JP",
Language.JA_JP: "ja-JP",
# Korean
Language.KO: "ko-KR",
Language.KO_KR: "ko-KR",
# Portuguese
Language.PT: "pt-BR", # Default to Brazilian
Language.PT_BR: "pt-BR",
# Russian
Language.RU: "ru-RU",
Language.RU_RU: "ru-RU",
# Spanish
Language.ES: "es-ES", # Default to Spain
Language.ES_ES: "es-ES",
Language.ES_US: "es-US", # US Spanish
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
class NvidiaSTTService(STTService):
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
through streaming recognition. Supports interim results and continuous audio
processing for low-latency applications.
"""
class InputParams(BaseModel):
"""Configuration parameters for NVIDIA Riva STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
"""
language: Optional[Language] = Language.EN_US
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
"model_name": "parakeet-ctc-1.1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva STT service.
Args:
api_key: NVIDIA API key for authentication.
server: NVIDIA Riva server address. Defaults to NVIDIA Cloud Function endpoint.
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for NVIDIA Riva.
**kwargs: Additional arguments passed to STTService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or NvidiaSTTService.InputParams()
self._api_key = api_key
self._profanity_filter = False
self._automatic_punctuation = True
self._no_verbatim_transcripts = False
self._language_code = params.language
self._boosted_lm_words = None
self._boosted_lm_score = 4.0
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
self._function_id = model_function_map.get("function_id")
self._settings = {
"language": str(params.language),
"profanity_filter": self._profanity_filter,
"automatic_punctuation": self._automatic_punctuation,
"verbatim_transcripts": not self._no_verbatim_transcripts,
"boosted_lm_words": self._boosted_lm_words,
"boosted_lm_score": self._boosted_lm_score,
}
self.set_model_name(model_function_map.get("model_name"))
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._asr_service = riva.client.ASRService(auth)
self._queue = None
self._config = None
self._thread_task = None
self._response_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
False - this service does not support metrics generation.
"""
return False
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Start the NVIDIA Riva STT service and initialize streaming configuration.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
if self._config:
return
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self.sample_rate,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
self._queue = asyncio.Queue()
if not self._thread_task:
self._thread_task = self.create_task(self._thread_task_handler())
if not self._response_task:
self._response_queue = asyncio.Queue()
self._response_task = self.create_task(self._response_task_handler())
async def stop(self, frame: EndFrame):
"""Stop the NVIDIA Riva STT service and clean up resources.
Args:
frame: EndFrame indicating pipeline stop.
"""
await super().stop(frame)
await self._stop_tasks()
async def cancel(self, frame: CancelFrame):
"""Cancel the NVIDIA Riva STT service operation.
Args:
frame: CancelFrame indicating operation cancellation.
"""
await super().cancel(frame)
await self._stop_tasks()
async def _stop_tasks(self):
if self._thread_task:
await self.cancel_task(self._thread_task)
self._thread_task = None
if self._response_task:
await self.cancel_task(self._response_task)
self._response_task = None
def _response_handler(self):
responses = self._asr_service.streaming_response_generator(
audio_chunks=self,
streaming_config=self._config,
)
for response in responses:
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
self._response_queue.put(response), self.get_event_loop()
)
async def _thread_task_handler(self):
try:
self._thread_running = True
await asyncio.to_thread(self._response_handler)
except asyncio.CancelledError:
self._thread_running = False
raise
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def _handle_response(self, response):
for result in response.results:
if result and not result.alternatives:
continue
transcript = result.alternatives[0].transcript
if transcript and len(transcript) > 0:
await self.stop_ttfb_metrics()
if result.is_final:
await self.stop_processing_metrics()
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
await self._handle_transcription(
transcript=transcript,
is_final=result.is_final,
language=self._language_code,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
async def _response_task_handler(self):
while True:
response = await self._response_queue.get()
await self._handle_response(response)
self._response_queue.task_done()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
None - transcription results are pushed to the pipeline via frames.
"""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._queue.put(audio)
yield None
def __next__(self) -> bytes:
"""Get the next audio chunk for NVIDIA Riva processing.
Returns:
Audio bytes from the queue.
Raises:
StopIteration: When the thread is no longer running.
"""
if not self._thread_running:
raise StopIteration
try:
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
return future.result()
except FuturesCancelledError:
raise StopIteration
def __iter__(self):
"""Return iterator for audio chunk processing.
Returns:
Self as iterator.
"""
return self
class NvidiaSegmentedSTTService(SegmentedSTTService):
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
transcription on audio segments. It inherits from SegmentedSTTService to handle
audio buffering and speech detection.
"""
class InputParams(BaseModel):
"""Configuration parameters for NVIDIA Riva segmented STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
profanity_filter: Whether to filter profanity from results.
automatic_punctuation: Whether to add automatic punctuation.
verbatim_transcripts: Whether to return verbatim transcripts.
boosted_lm_words: List of words to boost in language model.
boosted_lm_score: Score boost for specified words.
"""
language: Optional[Language] = Language.EN_US
profanity_filter: bool = False
automatic_punctuation: bool = True
verbatim_transcripts: bool = False
boosted_lm_words: Optional[List[str]] = None
boosted_lm_score: float = 4.0
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
"model_name": "canary-1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva segmented STT service.
Args:
api_key: NVIDIA API key for authentication
server: NVIDIA Riva server address (defaults to NVIDIA Cloud Function endpoint)
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
params: Additional configuration parameters for NVIDIA Riva
**kwargs: Additional arguments passed to SegmentedSTTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or NvidiaSegmentedSTTService.InputParams()
# Set model name
self.set_model_name(model_function_map.get("model_name"))
# Initialize NVIDIA Riva settings
self._api_key = api_key
self._server = server
self._function_id = model_function_map.get("function_id")
self._model_name = model_function_map.get("model_name")
# Store the language as a Language enum and as a string
self._language_enum = params.language or Language.EN_US
self._language = self.language_to_service_language(self._language_enum) or "en-US"
# Configure transcription parameters
self._profanity_filter = params.profanity_filter
self._automatic_punctuation = params.automatic_punctuation
self._verbatim_transcripts = params.verbatim_transcripts
self._boosted_lm_words = params.boosted_lm_words
self._boosted_lm_score = params.boosted_lm_score
# Voice activity detection thresholds (use NVIDIA Riva defaults)
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
# Create NVIDIA Riva client
self._config = None
self._asr_service = None
self._settings = {"language": self._language_enum}
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert pipecat Language enum to NVIDIA Riva's language code.
Args:
language: Language enum value.
Returns:
NVIDIA Riva language code or None if not supported.
"""
return language_to_nvidia_riva_language(language)
def _initialize_client(self):
"""Initialize the NVIDIA Riva ASR client with authentication metadata."""
if self._asr_service is not None:
return
# Set up authentication metadata for NVIDIA Cloud Functions
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
# Create authenticated client
auth = riva.client.Auth(None, True, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
def _create_recognition_config(self):
"""Create the NVIDIA Riva ASR recognition configuration."""
# Create base configuration
config = riva.client.RecognitionConfig(
language_code=self._language, # Now using the string, not a tuple
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=self._verbatim_transcripts,
)
# Add word boosting if specified
if self._boosted_lm_words:
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
# Add voice activity detection parameters
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
# Add any custom configuration
if self._custom_configuration:
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
return config
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True - this service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Initialize the service when the pipeline starts.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_recognition_config()
async def set_language(self, language: Language):
"""Set the language for the STT service.
Args:
language: Target language for transcription.
"""
logger.info(f"Switching STT language to: [{language}]")
self._language_enum = language
self._language = self.language_to_service_language(language) or "en-US"
self._settings["language"] = language
# Update configuration with new language
if self._config:
self._config.language_code = self._language
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Transcribe an audio segment.
Args:
audio: Raw audio bytes in WAV format (already converted by base class).
Yields:
Frame: TranscriptionFrame containing the transcribed text.
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Make sure the client is initialized
if self._asr_service is None:
self._initialize_client()
# Make sure the config is created
if self._config is None:
self._config = self._create_recognition_config()
# Type assertion to satisfy the IDE
assert self._asr_service is not None, "ASR service not initialized"
assert self._config is not None, "Recognition config not created"
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Process the response - handle different possible return types
try:
# If it's a future-like object, get the result
if hasattr(raw_response, "result"):
response = raw_response.result()
else:
response = raw_response
# Process transcription results
transcription_found = False
# Now we can safely check results
# Type hint for the IDE
results = getattr(response, "results", [])
for result in results:
alternatives = getattr(result, "alternatives", [])
if alternatives:
text = alternatives[0].transcript.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
self._language_enum,
)
transcription_found = True
await self._handle_transcription(text, True, self._language_enum)
if not transcription_found:
logger.debug("No transcription results found in NVIDIA Riva response")
except AttributeError as ae:
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")

View File

@@ -0,0 +1,187 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva text-to-speech service implementation.
This module provides integration with NVIDIA Riva's TTS services through
gRPC API for high-quality speech synthesis.
"""
import asyncio
import os
from typing import AsyncGenerator, Mapping, Optional
from pipecat.utils.tracing.service_decorators import traced_tts
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
raise Exception(f"Missing module: {e}")
NVIDIA_TTS_TIMEOUT_SECS = 5
class NvidiaTTSService(TTSService):
"""NVIDIA Riva text-to-speech service.
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
cloud-based TTS models. Supports multiple voices, languages, and
configurable quality settings.
"""
class InputParams(BaseModel):
"""Input parameters for Riva TTS configuration.
Parameters:
language: Language code for synthesis. Defaults to US English.
quality: Audio quality setting (0-100). Defaults to 20.
"""
language: Optional[Language] = Language.EN_US
quality: Optional[int] = 20
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
sample_rate: Optional[int] = None,
model_function_map: Mapping[str, str] = {
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
"model_name": "magpie-tts-multilingual",
},
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva TTS service.
Args:
api_key: NVIDIA API key for authentication.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for the TTS model.
params: Additional configuration parameters for TTS synthesis.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or NvidiaTTSService.InputParams()
self._api_key = api_key
self._voice_id = voice_id
self._language_code = params.language
self._quality = params.quality
self._function_id = model_function_map.get("function_id")
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
config_response = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
async def set_model(self, model: str):
"""Attempt to set the TTS model.
Note: Model cannot be changed after initialization for Riva service.
Args:
model: The model name to set (operation not supported).
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech data.
"""
def read_audio_responses(queue: asyncio.Queue):
def add_response(r):
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
try:
responses = self._service.synthesize_online(
text,
self._voice_id,
self._language_code,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._quality,
custom_dictionary={},
)
for r in responses:
add_response(r)
add_response(None)
except Exception as e:
logger.error(f"{self} exception: {e}")
add_response(None)
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
try:
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)
# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()

View File

@@ -346,11 +346,17 @@ class BaseOpenAILLMService(LLMService):
if chunk.usage.prompt_tokens_details
else None
)
reasoning_tokens = (
chunk.usage.completion_tokens_details.reasoning_tokens
if chunk.usage.completion_tokens_details
else None
)
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
cache_read_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
await self.start_llm_usage_metrics(tokens)

View File

@@ -76,7 +76,6 @@ class OpenAIImageGenService(ImageGenService):
image_url = image.data[0].url
if not image_url:
logger.error(f"{self} No image provided in response: {image}")
yield ErrorFrame("Image generation failed")
return

View File

@@ -57,7 +57,6 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt
@@ -444,7 +443,7 @@ class OpenAIRealtimeLLMService(LLMService):
)
self._receive_task = self.create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"{self} initialization error: {e}")
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
self._websocket = None
async def _disconnect(self):
@@ -461,7 +460,7 @@ class OpenAIRealtimeLLMService(LLMService):
self._completed_tool_calls = set()
self._disconnecting = False
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
async def _ws_send(self, realtime_message):
try:
@@ -474,12 +473,11 @@ class OpenAIRealtimeLLMService(LLMService):
# somehow *started* the websocket send attempt while we still
# had a connection)
return
logger.error(f"Error sending message to websocket: {e}")
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
# it is to recover from a send-side error with proper state management, and that exponential
# backoff for retries can have cost/stability implications for a service cluster, let's just
# treat a send-side error as fatal.
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
async def _update_settings(self):
settings = self._session_properties
@@ -657,10 +655,17 @@ class OpenAIRealtimeLLMService(LLMService):
async def _handle_evt_response_done(self, evt):
# todo: figure out whether there's anything we need to do for "cancelled" events
# usage metrics
cached_tokens = (
evt.response.usage.input_token_details.cached_tokens
if hasattr(evt.response.usage, "input_token_details")
and evt.response.usage.input_token_details
else None
)
tokens = LLMTokenUsage(
prompt_tokens=evt.response.usage.input_tokens,
completion_tokens=evt.response.usage.output_tokens,
total_tokens=evt.response.usage.total_tokens,
cache_read_input_tokens=cached_tokens,
)
await self.start_llm_usage_metrics(tokens)
await self.stop_processing_metrics()
@@ -668,7 +673,7 @@ class OpenAIRealtimeLLMService(LLMService):
self._current_assistant_response = None
# error handling
if evt.response.status == "failed":
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
await self.push_error(error_msg=evt.response.status_details["error"]["message"])
return
# response content
for item in evt.response.output:
@@ -760,7 +765,7 @@ class OpenAIRealtimeLLMService(LLMService):
async def _handle_evt_error(self, evt):
# Errors are fatal to this connection. Send an ErrorFrame.
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
await self.push_error(error_msg=f"Error: {evt}")
#
# state and client events for the current conversation
@@ -810,7 +815,7 @@ class OpenAIRealtimeLLMService(LLMService):
# We're done configuring the LLM for this session
self._llm_needs_conversation_setup = False
logger.debug(f"Creating response")
logger.debug("Creating response")
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()

View File

@@ -206,5 +206,4 @@ class OpenAITTSService(TTSService):
yield frame
yield TTSStoppedFrame()
except BadRequestError as e:
logger.exception(f"{self} error generating TTS: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -79,5 +79,5 @@ class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
)
self._receive_task = self.create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"{self} initialization error: {e}")
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
self._websocket = None

View File

@@ -425,7 +425,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
)
self._receive_task = self.create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"{self} initialization error: {e}")
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
self._websocket = None
async def _disconnect(self):
@@ -441,7 +441,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
self._receive_task = None
self._disconnecting = False
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
async def _ws_send(self, realtime_message):
try:
@@ -450,12 +450,11 @@ class OpenAIRealtimeBetaLLMService(LLMService):
except Exception as e:
if self._disconnecting:
return
logger.error(f"Error sending message to websocket: {e}")
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
# it is to recover from a send-side error with proper state management, and that exponential
# backoff for retries can have cost/stability implications for a service cluster, let's just
# treat a send-side error as fatal.
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
async def _update_settings(self):
settings = self._session_properties
@@ -686,7 +685,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_error(self, evt):
# Errors are fatal to this connection. Send an ErrorFrame.
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
await self.push_error(error_msg=f"Error: {evt}")
async def _handle_assistant_output(self, output):
# We haven't seen intermixed audio and function_call items in the same response. But let's

View File

@@ -88,9 +88,6 @@ class PiperTTSService(TTSService):
) as response:
if response.status != 200:
error = await response.text()
logger.error(
f"{self} error getting audio (status: {response.status}, error: {error})"
)
yield ErrorFrame(
error=f"Error getting audio (status: {response.status}, error: {error})"
)
@@ -109,7 +106,7 @@ class PiperTTSService(TTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
logger.debug(f"{self}: Finished TTS [{text}]")
await self.stop_ttfb_metrics()

View File

@@ -266,8 +266,7 @@ class PlayHTTTSService(InterruptibleTTSService):
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -280,8 +279,7 @@ class PlayHTTTSService(InterruptibleTTSService):
logger.debug("Disconnecting from PlayHT")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
finally:
self._request_id = None
self._websocket = None
@@ -351,8 +349,7 @@ class PlayHTTTSService(InterruptibleTTSService):
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
await self.push_error(error_msg=f"Error: {msg['error']}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
@@ -394,8 +391,7 @@ class PlayHTTTSService(InterruptibleTTSService):
await self._get_websocket().send(json.dumps(tts_command))
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
@@ -405,8 +401,7 @@ class PlayHTTTSService(InterruptibleTTSService):
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class PlayHTHttpTTSService(TTSService):
@@ -626,8 +621,7 @@ class PlayHTHttpTTSService(TTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -300,8 +300,7 @@ class RimeTTSService(AudioContextWordTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -313,8 +312,7 @@ class RimeTTSService(AudioContextWordTTSService):
await self._websocket.send(json.dumps(self._build_eos_msg()))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
@@ -407,10 +405,9 @@ class RimeTTSService(AudioContextWordTTSService):
logger.debug(f"Updated cumulative time to: {self._cumulative_time}")
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
await self.push_error(error_msg=f"Error: {msg['message']}")
self._context_id = None
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -452,16 +449,14 @@ class RimeTTSService(AudioContextWordTTSService):
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class RimeHttpTTSService(TTSService):
@@ -592,7 +587,6 @@ class RimeHttpTTSService(TTSService):
) as response:
if response.status != 200:
error_message = f"Rime TTS error: HTTP {response.status}"
logger.error(error_message)
yield ErrorFrame(error=error_message)
return
@@ -610,8 +604,7 @@ class RimeHttpTTSService(TTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -4,709 +4,32 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription.
import asyncio
from concurrent.futures import CancelledError as FuturesCancelledError
from typing import AsyncGenerator, List, Mapping, Optional
.. deprecated:: 0.0.96
This module is deprecated. Please NvidiaSTTService from
pipecat.services.nvidia.stt instead.
"""
from loguru import logger
from pydantic import BaseModel
import warnings
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
from pipecat.services.nvidia.stt import (
NvidiaSegmentedSTTService,
NvidiaSTTService,
language_to_nvidia_riva_language,
)
from pipecat.services.stt_service import SegmentedSTTService, STTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}")
def language_to_riva_language(language: Language) -> Optional[str]:
"""Maps Language enum to Riva ASR language codes.
Source:
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
Args:
language: Language enum value.
Returns:
Optional[str]: Riva language code or None if not supported.
"""
LANGUAGE_MAP = {
# Arabic
Language.AR: "ar-AR",
# English
Language.EN: "en-US", # Default to US
Language.EN_US: "en-US",
Language.EN_GB: "en-GB",
# French
Language.FR: "fr-FR",
Language.FR_FR: "fr-FR",
# German
Language.DE: "de-DE",
Language.DE_DE: "de-DE",
# Hindi
Language.HI: "hi-IN",
Language.HI_IN: "hi-IN",
# Italian
Language.IT: "it-IT",
Language.IT_IT: "it-IT",
# Japanese
Language.JA: "ja-JP",
Language.JA_JP: "ja-JP",
# Korean
Language.KO: "ko-KR",
Language.KO_KR: "ko-KR",
# Portuguese
Language.PT: "pt-BR", # Default to Brazilian
Language.PT_BR: "pt-BR",
# Russian
Language.RU: "ru-RU",
Language.RU_RU: "ru-RU",
# Spanish
Language.ES: "es-ES", # Default to Spain
Language.ES_ES: "es-ES",
Language.ES_US: "es-US", # US Spanish
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
class RivaSTTService(STTService):
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
through streaming recognition. Supports interim results and continuous audio
processing for low-latency applications.
"""
class InputParams(BaseModel):
"""Configuration parameters for Riva STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
"""
language: Optional[Language] = Language.EN_US
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
"model_name": "parakeet-ctc-1.1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Riva STT service.
Args:
api_key: NVIDIA API key for authentication.
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for Riva.
**kwargs: Additional arguments passed to STTService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RivaSTTService.InputParams()
self._api_key = api_key
self._profanity_filter = False
self._automatic_punctuation = True
self._no_verbatim_transcripts = False
self._language_code = params.language
self._boosted_lm_words = None
self._boosted_lm_score = 4.0
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
self._function_id = model_function_map.get("function_id")
self._settings = {
"language": str(params.language),
"profanity_filter": self._profanity_filter,
"automatic_punctuation": self._automatic_punctuation,
"verbatim_transcripts": not self._no_verbatim_transcripts,
"boosted_lm_words": self._boosted_lm_words,
"boosted_lm_score": self._boosted_lm_score,
}
self.set_model_name(model_function_map.get("model_name"))
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._asr_service = riva.client.ASRService(auth)
self._queue = None
self._config = None
self._thread_task = None
self._response_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
False - this service does not support metrics generation.
"""
return False
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Start the Riva STT service and initialize streaming configuration.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
if self._config:
return
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self.sample_rate,
audio_channel_count=1,
),
interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
self._queue = asyncio.Queue()
if not self._thread_task:
self._thread_task = self.create_task(self._thread_task_handler())
if not self._response_task:
self._response_queue = asyncio.Queue()
self._response_task = self.create_task(self._response_task_handler())
async def stop(self, frame: EndFrame):
"""Stop the Riva STT service and clean up resources.
Args:
frame: EndFrame indicating pipeline stop.
"""
await super().stop(frame)
await self._stop_tasks()
async def cancel(self, frame: CancelFrame):
"""Cancel the Riva STT service operation.
Args:
frame: CancelFrame indicating operation cancellation.
"""
await super().cancel(frame)
await self._stop_tasks()
async def _stop_tasks(self):
if self._thread_task:
await self.cancel_task(self._thread_task)
self._thread_task = None
if self._response_task:
await self.cancel_task(self._response_task)
self._response_task = None
def _response_handler(self):
responses = self._asr_service.streaming_response_generator(
audio_chunks=self,
streaming_config=self._config,
)
for response in responses:
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
self._response_queue.put(response), self.get_event_loop()
)
async def _thread_task_handler(self):
try:
self._thread_running = True
await asyncio.to_thread(self._response_handler)
except asyncio.CancelledError:
self._thread_running = False
raise
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def _handle_response(self, response):
for result in response.results:
if result and not result.alternatives:
continue
transcript = result.alternatives[0].transcript
if transcript and len(transcript) > 0:
await self.stop_ttfb_metrics()
if result.is_final:
await self.stop_processing_metrics()
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
await self._handle_transcription(
transcript=transcript,
is_final=result.is_final,
language=self._language_code,
)
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language_code,
result=result,
)
)
async def _response_task_handler(self):
while True:
response = await self._response_queue.get()
await self._handle_response(response)
self._response_queue.task_done()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
None - transcription results are pushed to the pipeline via frames.
"""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._queue.put(audio)
yield None
def __next__(self) -> bytes:
"""Get the next audio chunk for Riva processing.
Returns:
Audio bytes from the queue.
Raises:
StopIteration: When the thread is no longer running.
"""
if not self._thread_running:
raise StopIteration
try:
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
return future.result()
except FuturesCancelledError:
raise StopIteration
def __iter__(self):
"""Return iterator for audio chunk processing.
Returns:
Self as iterator.
"""
return self
class RivaSegmentedSTTService(SegmentedSTTService):
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
transcription on audio segments. It inherits from SegmentedSTTService to handle
audio buffering and speech detection.
"""
class InputParams(BaseModel):
"""Configuration parameters for Riva segmented STT service.
Parameters:
language: Target language for transcription. Defaults to EN_US.
profanity_filter: Whether to filter profanity from results.
automatic_punctuation: Whether to add automatic punctuation.
verbatim_transcripts: Whether to return verbatim transcripts.
boosted_lm_words: List of words to boost in language model.
boosted_lm_score: Score boost for specified words.
"""
language: Optional[Language] = Language.EN_US
profanity_filter: bool = False
automatic_punctuation: bool = True
verbatim_transcripts: bool = False
boosted_lm_words: Optional[List[str]] = None
boosted_lm_score: float = 4.0
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
"model_name": "canary-1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Riva segmented STT service.
Args:
api_key: NVIDIA API key for authentication
server: Riva server address (defaults to NVIDIA Cloud Function endpoint)
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
params: Additional configuration parameters for Riva
**kwargs: Additional arguments passed to SegmentedSTTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RivaSegmentedSTTService.InputParams()
# Set model name
self.set_model_name(model_function_map.get("model_name"))
# Initialize Riva settings
self._api_key = api_key
self._server = server
self._function_id = model_function_map.get("function_id")
self._model_name = model_function_map.get("model_name")
# Store the language as a Language enum and as a string
self._language_enum = params.language or Language.EN_US
self._language = self.language_to_service_language(self._language_enum) or "en-US"
# Configure transcription parameters
self._profanity_filter = params.profanity_filter
self._automatic_punctuation = params.automatic_punctuation
self._verbatim_transcripts = params.verbatim_transcripts
self._boosted_lm_words = params.boosted_lm_words
self._boosted_lm_score = params.boosted_lm_score
# Voice activity detection thresholds (use Riva defaults)
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
# Create Riva client
self._config = None
self._asr_service = None
self._settings = {"language": self._language_enum}
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert pipecat Language enum to Riva's language code.
Args:
language: Language enum value.
Returns:
Riva language code or None if not supported.
"""
return language_to_riva_language(language)
def _initialize_client(self):
"""Initialize the Riva ASR client with authentication metadata."""
if self._asr_service is not None:
return
# Set up authentication metadata for NVIDIA Cloud Functions
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
# Create authenticated client
auth = riva.client.Auth(None, True, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
logger.info(f"Initialized RivaSegmentedSTTService with model: {self.model_name}")
def _create_recognition_config(self):
"""Create the Riva ASR recognition configuration."""
# Create base configuration
config = riva.client.RecognitionConfig(
language_code=self._language, # Now using the string, not a tuple
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=self._verbatim_transcripts,
)
# Add word boosting if specified
if self._boosted_lm_words:
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
# Add voice activity detection parameters
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
# Add any custom configuration
if self._custom_configuration:
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
return config
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True - this service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Set the ASR model for transcription.
Args:
model: Model name to set.
Note:
Model cannot be changed after initialization. Use model_function_map
parameter in constructor instead.
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
async def start(self, frame: StartFrame):
"""Initialize the service when the pipeline starts.
Args:
frame: StartFrame indicating pipeline start.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_recognition_config()
async def set_language(self, language: Language):
"""Set the language for the STT service.
Args:
language: Target language for transcription.
"""
logger.info(f"Switching STT language to: [{language}]")
self._language_enum = language
self._language = self.language_to_service_language(language) or "en-US"
self._settings["language"] = language
# Update configuration with new language
if self._config:
self._config.language_code = self._language
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
pass
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Transcribe an audio segment.
Args:
audio: Raw audio bytes in WAV format (already converted by base class).
Yields:
Frame: TranscriptionFrame containing the transcribed text.
"""
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Make sure the client is initialized
if self._asr_service is None:
self._initialize_client()
# Make sure the config is created
if self._config is None:
self._config = self._create_recognition_config()
# Type assertion to satisfy the IDE
assert self._asr_service is not None, "ASR service not initialized"
assert self._config is not None, "Recognition config not created"
# Process audio with Riva ASR - explicitly request non-future response
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
# Process the response - handle different possible return types
try:
# If it's a future-like object, get the result
if hasattr(raw_response, "result"):
response = raw_response.result()
else:
response = raw_response
# Process transcription results
transcription_found = False
# Now we can safely check results
# Type hint for the IDE
results = getattr(response, "results", [])
for result in results:
alternatives = getattr(result, "alternatives", [])
if alternatives:
text = alternatives[0].transcript.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
self._language_enum,
)
transcription_found = True
await self._handle_transcription(text, True, self._language_enum)
if not transcription_found:
logger.debug("No transcription results found in Riva response")
except AttributeError as ae:
logger.error(f"Unexpected response structure from Riva: {ae}")
yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}")
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
class ParakeetSTTService(RivaSTTService):
"""Deprecated speech-to-text service using NVIDIA Parakeet models.
.. deprecated:: 0.0.66
This class is deprecated. Use `RivaSTTService` instead for equivalent functionality
with Parakeet models by specifying the appropriate model_function_map.
"""
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
model_function_map: Mapping[str, str] = {
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
"model_name": "parakeet-ctc-1.1b-asr",
},
sample_rate: Optional[int] = None,
params: Optional[RivaSTTService.InputParams] = None, # Use parent class's type
**kwargs,
):
"""Initialize the Parakeet STT service.
Args:
api_key: NVIDIA API key for authentication.
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
model_function_map: Mapping containing 'function_id' and 'model_name' for Parakeet model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for Riva.
**kwargs: Additional arguments passed to RivaSTTService.
"""
super().__init__(
api_key=api_key,
server=server,
model_function_map=model_function_map,
sample_rate=sample_rate,
params=params,
**kwargs,
)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.",
DeprecationWarning,
)
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"RivaSTTService and ParakeetSTTService "
"from pipecat.services.riva.stt is deprecated. "
"Please use NvidiaSTTService from pipecat.services.nvidia.stt instead.",
DeprecationWarning,
stacklevel=2,
)
RivaSTTService = NvidiaSTTService
language_to_riva_language = language_to_nvidia_riva_language
RivaSegmentedSTTService = NvidiaSegmentedSTTService
ParakeetSTTService = NvidiaSTTService

View File

@@ -8,232 +8,26 @@
This module provides integration with NVIDIA Riva's TTS services through
gRPC API for high-quality speech synthesis.
.. deprecated:: 0.0.96
This module is deprecated. Please NvidiaTTSService from
pipecat.services.nvidia.tts instead.
"""
import asyncio
import os
from typing import AsyncGenerator, Mapping, Optional
import warnings
from pipecat.utils.tracing.service_decorators import traced_tts
from pipecat.services.nvidia.tts import NVIDIA_TTS_TIMEOUT_SECS, NvidiaTTSService
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"FastPitchTTSService and RivaTTSService "
"from pipecat.services.nim.llm are deprecated. "
"Please use NvidiaLLMService from pipecat.services.nvidia.tts instead.",
DeprecationWarning,
stacklevel=2,
)
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}")
RIVA_TTS_TIMEOUT_SECS = 5
class RivaTTSService(TTSService):
"""NVIDIA Riva text-to-speech service.
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
cloud-based TTS models. Supports multiple voices, languages, and
configurable quality settings.
"""
class InputParams(BaseModel):
"""Input parameters for Riva TTS configuration.
Parameters:
language: Language code for synthesis. Defaults to US English.
quality: Audio quality setting (0-100). Defaults to 20.
"""
language: Optional[Language] = Language.EN_US
quality: Optional[int] = 20
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
sample_rate: Optional[int] = None,
model_function_map: Mapping[str, str] = {
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
"model_name": "magpie-tts-multilingual",
},
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the NVIDIA Riva TTS service.
Args:
api_key: NVIDIA API key for authentication.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for the TTS model.
params: Additional configuration parameters for TTS synthesis.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or RivaTTSService.InputParams()
self._api_key = api_key
self._voice_id = voice_id
self._language_code = params.language
self._quality = params.quality
self._function_id = model_function_map.get("function_id")
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
config_response = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
async def set_model(self, model: str):
"""Attempt to set the TTS model.
Note: Model cannot be changed after initialization for Riva service.
Args:
model: The model name to set (operation not supported).
"""
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
logger.warning(
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech data.
"""
def read_audio_responses(queue: asyncio.Queue):
def add_response(r):
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
try:
responses = self._service.synthesize_online(
text,
self._voice_id,
self._language_code,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._quality,
custom_dictionary={},
)
for r in responses:
add_response(r)
add_response(None)
except Exception as e:
logger.error(f"{self} exception: {e}")
add_response(None)
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
try:
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)
# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
class FastPitchTTSService(RivaTTSService):
"""Deprecated FastPitch TTS service.
.. deprecated:: 0.0.66
This class is deprecated. Use RivaTTSService instead for new implementations.
Provides backward compatibility for existing FastPitch TTS integrations.
"""
def __init__(
self,
*,
api_key: str,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "English-US.Female-1",
sample_rate: Optional[int] = None,
model_function_map: Mapping[str, str] = {
"function_id": "0149dedb-2be8-4195-b9a0-e57e0e14f972",
"model_name": "fastpitch-hifigan-tts",
},
params: Optional[RivaTTSService.InputParams] = None,
**kwargs,
):
"""Initialize the deprecated FastPitch TTS service.
Args:
api_key: NVIDIA API key for authentication.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
voice_id: Voice model identifier. Defaults to Female-1 voice.
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for FastPitch model.
params: Additional configuration parameters for TTS synthesis.
**kwargs: Additional arguments passed to parent RivaTTSService.
"""
super().__init__(
api_key=api_key,
server=server,
voice_id=voice_id,
sample_rate=sample_rate,
model_function_map=model_function_map,
params=params,
**kwargs,
)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.",
DeprecationWarning,
)
RivaTTSService = NvidiaTTSService
FastPitchTTSService = NvidiaTTSService
RIVA_TTS_TIMEOUT_SECS = NVIDIA_TTS_TIMEOUT_SECS

View File

@@ -275,8 +275,7 @@ class SarvamSTTService(STTService):
await self._socket_client.translate(**method_kwargs)
except Exception as e:
logger.error(f"Error sending audio to Sarvam: {e}")
await self.push_error(ErrorFrame(f"Failed to send audio: {e}"))
yield ErrorFrame(error=f"Error sending audio to Sarvam: {e}", exception=e)
yield None
@@ -332,13 +331,11 @@ class SarvamSTTService(STTService):
logger.info("Connected to Sarvam successfully")
except ApiError as e:
logger.error(f"Sarvam API error: {e}")
await self.push_error(ErrorFrame(f"Sarvam API error: {e}"))
await self.push_error(error_msg=f"Sarvam API error: {e}", exception=e)
except Exception as e:
logger.error(f"Failed to connect to Sarvam: {e}")
self._socket_client = None
self._websocket_context = None
await self.push_error(ErrorFrame(f"Failed to connect to Sarvam: {e}"))
await self.push_error(error_msg=f"Failed to connect to Sarvam: {e}", exception=e)
async def _disconnect(self):
"""Disconnect from Sarvam WebSocket API using SDK."""
@@ -351,7 +348,9 @@ class SarvamSTTService(STTService):
# Exit the async context manager
await self._websocket_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Error closing WebSocket connection: {e}")
await self.push_error(
error_msg=f"Error closing WebSocket connection: {e}", exception=e
)
finally:
logger.debug("Disconnected from Sarvam WebSocket")
self._socket_client = None
@@ -371,8 +370,7 @@ class SarvamSTTService(STTService):
# Messages will be handled via the _message_handler callback
await self._socket_client.start_listening()
except Exception as e:
logger.error(f"Error in Sarvam receive task: {e}")
await self.push_error(ErrorFrame(f"Sarvam receive task error: {e}"))
await self.push_error(error_msg=f"Sarvam receive task error: {e}", exception=e)
async def _handle_message(self, message):
"""Handle incoming WebSocket message from Sarvam SDK.
@@ -427,8 +425,7 @@ class SarvamSTTService(STTService):
await self.stop_processing_metrics()
except Exception as e:
logger.error(f"Error handling Sarvam message: {e}")
await self.push_error(ErrorFrame(f"Failed to handle message: {e}"))
await self.push_error(error_msg=f"Failed to handle message: {e}", exception=e)
await self.stop_all_metrics()
@traced_stt

View File

@@ -254,8 +254,7 @@ class SarvamHttpTTSService(TTSService):
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Sarvam API error: {error_text}")
await self.push_error(ErrorFrame(error=f"Sarvam API error: {error_text}"))
yield ErrorFrame(error=f"Sarvam API error: {error_text}")
return
response_data = await response.json()
@@ -264,8 +263,7 @@ class SarvamHttpTTSService(TTSService):
# Decode base64 audio data
if "audios" not in response_data or not response_data["audios"]:
logger.error("No audio data received from Sarvam API")
await self.push_error(ErrorFrame(error="No audio data received"))
yield ErrorFrame(error="No audio data received")
return
# Get the first audio (there should be only one for single text input)
@@ -286,8 +284,7 @@ class SarvamHttpTTSService(TTSService):
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
yield ErrorFrame(error=f"Error generating TTS: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()
@@ -517,9 +514,11 @@ class SarvamTTSService(InterruptibleTTSService):
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame and flush audio if it's the end of a full response."""
if isinstance(frame, LLMFullResponseEndFrame):
await super().process_frame(frame, direction)
# When the LLM finishes responding, flush any remaining text in Sarvam's buffer
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
return await super().process_frame(frame, direction)
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""
@@ -560,8 +559,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self._disconnect_websocket()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
# Reset state only after everything is cleaned up
self._started = False
@@ -585,8 +583,9 @@ class SarvamTTSService(InterruptibleTTSService):
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(
error_msg=f"Error connecting to Sarvam TTS Websocket: {e}", exception=e
)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
@@ -602,8 +601,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self._websocket.send(json.dumps(config_message))
logger.debug("Configuration sent successfully")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise
async def _disconnect_websocket(self):
@@ -615,8 +613,7 @@ class SarvamTTSService(InterruptibleTTSService):
logger.debug("Disconnecting from Sarvam")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._started = False
self._websocket = None
@@ -640,7 +637,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self.push_frame(frame)
elif msg.get("type") == "error":
error_msg = msg["data"]["message"]
logger.error(f"TTS Error: {error_msg}")
await self.push_error(error_msg=f"TTS Error: {error_msg}")
# If it's a timeout error, the connection might need to be reset
if "too long" in error_msg.lower() or "timeout" in error_msg.lower():
@@ -702,13 +699,11 @@ class SarvamTTSService(InterruptibleTTSService):
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -48,12 +48,14 @@ class SimliVideoService(FrameProcessor):
"""Input parameters for Simli video configuration.
Parameters:
enable_logging: Whether to enable Simli logging.
max_session_length: Absolute maximum session duration in seconds.
Avatar will disconnect after this time even if it's speaking.
max_idle_time: Maximum duration in seconds the avatar is not speaking
before the avatar disconnects.
"""
enable_logging: Optional[bool] = None
max_session_length: Optional[int] = None
max_idle_time: Optional[int] = None
@@ -154,6 +156,7 @@ class SimliVideoService(FrameProcessor):
config=config,
latencyInterval=latency_interval,
simliURL=simli_url,
enable_logging=params.enable_logging or False,
)
self._pipecat_resampler: AudioResampler = None
@@ -178,7 +181,7 @@ class SimliVideoService(FrameProcessor):
self._audio_task = self.create_task(self._consume_and_process_audio())
self._video_task = self.create_task(self._consume_and_process_video())
except Exception as e:
logger.error(f"{self}: unable to start connection: {e}")
await self.push_error(error_msg=f"Unable to start connection: {e}", exception=e)
async def _consume_and_process_audio(self):
"""Consume audio frames from Simli and push them downstream."""
@@ -256,7 +259,7 @@ class SimliVideoService(FrameProcessor):
await self._simli_client.send(audioBytes)
return
except Exception as e:
logger.exception(f"{self} exception: {e}")
await self.push_error(error_msg=f"Error sending audio: {e}", exception=e)
elif isinstance(frame, TTSStoppedFrame):
try:
if self._previously_interrupted and len(self._audio_buffer) > 0:
@@ -264,7 +267,7 @@ class SimliVideoService(FrameProcessor):
self._previously_interrupted = False
self._audio_buffer = bytearray()
except Exception as e:
logger.exception(f"{self} exception: {e}")
await self.push_error(error_msg=f"Error stopping TTS: {e}", exception=e)
return
elif isinstance(frame, (EndFrame, CancelFrame)):
await self._stop()

View File

@@ -194,7 +194,7 @@ class SonioxSTTService(STTService):
self._websocket = await websocket_connect(self._url)
if not self._websocket:
logger.error(f"Unable to connect to Soniox API at {self._url}")
await self.push_error(error_msg=f"Unable to connect to Soniox API at {self._url}")
# If vad_force_turn_endpoint is not enabled, we need to enable endpoint detection.
# Either one or the other is required.
@@ -327,8 +327,7 @@ class SonioxSTTService(STTService):
# Expected when closing the connection
logger.debug("WebSocket connection closed, keepalive task stopped.")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
async def _receive_task_handler(self):
if not self._websocket:
@@ -404,13 +403,8 @@ class SonioxSTTService(STTService):
if error_code or error_message:
# In case of error, still send the final transcript (if any remaining in the buffer).
await send_endpoint_transcript()
logger.error(
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
)
await self.push_error(
ErrorFrame(
error=f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
)
error_msg=f"Error: {error_code} (_receive_task_handler) - {error_message}"
)
finished = content.get("finished")
@@ -425,5 +419,4 @@ class SonioxSTTService(STTService):
# Expected when closing the connection.
pass
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error receiving message: {e}", exception=e)

View File

@@ -467,8 +467,7 @@ class SpeechmaticsSTTService(STTService):
await self._client.send_audio(audio)
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
await self._disconnect()
def update_params(
@@ -514,8 +513,7 @@ class SpeechmaticsSTTService(STTService):
self._client.send_message(payload), self.get_event_loop()
)
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise RuntimeError(f"error sending message to STT: {e}")
async def _connect(self) -> None:
@@ -581,8 +579,7 @@ class SpeechmaticsSTTService(STTService):
logger.debug(f"{self} Connected to Speechmatics STT service")
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Error connecting to Speechmatics: {e}", exception=e)
self._client = None
async def _disconnect(self) -> None:
@@ -596,8 +593,9 @@ class SpeechmaticsSTTService(STTService):
except asyncio.TimeoutError:
logger.warning(f"{self} Timeout while closing Speechmatics client connection")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(
error_msg=f"Error disconnecting from Speechmatics: {e}", exception=e
)
finally:
self._client = None
await self._call_event_handler("on_disconnected")

View File

@@ -163,7 +163,7 @@ class SpeechmaticsTTSService(TTSService):
# Report error frame
yield ErrorFrame(
error=f"{self} Service unavailable [503] (attempt {attempt}, retry in {backoff_time:.2f}s)"
error=f"Service unavailable [503] (attempt {attempt}, retry in {backoff_time:.2f}s)"
)
# Wait before retrying
@@ -174,16 +174,13 @@ class SpeechmaticsTTSService(TTSService):
except (ValueError, ArithmeticError):
yield ErrorFrame(
error=f"{self} Service unavailable [503] (attempts {attempt})",
fatal=True,
error=f"Service unavailable [503] (attempts {attempt})",
)
return
# != 200 : Error
if response.status != 200:
yield ErrorFrame(
error=f"{self} Service unavailable [{response.status}]", fatal=True
)
yield ErrorFrame(error=f"Service unavailable [{response.status}]")
return
# Update Pipecat metrics
@@ -225,7 +222,7 @@ class SpeechmaticsTTSService(TTSService):
break
except Exception as e:
yield ErrorFrame(error=f"{self}: Error generating TTS: {e}", fatal=True)
yield ErrorFrame(error=f"Error generating TTS: {e}")
finally:
# Emit the TTS stopped frame
yield TTSStoppedFrame()

View File

@@ -329,4 +329,4 @@ class WebsocketSTTService(STTService, WebsocketService):
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
await self.push_error_frame(error)

View File

@@ -425,16 +425,13 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
pending_aggregation = self._text_aggregator.text
# Flush any remaining text (including text waiting for lookahead)
remaining = await self._text_aggregator.flush()
if remaining:
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
# Reset aggregator state
await self._text_aggregator.reset()
self._processing_text = False
if pending_aggregation.text:
await self._push_tts_frames(
AggregatedTextFrame(pending_aggregation.text, pending_aggregation.type)
)
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
@@ -539,17 +536,20 @@ class TTSService(AIService):
text = frame.text
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
aggregated_by = "token"
if text:
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
)
else:
aggregate = await self._text_aggregator.aggregate(frame.text)
if aggregate:
async for aggregate in self._text_aggregator.aggregate(frame.text):
text = aggregate.text
aggregated_by = aggregate.type
if text:
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
)
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
)
async def _push_tts_frames(
self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False
@@ -781,7 +781,7 @@ class WebsocketTTSService(TTSService, WebsocketService):
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
await self.push_error_frame(error)
class InterruptibleTTSService(WebsocketTTSService):
@@ -843,7 +843,7 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error(error)
await self.push_error_frame(error)
class InterruptibleWordTTSService(WebsocketWordTTSService):

View File

@@ -246,8 +246,7 @@ class UltravoxSTTService(AIService):
logger.info("Model warm-up completed successfully")
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
def _generate_silent_audio(self, sample_rate=16000, duration_sec=1.0):
"""Generate silent audio as a numpy array.
@@ -377,7 +376,7 @@ class UltravoxSTTService(AIService):
if arr.size > 0: # Check if array is not empty
audio_arrays.append(arr)
except Exception as e:
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
# Handle numpy array data
elif isinstance(f.audio, np.ndarray):
if f.audio.size > 0: # Check if array is not empty
@@ -437,17 +436,11 @@ class UltravoxSTTService(AIService):
yield LLMFullResponseEndFrame()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
else:
logger.error("No model available for text generation")
yield ErrorFrame("No model available for text generation")
except Exception as e:
logger.error(f"{self} exception: {e}")
import traceback
logger.error(traceback.format_exc())
yield ErrorFrame(f"Error processing audio: {str(e)}")
finally:
self._buffer.is_processing = False

View File

@@ -226,8 +226,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
logger.warning("Received empty transcription from API")
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
async def _transcribe(self, audio: bytes) -> Transcription:
"""Transcribe audio data to text.

Some files were not shown because too many files have changed in this diff Show More