Compare commits
23 Commits
hush/conte
...
bot-output
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8640d84ae | ||
|
|
23e4e29999 | ||
|
|
713b488bb6 | ||
|
|
71b87fd420 | ||
|
|
3f269f9834 | ||
|
|
4c698777f3 | ||
|
|
5ca04ad741 | ||
|
|
9a3902a82c | ||
|
|
8ab0c92681 | ||
|
|
124f147a37 | ||
|
|
ed808a9246 | ||
|
|
e9de9daf8c | ||
|
|
82b9c4f0b6 | ||
|
|
5dfe20be91 | ||
|
|
0d2c5286fa | ||
|
|
29417ba44d | ||
|
|
bc6a9cac26 | ||
|
|
8a90decbc0 | ||
|
|
ccca6e8d81 | ||
|
|
e6dc1a510d | ||
|
|
69945c5e0d | ||
|
|
5c8635570d | ||
|
|
fe9aa3383e |
517
CHANGELOG.md
517
CHANGELOG.md
@@ -9,411 +9,94 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `wait_for_all` argument to the base `LLMService`. When enabled, this
|
||||
ensures all function calls complete before returning results to the LLM (i.e.,
|
||||
before running a new inference with those results).
|
||||
|
||||
### Changed
|
||||
|
||||
- Improved interruption handling to prevent bots from repeating themselves.
|
||||
LLM services that return multiple sentences in a single response (e.g.,
|
||||
`GoogleLLMService`) are now split into individual sentences before being sent
|
||||
to TTS. This ensures interruptions occur at sentence boundaries, preventing
|
||||
the bot from repeating content after being interrupted during long responses.
|
||||
|
||||
- Text Aggregation Improvements:
|
||||
|
||||
- **Breaking Change**: `BaseTextAggregator.aggregate()` now returns
|
||||
`AsyncIterator[Aggregation]` instead of `Optional[Aggregation]`. This
|
||||
enables the aggregator to return multiple results based on the provided
|
||||
text.
|
||||
- Refactored text aggregators to use inheritance: `SkipTagsAggregator` and
|
||||
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
|
||||
the base class's sentence detection logic.
|
||||
|
||||
- Updated `AICFilter` to use Quail STT as the default model
|
||||
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
|
||||
interaction (e.g., voice agents, speech-to-text) and operates at a native
|
||||
sample rate of 16 kHz with fixed enhancement parameters.
|
||||
|
||||
- Updated Deepgram logging to include Deepgram request IDs for improved debugging.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
|
||||
|
||||
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
|
||||
has any effect. Noise gating is now handled automatically by the AIC VAD
|
||||
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
|
||||
|
||||
- NVIDIA Services name changes (all functionality is unchanged):
|
||||
|
||||
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
|
||||
- `RivaSTTService` is now deprecated, use `NvidiaSTTService` instead.
|
||||
- `RivaTTSService` is now deprecated, use `NvidiaTTSService` instead.
|
||||
- Use `uv pip install pipecat-ai[nvidia]` instead of
|
||||
`uv pip install pipecat-ai[riva]`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
|
||||
services.
|
||||
|
||||
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
|
||||
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
|
||||
|
||||
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
|
||||
multiple times for `KEEP` or `AGGREGATE` patterns.
|
||||
|
||||
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
|
||||
spoken. Now, audio is flushed when the TTS services receives the
|
||||
`LLMFullResponseEndFrame` or `EndFrame`.
|
||||
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was
|
||||
always set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
|
||||
incorrectly pushed after a functional call. This caused an issue with the
|
||||
voice-ui-kit's conversational panel rending of the LLM output after a
|
||||
function call.
|
||||
|
||||
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃
|
||||
|
||||
### Added
|
||||
|
||||
- Added `AWSBedrockAgentCoreProcessor` to support invoking an AgentCore-hosted
|
||||
agent in a Pipecat pipeline.
|
||||
|
||||
- Enhanced error handling across the framework:
|
||||
|
||||
- Added `on_error` callback to `FrameProcessor` for centralized error
|
||||
handling.
|
||||
|
||||
- Renamed `push_error(error: ErrorFrame)` to `push_error_frame(error: ErrorFrame)`
|
||||
for clarity.
|
||||
|
||||
- Added new `push_error` method for simplified error reporting:
|
||||
|
||||
```python
|
||||
async def push_error(error_msg: str,
|
||||
exception: Optional[Exception] = None,
|
||||
fatal: bool = False)
|
||||
```
|
||||
|
||||
- Standardized error logging by replacing `logger.exception` calls with
|
||||
`logger.error` throughout the codebase.
|
||||
|
||||
- Added `cache_read_input_tokens`, `cache_creation_input_tokens` and
|
||||
`reasoning_tokens` to OTel spans for LLM call
|
||||
|
||||
- Added `LiveKitRESTHelper` utility class for managing LiveKit rooms via REST API.
|
||||
|
||||
- Added `DeepgramSageMakerSTTService` which connects to a SageMaker hosted
|
||||
Deepgram STT model. Added `07c-interruptible-deepgram-sagemaker.py`
|
||||
foundational example.
|
||||
|
||||
- Added `SageMakerBidiClient` to connect to SageMaker hosted BiDi compatible
|
||||
services.
|
||||
|
||||
- Added support for `include_timestamps` and `enable_logging` in
|
||||
`ElevenLabsRealtimeSTTService`. When `include_timestamps` is enabled,
|
||||
timestamp data is included in the `TranscriptionFrame`'s `result`
|
||||
parameter.
|
||||
|
||||
- Added optional speaking rate control to `InworldTTSService`.
|
||||
|
||||
- Introduced a new `AggregatedTextFrame` type to support passing text along with
|
||||
an `aggregated_by` field to describe the type of text
|
||||
included. `TTSTextFrame`s now inherit from `AggregatedTextFrame`. With this
|
||||
inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate the
|
||||
perceived output and determine whether or not the text was spoken based on if
|
||||
that frame is also a `TTSTextFrame`.
|
||||
|
||||
With this frame, the llm token stream can be transformed into custom
|
||||
composable chunks, allowing for aggregation outside the TTS service. This
|
||||
makes it possible to listen for or handle those aggregations and sets the
|
||||
stage for doing things like composing a best effort of the perceived llm
|
||||
output in a more digestable form and to do so whether or not it is processed
|
||||
by a TTS or if even a TTS exists.
|
||||
|
||||
- Introduced `LLMTextProcessor`: A new processor meant to allow customization
|
||||
for how LLMTextFrames should be aggregated and considered. It's purpose is to
|
||||
turn `LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService
|
||||
will still aggregate `LLMTextFrame`s by sentence for the service to
|
||||
consume. However, if you wish to override how the llm text is aggregated, you
|
||||
should no longer override the TTS's internal text_aggregator, but instead,
|
||||
insert this processor between your LLM and TTS in the pipeline.
|
||||
|
||||
- New `bot-output` RTVI message to represent what the bot actually "says".
|
||||
|
||||
- The `RTVIObserver` now emits `bot-output` messages based off the new
|
||||
`AggregatedTextFrame`s (`bot-tts-text` and `bot-llm-text` are still
|
||||
supported and generated, but `bot-transcript` is now deprecated in lieu of
|
||||
this new, more thorough, message).
|
||||
|
||||
- The new `RTVIBotOutputMessage` includes the fields:
|
||||
|
||||
- `spoken`: A boolean indicating whether the text was spoken by TTS
|
||||
|
||||
- `aggregated_by`: A string representing how the text was aggregated
|
||||
("sentence", "word", "my custom aggregation")
|
||||
|
||||
- Introduced new fields to `RTVIObserver` to support the new `bot-output`
|
||||
messaging:
|
||||
|
||||
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output
|
||||
messages.
|
||||
|
||||
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that
|
||||
match aggregation types that should not be included in bot-output
|
||||
messages. (Ex. `credit_card`)
|
||||
|
||||
- Introduced new methods, `add_text_transformer()` and
|
||||
`remove_text_transformer()`, to `RTVIObserver` to support providing (and
|
||||
subsequently removing) callbacks for various types of aggregations (or all
|
||||
aggregations with `*`) that can modify the text before being sent as a
|
||||
`bot-output` or `tts-text` message. (Think obscuring the credit card or
|
||||
inserting extra detail the client might want that the context doesn't need.)
|
||||
|
||||
- In `MiniMaxHttpTTSService`:
|
||||
|
||||
- Added support for speech-2.6-hd and speech-2.6-turbo models
|
||||
|
||||
- Added languages: Afrikaans, Bulgarian, Catalan, Danish, Persian, Filipino,
|
||||
Hebrew, Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian,
|
||||
Swedish, and Tamil
|
||||
|
||||
- Added new emotions: calm and fluent
|
||||
|
||||
- Added `enable_logging` to `SimliVideoService` input parameters. It's disabled
|
||||
by default.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `FishAudioTTSService` default model to `s1`.
|
||||
|
||||
- Updated `DeepgramTTSService` to use Deepgram's TTS websocket API. ⚠️ This is
|
||||
a potential breaking change, which only affects you if you're self-hosting
|
||||
`DeepgramTTSService`. The new service uses Websockets and improves TTFB
|
||||
latency.
|
||||
|
||||
- Updated `daily-python` to 0.22.0.
|
||||
|
||||
- `BaseTextAggregator` changes:
|
||||
|
||||
Modified the BaseTextAggregator type so that when text gets aggregated,
|
||||
metadata can be associated with it. Currently, that just means a `type`, so
|
||||
that the aggregation can be classified or described. Changes made to support
|
||||
this:
|
||||
|
||||
- ⚠️ IMPORTANT: Aggregators are now expected to strip leading/trailing white
|
||||
space characters before returning their aggregation from `aggregation()` or
|
||||
`.text`. This way all aggregators have a consistent contract allowing
|
||||
downstream use to know how to stitch aggregations back together.
|
||||
|
||||
- Introduced a new `Aggregation` dataclass to represent both the aggregated
|
||||
`text` and a string identifying the `type` of aggregation (ex. "sentence",
|
||||
"word", "my custom aggregation")
|
||||
|
||||
- ⚠️ Breaking change: `BaseTextAggregator.text` now returns an `Aggregation`
|
||||
(instead of `str`).
|
||||
|
||||
Before:
|
||||
|
||||
```python
|
||||
aggregated_text = myAggregator.text
|
||||
```
|
||||
|
||||
Now:
|
||||
|
||||
```python
|
||||
aggregated_text = myAggregator.text.text
|
||||
```
|
||||
|
||||
- ⚠️ Breaking change: `BaseTextAggregator.aggregate()` now returns
|
||||
`Optional[Aggregation]` (instead of `Optional[str]`).
|
||||
|
||||
Before:
|
||||
|
||||
```python
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
print(f"successfully aggregated text: {aggregation}")
|
||||
```
|
||||
|
||||
Now:
|
||||
|
||||
```python
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
if aggregation:
|
||||
print(f"successfully aggregated text: {aggregation.text}")
|
||||
```
|
||||
|
||||
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator`
|
||||
updated to produce/consume `Aggregation` objects.
|
||||
|
||||
- All uses of the above Aggregators have been updated accordingly.
|
||||
|
||||
- Augmented the `PatternPairAggregator` so that matched patterns can be treated
|
||||
as their own aggregation, taking advantage of the new. To that end:
|
||||
|
||||
- Introduced a new, preferred version of `add_pattern` to support a new option
|
||||
for treating a match as a separate aggregation returned from
|
||||
`aggregate()`. This replaces the now deprecated `add_pattern_pair` method
|
||||
and you provide a `MatchAction` in lieu of the `remove_match` field.
|
||||
|
||||
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization
|
||||
for how a match should be handled.
|
||||
|
||||
- `REMOVE`: The text along with its delimiters will be removed from the
|
||||
streaming text. Sentence aggregation will continue on as if this text
|
||||
did not exist.
|
||||
|
||||
- `KEEP`: The delimiters will be removed, but the content between them
|
||||
will be kept. Sentence aggregation will continue on with the internal
|
||||
text included.
|
||||
|
||||
- `AGGREGATE`: The delimiters will be removed and the content between will
|
||||
be treated as a separate aggregation. Any text before the start of the
|
||||
pattern will be returned early, whether or not a complete sentence was
|
||||
found. Then the pattern will be returned. Then the aggregation will
|
||||
continue on sentence matching after the closing delimiter is found. The
|
||||
content between the delimiters is not aggregated by sentence. It is
|
||||
aggregated as one single block of text.
|
||||
|
||||
- `PatternMatch` now extends `Aggregation` and provides richer info to
|
||||
handlers.
|
||||
|
||||
- ⚠️ Breaking change: The `PatternMatch` type returned to handlers registered
|
||||
via `on_pattern_match` has been updated to subclass from the new
|
||||
`Aggregation` type, which means that `content` has been replaced with
|
||||
`text` and `pattern_id` has been replaced with `type`:
|
||||
|
||||
```python
|
||||
async dev on_match_tag(match: PatternMatch):
|
||||
pattern = match.type # instead of match.pattern_id
|
||||
text = match.text # instead of match.content
|
||||
```
|
||||
|
||||
- `TextFrame` now includes the field `append_to_context` to support setting
|
||||
whether or not the encompassing text should be added to the LLM context (by
|
||||
the LLM assistant aggregator). It defaults to `True`.
|
||||
|
||||
- `TTSService` base class updates:
|
||||
|
||||
- `TTSService`s now accept a new `skip_aggregator_types` to avoid speaking
|
||||
certain aggregation types (now determined/returned by the aggregator)
|
||||
|
||||
- Introduced the ability to do a just-in-time transform of text before it gets
|
||||
sent to the TTS service via callbacks you can set up via a new init field,
|
||||
`text_transforms` or a new method `add_text_transformer()`. This makes it
|
||||
possible to do things like introduce TTS-specific tags for spelling or
|
||||
emotion or change the pronunciation of something on the
|
||||
fly. `remove_text_transformer` has also been added to support removing a
|
||||
registered transform callback.
|
||||
|
||||
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when
|
||||
either an aggregation occurs that should not be spoken or when the TTS
|
||||
service supports word-by-word timestamping. In the latter case, the
|
||||
`TTSService` preliminarily generates an `AggregatedTextFrame`, aggregated by
|
||||
sentence to generate the full sentence content as early as possible.
|
||||
|
||||
- Updated `CartesiaTTSService`:
|
||||
|
||||
- Modified use of custom default text_aggregator to avoid deprecation warnings
|
||||
and push users towards use of transformers or the `LLMTextProcessor`
|
||||
|
||||
- Added convenience methods for taking advantage of Cartesia's SSML tags:
|
||||
spell, emotion, pauses, volume, and speed.
|
||||
|
||||
- Updated `RimeTTSService`:
|
||||
|
||||
- Modified use of custom default text_aggregator to avoid deprecation warnings
|
||||
and push users towards use of transformers or the `LLMTextProcessor`
|
||||
|
||||
- Added convenience methods for taking advantage of Rime's customization
|
||||
options: spell, pauses, pronunciations, and inline speed control.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
|
||||
`LLMTextProcessor`. TTSServices still have an internal aggregator for support
|
||||
of default behavior, but if you want to override the aggregation behavior, you
|
||||
should use the new processor.
|
||||
|
||||
- The RTVI `bot-transcription` event is deprecated in favor of the new
|
||||
`bot-output` message which is the canonical representation of bot output
|
||||
(spoken or not). The code still emits a transcription message for backwards
|
||||
compatibility while transition occurs.
|
||||
|
||||
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a
|
||||
`pattern_id` and `remove_match` field in favor of the new `add_pattern` method
|
||||
which takes a `type` and an `action`
|
||||
|
||||
- `english_normalization` input parameter for `MiniMaxHttpTTSService` is
|
||||
deprecated, use `test_normalization` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `AWSBedrockLLMService` where the `aws_region` arg was
|
||||
always set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue with `DeepgramFluxSTTService` where it sometimes failed to reconnect.
|
||||
|
||||
- Fixed an issue in `ElevenLabsRealtimeSTTService` where dynamic language
|
||||
updates were not working.
|
||||
|
||||
- Fixed an issue in `ElevenLabsRealtimeSTTService` where setting the sample
|
||||
rate would result in transcripts failing.
|
||||
|
||||
- Fixed `InworldTTSService` audio config payload to use camelCase keys expected
|
||||
by the Inworld API.
|
||||
|
||||
## [0.0.95] - 2025-11-18
|
||||
|
||||
### Added
|
||||
|
||||
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
|
||||
example wiring; leverages the enhancement model for robust detection with no
|
||||
ONNX dependency or added processing complexity.
|
||||
|
||||
- Added a watchdog to `DeepgramFluxSTTService` to prevent dangling tasks in case the
|
||||
user was speaking and we stop receiving audio.
|
||||
|
||||
- Introduced a minimum confidence parameter in `DeepgramFluxSTTService` to avoid
|
||||
generating transcriptions below a defined threshold.
|
||||
|
||||
- Added `ElevenLabsRealtimeSTTService` which implements the Realtime STT
|
||||
service from ElevenLabs.
|
||||
|
||||
- Added word-level timestamps support to Hume TTS service
|
||||
- Added a `TTSService.includes_inter_frame_spaces` property getter, so that TTS
|
||||
services that subclass `TTSService` can indicate whether the text in the
|
||||
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
|
||||
|
||||
- Introduced new `AggregatedTextFrame` type to support representing a best effort of
|
||||
the perceived llm output whether or not it is processed by the TTS. This new frame
|
||||
type includes the field `aggregated_by` to represent the conceptual format by which
|
||||
the given text is aggregated. `TTSTextFrame`s now inherit from `AggregatedTextFrame`.
|
||||
With this inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate
|
||||
the perceived output and determine whether or not the text was spoken based on if that
|
||||
frame is also a `TTSTextFrame`. (See bullet below on new `bot-output` which takes
|
||||
advantage of this)
|
||||
|
||||
- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how
|
||||
LLMTextFrames should be aggregated and considered. It's purpose is to turn
|
||||
`LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still
|
||||
aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you
|
||||
wish to override how the llm text is aggregated, you should no longer override the
|
||||
TTS's internal aggregator, but instead, insert this processor between your LLM and
|
||||
TTS in the pipeline.
|
||||
|
||||
- New `bot-output` RTVI message to represent what the bot actually "says".
|
||||
- The `RTVIObserver` now emits `bot-output` messages based off the new `AggregatedTextFrame`s
|
||||
(`bot-tts-text` and `bot-llm-text` are still supported and generated, but `bot-transcript` is
|
||||
now deprecated in lieu of this new, more thorough, message).
|
||||
- The new `RTVIBotOutputMessage` includes the fields:
|
||||
- `spoken`: A boolean indicating whether the text was spoken by TTS
|
||||
- `aggregated_by`: A string representing how the text was aggregated ("sentence", "word",
|
||||
"my custom aggregation")
|
||||
- Introduced new fields to `RTVIObserver` to support the new `bot-output` messaging:
|
||||
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output messages.
|
||||
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that match
|
||||
aggregation types that should not be included in bot-output messages. (Ex. `credit_card`)
|
||||
- Introduced new methods, `add_text_transformer()` and `remove_text_transformer()`, to `RTVIObserver` to support providing (and subsequently removing)
|
||||
callbacks for various types of aggregations (or all aggregations with `*`) that can modify the
|
||||
text before being sent as a `bot-output` or `tts-text` message. (Think obscuring the credit card
|
||||
or inserting extra detail the client might want that the context doesn't need.)
|
||||
|
||||
- Updated the base aggregator type:
|
||||
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
|
||||
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
|
||||
aggregation")
|
||||
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
|
||||
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
|
||||
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
|
||||
(instead of `Optional[str]`). To update:
|
||||
```
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
if (aggregation):
|
||||
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
|
||||
```
|
||||
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
|
||||
produce/consume `Aggregation` objects.
|
||||
|
||||
- Augmented the `PatternPairAggregator`:
|
||||
- Introduced a new, preferred version of `add_pattern` to support a new option for treating a
|
||||
match as a separate aggregation returned from `aggregate()`. This replaces the now
|
||||
deprecated `add_pattern_pair` method and you provide a `MatchAction` in lieu of the `remove_match` field.
|
||||
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization for how
|
||||
a match should be handled.
|
||||
- `REMOVE`: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
- `KEEP`: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
- `AGGREGATE`: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
- `PatternMatch` now extends `Aggregation` and provides richer info to handlers.
|
||||
- **BREAKING**: The `PatternMatch` type returned to handlers registered via `on_pattern_match`
|
||||
has been updated to subclass from the new `Aggregation` type, which means that `content`
|
||||
has been replaced with `text` and `pattern_id` has been replaced with `type`:
|
||||
```
|
||||
async dev on_match_tag(match: PatternMatch):
|
||||
pattern = match.type # instead of match.pattern_id
|
||||
text = match.text # instead of match.content
|
||||
```
|
||||
|
||||
### Changed
|
||||
|
||||
- ⚠️ Breaking change: `LLMContext.create_image_message()`,
|
||||
`LLMContext.create_audio_message()`, `LLMContext.add_image_frame_message()`
|
||||
and `LLMContext.add_audio_frames_message()` are now async methods. This fixes
|
||||
an issue where the asyncio event loop would be blocked while encoding audio or
|
||||
images.
|
||||
|
||||
- `ConsumerProcessor` now queues frames from the producer internally instead of
|
||||
pushing them directly. This allows us to subclass consumer processors and
|
||||
manipulate frames before they are pushed.
|
||||
|
||||
- `BaseTextFilter` only require subclasses to implement the `filter()` method.
|
||||
|
||||
- Extracted the logic for retrying connections, and create a new `send_with_retry`
|
||||
method inside `WebSocketService`.
|
||||
|
||||
- Refactored `DeepgramFluxSTTService` to automatically reconnect if sending a
|
||||
message fails.
|
||||
|
||||
- Updated all STT and TTS services to use consistent error handling pattern with
|
||||
`push_error()` method for better pipeline error event integration.
|
||||
|
||||
- Added support for `maybe_capture_participant_camera()` and
|
||||
`maybe_capture_participant_screen()` for `SmallWebRTCTransport` in the runner
|
||||
utils.
|
||||
|
||||
- Added Hindi support for Rime TTS services.
|
||||
|
||||
- Updated `GeminiTTSService` to use Google Cloud Text-to-Speech streaming API
|
||||
@@ -426,18 +109,44 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Updated language mappings for the Google and Gemini TTS services to match
|
||||
official documentation.
|
||||
|
||||
- `TextFrame` new field `append_to_context` used to indicate if the encompassing
|
||||
text should be added to the LLM context (by the LLM assistant aggregator). It
|
||||
defaults to `True`.
|
||||
|
||||
- TTS flow respects aggregation metadata
|
||||
- `TTSService` accepts a new `skip_aggregator_types` to avoid speaking certain aggregation types
|
||||
(now determined/returned by the aggregator)
|
||||
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when either an
|
||||
aggregation occurs that should not be spoken or when the TTS service supports word-by-word
|
||||
timestamping. In the latter case, the `TTSService` preliminarily generates an
|
||||
`AggregatedTextFrame`, aggregated by sentence to generate the full sentence content as early
|
||||
as possible.
|
||||
- Introduced a new methods, `add_text_transformer()` and `remove_text_transformer()`:
|
||||
These functions introduce the ability to provide (and subsequently remove) callbacks to the TTS to transform text based on
|
||||
its aggregated type prior to sending the text to the underlying TTS service. This makes it
|
||||
possible to do things like introduce TTS-specific tags for spelling or emotion or change the
|
||||
pronunciation of something on the fly.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
|
||||
`credentials` or `credentials_path` instead for Google Cloud authentication.
|
||||
|
||||
- The RTVI `bot-transcription` event is deprecated in favor of the new `bot-output`
|
||||
message which is the canonical representation of bot output (spoken or not). The code
|
||||
still emits a transcription message for backwards compatibility while transition occurs.
|
||||
|
||||
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
|
||||
`LLMTextProcessor`. TTSServices still have an internal aggregator for support of default
|
||||
behavior, but if you want to override the aggregation behavior, you should use the new
|
||||
processor.
|
||||
|
||||
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a `pattern_id`
|
||||
and `remove_match` field in favor of the new `add_pattern` method which takes a `type` and an
|
||||
`action`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `SimliVideoService` connection issue.
|
||||
|
||||
- Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the
|
||||
`request_data` was not being passed to the `SmallWebRTCRunnerArguments` body.
|
||||
|
||||
- Fixed subtle issue of assistant context messages ending up with double spaces
|
||||
between words or sentences.
|
||||
|
||||
@@ -452,6 +161,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Prevented `HeyGenVideoService` from automatically disconnecting after 5 minutes.
|
||||
|
||||
### Added
|
||||
|
||||
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
|
||||
example wiring; leverages the enhancement model for robust detection with no
|
||||
ONNX dependency or added processing complexity.
|
||||
|
||||
## [0.0.94] - 2025-11-10
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel
|
||||
|
||||
**Examples:**
|
||||
|
||||
- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py)
|
||||
- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py)
|
||||
- [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py)
|
||||
|
||||
#### Key requirements:
|
||||
|
||||
@@ -119,6 +119,7 @@ def import_core_modules():
|
||||
"pipecat.observers",
|
||||
"pipecat.runner",
|
||||
"pipecat.serializers",
|
||||
"pipecat.sync",
|
||||
"pipecat.transcriptions",
|
||||
"pipecat.utils",
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ Quick Links
|
||||
Runner <api/pipecat.runner>
|
||||
Serializers <api/pipecat.serializers>
|
||||
Services <api/pipecat.services>
|
||||
Sync <api/pipecat.sync>
|
||||
Transcriptions <api/pipecat.transcriptions>
|
||||
Transports <api/pipecat.transports>
|
||||
Utils <api/pipecat.utils>
|
||||
Utils <api/pipecat.utils>
|
||||
@@ -44,7 +44,6 @@ DAILY_SAMPLE_ROOM_URL=https://...
|
||||
|
||||
# Deepgram
|
||||
DEEPGRAM_API_KEY=...
|
||||
SAGEMAKER_ENDPOINT_NAME=...
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=...
|
||||
|
||||
@@ -15,7 +15,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.services.riva.tts import FastPitchTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -36,7 +36,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
task = PipelineTask(
|
||||
Pipeline([tts, transport.output()]),
|
||||
@@ -13,13 +13,12 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, LLMRunFrame
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -31,44 +30,6 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
FILTERED_WORDS = ["apple", "banana", "car"]
|
||||
|
||||
|
||||
class ContentFilterProcessor(FrameProcessor):
|
||||
"""Processor that filters LLMContextFrames containing specific words.
|
||||
|
||||
If the user's message contains any of the filtered words, the context
|
||||
is replaced with a message indicating the assistant cannot respond.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
# Check the last user message for filtered words
|
||||
messages = frame.context.messages
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
content = last_message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
content_lower = content.lower()
|
||||
if any(word in content_lower for word in FILTERED_WORDS):
|
||||
logger.info(f"Filtered content detected: {content}")
|
||||
# Create a new context with a filtered response instruction
|
||||
filtered_context = LLMContext(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "The user is asking about something you cannot give an answer about. Tell them you don't know how to respond.",
|
||||
}
|
||||
]
|
||||
)
|
||||
await self.push_frame(LLMContextFrame(filtered_context), direction)
|
||||
return
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -115,14 +76,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
content_filter = ContentFilterProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
content_filter, # Content filter
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
|
||||
@@ -13,29 +13,24 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.hume.tts import HUME_SAMPLE_RATE, HumeTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -93,7 +88,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS (HumeTTSService with word timestamps)
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
@@ -107,14 +102,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
audio_out_sample_rate=HUME_SAMPLE_RATE,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
}
|
||||
),
|
||||
],
|
||||
observers=[RTVIObserver(rtvi)],
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
@@ -124,9 +112,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info(
|
||||
"💡 Word timestamps are enabled! Watch the console for TTSTextFrame logs showing each word with its PTS."
|
||||
)
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@@ -52,10 +52,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramFluxSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
params=DeepgramFluxSTTService.InputParams(min_confidence=0.3),
|
||||
)
|
||||
stt = DeepgramFluxSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.services.nvidia.stt import NvidiaSTTService
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.riva.stt import RivaSTTService
|
||||
from pipecat.services.riva.tts import RivaTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -59,13 +59,11 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
stt = RivaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
|
||||
)
|
||||
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
|
||||
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -117,7 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -15,21 +15,14 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame, UserImageRequestFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
@@ -73,27 +66,6 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
class MoondreamTextFrameWrapper(FrameProcessor):
|
||||
"""Wraps Moondream-provided TextFrames with LLM response start/end frames.
|
||||
|
||||
This processor detects TextFrames and automatically wraps them with
|
||||
LLMFullResponseStartFrame and LLMFullResponseEndFrame to provide proper
|
||||
response boundaries for downstream processors.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we receive a TextFrame, wrap it with response start/end frames
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(LLMFullResponseStartFrame(), direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.push_frame(LLMFullResponseEndFrame(), direction)
|
||||
else:
|
||||
# For all other frames, just pass them through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -158,12 +130,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# If you run into weird description, try with use_cpu=True
|
||||
moondream = MoondreamService()
|
||||
|
||||
# Wrap TextFrames with LLM response start/end frames, which makes Moondream
|
||||
# output be treated like LLM responses for the purpose of context
|
||||
# aggregation. Without this, the assistant context aggregator would ignore
|
||||
# Moondream output (if the TTS service is disabled).
|
||||
moondream_text_wrapper = MoondreamTextFrameWrapper()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
@@ -171,7 +137,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context_aggregator.user(), # User responses
|
||||
ParallelPipeline(
|
||||
[llm], # LLM
|
||||
[moondream, moondream_text_wrapper],
|
||||
[moondream],
|
||||
),
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -75,11 +75,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = NvidiaLLMService(
|
||||
llm = NimLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"),
|
||||
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Recommended when turning thinking off
|
||||
params=NvidiaLLMService.InputParams(temperature=0.0),
|
||||
params=NimLLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
@@ -14,13 +14,20 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
|
||||
from pipecat.frames.frames import (
|
||||
LLMRunFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
TranscriptionMessage,
|
||||
)
|
||||
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
|
||||
@@ -19,6 +19,7 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
|
||||
@@ -28,10 +28,10 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@@ -45,11 +45,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -46,11 +46,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -47,11 +47,11 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
@@ -391,7 +391,7 @@ class AudioAccumulator(FrameProcessor):
|
||||
)
|
||||
self._user_speaking = False
|
||||
context = LLMContext()
|
||||
await context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self.push_frame(LLMContextFrame(context=context))
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
||||
|
||||
@@ -150,7 +150,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
LLMLogObserver(),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION),
|
||||
UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE),
|
||||
EndFrame: None,
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp.client.session_group import SseServerParameters
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
@@ -22,16 +23,16 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.mcp_service import MCPClient
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -60,42 +61,56 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Initialize Deepgram SageMaker STT Service
|
||||
# This requires:
|
||||
# - AWS credentials configured (via environment variables or AWS CLI)
|
||||
# - A deployed SageMaker endpoint with Deepgram model
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region=os.getenv("AWS_REGION"),
|
||||
model="us.amazon.nova-pro-v1:0",
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8),
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-7-sonnet-latest"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
try:
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
mcp = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
context = LLMContext(messages)
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to a number of tools provided by mcp.run. Use any and all tools to help users.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
When asked for today's date, use 'https://www.datetoday.net/'.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
messages = [{"role": "system", "content": system}]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
stt,
|
||||
context_aggregator.user(), # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,9 +125,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
@@ -132,6 +146,14 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("MCP_RUN_SSE_URL"):
|
||||
logger.error(
|
||||
f"Please set MCP_RUN_SSE_URL environment variable for this example. See https://mcp.run"
|
||||
)
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -16,7 +15,7 @@ import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from mcp.client.session_group import SseServerParameters
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
@@ -67,12 +66,10 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
pattern = r"!\[[^\]]*\]\((https?://[^)]+\.(png|jpg|jpeg|PNG|JPG|JPEG|gif))\)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
@@ -135,11 +132,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection and the user's GitHub repositories and account.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
You can also offer to answer users questions about their GitHub repositories and account.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
@@ -149,11 +145,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages = [{"role": "system", "content": system}]
|
||||
|
||||
try:
|
||||
rijksmuseum_mcp = MCPClient(
|
||||
mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
args=["-y", "mcp-server-error setting up mcp"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
@@ -161,32 +157,24 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.error(f"error setting up rijksmuseum mcp")
|
||||
logger.exception("error trace:")
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
github_mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={
|
||||
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
|
||||
},
|
||||
)
|
||||
)
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
# ie. "https://www.mcp.run/api/mcp/sse?..."
|
||||
# ensure the profile has a tool or few installed
|
||||
mcp_run = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp.run")
|
||||
logger.exception("error trace:")
|
||||
|
||||
rijksmuseum_tools = {}
|
||||
github_tools = {}
|
||||
tools = {}
|
||||
run_tools = {}
|
||||
try:
|
||||
rijksmuseum_tools = await rijksmuseum_mcp.register_tools(llm)
|
||||
github_tools = await github_mcp.register_tools(llm)
|
||||
tools = await mcp.register_tools(llm)
|
||||
run_tools = await mcp_run.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
|
||||
all_standard_tools = run_tools.standard_tools + tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
|
||||
context = LLMContext(messages, all_tools)
|
||||
@@ -238,9 +226,9 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN"):
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("MCP_RUN_SSE_URL"):
|
||||
logger.error(
|
||||
f"Please set `RIJKSMUSEUM_API_KEY` and `GITHUB_PERSONAL_ACCESS_TOKEN` environment variables. See https://github.com/r-huijts/rijksmuseum-mcp."
|
||||
f"Please set RIJKSMUSEUM_API_KEY and MCP_RUN_SSE_URL environment variables. See https://github.com/r-huijts/rijksmuseum-mcp and https://mcp.run"
|
||||
)
|
||||
import sys
|
||||
|
||||
@@ -45,18 +45,18 @@ Source = "https://github.com/pipecat-ai/pipecat"
|
||||
Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.2.0" ]
|
||||
aic = [ "aic-sdk~=1.1.0" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
aws = [ "aioboto3~=15.5.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.2.0; python_version>='3.12'" ]
|
||||
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.1.1; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cerebras = []
|
||||
daily = [ "daily-python~=0.22.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0", "pipecat-ai[websockets-base]" ]
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.21.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
@@ -69,38 +69,37 @@ gstreamer = [ "pygobject~=3.50.0" ]
|
||||
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
|
||||
hume = [ "hume>=0.11.2" ]
|
||||
inworld = []
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0", "pyjwt>=2.10.1" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "pipecat-ai[websockets-base]" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
|
||||
mcp = [ "mcp[cli]>=1.11.0,<2" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
nim = []
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
nvidia = [ "nvidia-riva-client~=2.21.1" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
remote-smart-turn = []
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "pipecat-ai[nvidia]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
|
||||
sambanova = []
|
||||
sarvam = [ "sarvamai==0.1.21", "pipecat-ai[websockets-base]" ]
|
||||
sentry = [ "sentry-sdk>=2.28.0,<3" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime>=1.20.1,<2" ]
|
||||
simli = [ "simli-ai~=1.0.3"]
|
||||
simli = [ "simli-ai~=0.1.25"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.1" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.5.0" ]
|
||||
|
||||
@@ -30,8 +30,8 @@ EVAL_SIMPLE_MATH = EvalConfig(
|
||||
)
|
||||
|
||||
EVAL_WEATHER = EvalConfig(
|
||||
prompt="What's the weather in San Francisco (in farhenheit or celsius)?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees (in farhenheit or celsius).",
|
||||
prompt="What's the weather in San Francisco?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees.",
|
||||
)
|
||||
|
||||
EVAL_ONLINE_SEARCH = EvalConfig(
|
||||
@@ -70,7 +70,7 @@ EVAL_VOICEMAIL = EvalConfig(
|
||||
|
||||
EVAL_CONVERSATION = EvalConfig(
|
||||
prompt="Hello, this is Mark.",
|
||||
eval="The user acknowledges the greeting.",
|
||||
eval="The user replies with a greeting.",
|
||||
eval_speaks_first=True,
|
||||
)
|
||||
|
||||
@@ -103,7 +103,7 @@ TESTS_07 = [
|
||||
("07o-interruptible-assemblyai.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime-http.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-nvidia.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-riva-nim.py", EVAL_SIMPLE_MATH),
|
||||
("07s-interruptible-google-audio-in.py", EVAL_SIMPLE_MATH),
|
||||
("07t-interruptible-fish.py", EVAL_SIMPLE_MATH),
|
||||
("07v-interruptible-neuphonic.py", EVAL_SIMPLE_MATH),
|
||||
@@ -136,7 +136,7 @@ TESTS_14 = [
|
||||
("14g-function-calling-grok.py", EVAL_WEATHER),
|
||||
("14h-function-calling-azure.py", EVAL_WEATHER),
|
||||
("14i-function-calling-fireworks.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nvidia.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nim.py", EVAL_WEATHER),
|
||||
("14k-function-calling-cerebras.py", EVAL_WEATHER),
|
||||
("14m-function-calling-openrouter.py", EVAL_WEATHER),
|
||||
("14n-function-calling-perplexity.py", EVAL_WEATHER),
|
||||
|
||||
@@ -39,7 +39,7 @@ class AICFilter(BaseAudioFilter):
|
||||
self,
|
||||
*,
|
||||
license_key: str = "",
|
||||
model_type: AICModelType = AICModelType.QUAIL_STT,
|
||||
model_type: AICModelType = AICModelType.QUAIL_L,
|
||||
enhancement_level: Optional[float] = 1.0,
|
||||
voice_gain: Optional[float] = 1.0,
|
||||
noise_gate_enable: Optional[bool] = True,
|
||||
@@ -52,27 +52,12 @@ class AICFilter(BaseAudioFilter):
|
||||
enhancement_level: Optional overall enhancement strength (0.0..1.0).
|
||||
voice_gain: Optional linear gain applied to detected speech (0.0..4.0).
|
||||
noise_gate_enable: Optional enable/disable noise gate (default: True).
|
||||
|
||||
.. deprecated:: 1.3.0
|
||||
The `noise_gate_enable` parameter is deprecated and no longer has any effect.
|
||||
It will be removed in a future version.
|
||||
"""
|
||||
self._license_key = license_key
|
||||
self._model_type = model_type
|
||||
|
||||
self._enhancement_level = enhancement_level
|
||||
self._voice_gain = voice_gain
|
||||
if noise_gate_enable is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter `noise_gate_enable` is deprecated and no longer has any effect. "
|
||||
"It will be removed in a future version. Use AIC VAD instead (create_vad_analyzer()).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._noise_gate_enable = noise_gate_enable
|
||||
|
||||
self._enabled = True
|
||||
@@ -164,6 +149,10 @@ class AICFilter(BaseAudioFilter):
|
||||
)
|
||||
if self._voice_gain is not None:
|
||||
self._aic.set_parameter(AICParameter.VOICE_GAIN, float(self._voice_gain))
|
||||
if self._noise_gate_enable is not None:
|
||||
self._aic.set_parameter(
|
||||
AICParameter.NOISE_GATE_ENABLE, 1.0 if bool(self._noise_gate_enable) else 0.0
|
||||
)
|
||||
|
||||
self._aic_ready = True
|
||||
|
||||
|
||||
@@ -18,10 +18,8 @@ from loguru import logger
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMTextFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -151,17 +149,10 @@ class IVRProcessor(FrameProcessor):
|
||||
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
# Process text through the pattern aggregator
|
||||
async for result in self._aggregator.aggregate(frame.text):
|
||||
result = await self._aggregator.aggregate(frame.text)
|
||||
if result:
|
||||
# Push aggregated text that doesn't contain XML patterns
|
||||
await self.push_frame(LLMTextFrame(result.text), direction)
|
||||
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
# Flush any remaining text from the aggregator
|
||||
remaining = await self._aggregator.flush()
|
||||
if remaining:
|
||||
await self.push_frame(LLMTextFrame(remaining.text), direction)
|
||||
# Push the end frame
|
||||
await self.push_frame(frame, direction)
|
||||
await self.push_frame(LLMTextFrame(result), direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -40,8 +40,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
|
||||
|
||||
class NotifierGate(FrameProcessor):
|
||||
|
||||
@@ -330,7 +330,7 @@ class TextFrame(DataFrame):
|
||||
"""
|
||||
|
||||
text: str
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
skip_tts: bool = field(init=False)
|
||||
# Whether any necessary inter-frame (leading/trailing) spaces are already
|
||||
# included in the text.
|
||||
# NOTE: Ideally this would be available at init time with a default value,
|
||||
@@ -343,7 +343,7 @@ class TextFrame(DataFrame):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = None
|
||||
self.skip_tts = False
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
@@ -356,10 +356,7 @@ class TextFrame(DataFrame):
|
||||
class LLMTextFrame(TextFrame):
|
||||
"""Text frame generated by LLM services."""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# LLM services send text frames with all necessary spaces included
|
||||
self.includes_inter_frame_spaces = True
|
||||
pass
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
@@ -835,13 +832,11 @@ class ErrorFrame(SystemFrame):
|
||||
error: Description of the error that occurred.
|
||||
fatal: Whether the error is fatal and requires bot shutdown.
|
||||
processor: The frame processor that generated the error.
|
||||
exception: The exception that occurred.
|
||||
"""
|
||||
|
||||
error: str
|
||||
fatal: bool = False
|
||||
processor: Optional["FrameProcessor"] = None
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(error: {self.error}, fatal: {self.fatal})"
|
||||
@@ -1632,22 +1627,22 @@ class LLMFullResponseStartFrame(ControlFrame):
|
||||
more TextFrames and a final LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = None
|
||||
self.skip_tts = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseEndFrame(ControlFrame):
|
||||
"""Frame indicating the end of an LLM response."""
|
||||
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = None
|
||||
self.skip_tts = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class GatedLLMContextAggregator(FrameProcessor):
|
||||
|
||||
@@ -14,7 +14,6 @@ translation from this universal context into whatever format it needs, using a
|
||||
service-specific adapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import wave
|
||||
@@ -138,7 +137,7 @@ class LLMContext:
|
||||
return {"role": role, "content": content}
|
||||
|
||||
@staticmethod
|
||||
async def create_image_message(
|
||||
def create_image_message(
|
||||
*,
|
||||
role: str = "user",
|
||||
format: str,
|
||||
@@ -155,21 +154,15 @@ class LLMContext:
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
|
||||
def encode_image():
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
encoded_image = await asyncio.to_thread(encode_image)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
|
||||
return LLMContext.create_image_url_message(role=role, url=url, text=text)
|
||||
|
||||
@staticmethod
|
||||
async def create_audio_message(
|
||||
def create_audio_message(
|
||||
*, role: str = "user", audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
) -> LLMContextMessage:
|
||||
"""Create a context message containing audio.
|
||||
@@ -179,26 +172,21 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
async def encode_audio():
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_audio
|
||||
|
||||
encoded_audio = await asyncio.to_thread(encode_audio)
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content.append(
|
||||
{
|
||||
@@ -333,7 +321,7 @@ class LLMContext:
|
||||
"""
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
async def add_image_frame_message(
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
@@ -344,12 +332,10 @@ class LLMContext:
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
message = await LLMContext.create_image_message(
|
||||
format=format, size=size, image=image, text=text
|
||||
)
|
||||
message = LLMContext.create_image_message(format=format, size=size, image=image, text=text)
|
||||
self.add_message(message)
|
||||
|
||||
async def add_audio_frames_message(
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add a message containing audio frames.
|
||||
@@ -358,7 +344,7 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
message = await LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
|
||||
message = LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
|
||||
self.add_message(message)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -66,7 +66,7 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -90,7 +90,15 @@ class LLMContextAggregator(FrameProcessor):
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation: List[TextPartForConcatenation] = []
|
||||
self._aggregation: List[str] = []
|
||||
|
||||
# Whether to add spaces between text parts.
|
||||
# (Currently only used by LLMAssistantAggregator, but could be expanded
|
||||
# to LLMUserAggregator in the future if needed; that would require
|
||||
# additional work since LLMUserAggregator currently trims spaces from
|
||||
# incoming frames before determining whether it "really" received any
|
||||
# text).
|
||||
self._add_spaces = True
|
||||
|
||||
@property
|
||||
def messages(self) -> List[LLMContextMessage]:
|
||||
@@ -183,7 +191,7 @@ class LLMContextAggregator(FrameProcessor):
|
||||
Returns:
|
||||
The concatenated aggregation string.
|
||||
"""
|
||||
return concatenate_aggregated_text(self._aggregation)
|
||||
return concatenate_aggregated_text(self._aggregation, self._add_spaces)
|
||||
|
||||
|
||||
class LLMUserAggregator(LLMContextAggregator):
|
||||
@@ -433,12 +441,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
# Transcriptions never include inter-part spaces (so far).
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
self._aggregation.append(text)
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
@@ -793,7 +796,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
await self._context.add_image_frame_message(
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
@@ -818,11 +821,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# Track whether we need to add spaces between text parts
|
||||
# Assumption: we can just keep track of the latest frame's value
|
||||
self._add_spaces = not frame.includes_inter_frame_spaces
|
||||
|
||||
self._aggregation.append(frame.text)
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
|
||||
@@ -83,7 +83,8 @@ class LLMTextProcessor(FrameProcessor):
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
async for aggregation in self._text_aggregator.aggregate(in_frame.text):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
@@ -91,13 +92,15 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: Optional[bool] = None):
|
||||
# Flush any remaining text
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=remaining.text,
|
||||
aggregated_by=remaining.type,
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -83,4 +83,4 @@ class ConsumerProcessor(FrameProcessor):
|
||||
while True:
|
||||
frame = await self._queue.get()
|
||||
new_frame = await self._transformer(frame)
|
||||
await self.queue_frame(new_frame, self._direction)
|
||||
await self.push_frame(new_frame, self._direction)
|
||||
|
||||
@@ -126,4 +126,6 @@ class WakeCheckFilter(FrameProcessor):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error in wake word filter: {e}", exception=e)
|
||||
error_msg = f"Error in wake word filter: {e}"
|
||||
logger.exception(error_msg)
|
||||
await self.push_error(ErrorFrame(error_msg))
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Awaitable, Callable, Tuple, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class WakeNotifierFilter(FrameProcessor):
|
||||
|
||||
@@ -142,7 +142,6 @@ class FrameProcessor(BaseObject):
|
||||
- on_after_process_frame: Called after a frame is processed
|
||||
- on_before_push_frame: Called before a frame is pushed
|
||||
- on_after_push_frame: Called after a frame is pushed
|
||||
- on_error: Called when an error is raised in the frame processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -235,7 +234,6 @@ class FrameProcessor(BaseObject):
|
||||
self._register_event_handler("on_after_process_frame", sync=True)
|
||||
self._register_event_handler("on_before_push_frame", sync=True)
|
||||
self._register_event_handler("on_after_push_frame", sync=True)
|
||||
self._register_event_handler("on_error", sync=True)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -632,43 +630,7 @@ class FrameProcessor(BaseObject):
|
||||
elif isinstance(frame, (FrameProcessorResumeFrame, FrameProcessorResumeUrgentFrame)):
|
||||
await self.__resume(frame)
|
||||
|
||||
async def push_error(
|
||||
self,
|
||||
error_msg: str,
|
||||
exception: Optional[Exception] = None,
|
||||
fatal: bool = False,
|
||||
):
|
||||
"""Creates and pushes an ErrorFrame upstream.
|
||||
|
||||
Creates and pushes an ErrorFrame upstream to notify other processors in the
|
||||
pipeline about an error condition. The error frame will include context about
|
||||
which processor generated the error.
|
||||
|
||||
Args:
|
||||
error_msg: Descriptive message explaining the error condition.
|
||||
exception: Optional exception object that caused the error, if available.
|
||||
This provides additional context for debugging and error handling.
|
||||
fatal: Whether this error should be considered fatal to the pipeline.
|
||||
Fatal errors typically cause the entire pipeline to stop processing.
|
||||
Defaults to False for non-fatal errors.
|
||||
|
||||
Example::
|
||||
|
||||
```python
|
||||
# Non-fatal error
|
||||
await self.push_error("Failed to process audio chunk, skipping")
|
||||
|
||||
# Fatal error with exception context
|
||||
try:
|
||||
result = some_critical_operation()
|
||||
except Exception as e:
|
||||
await self.push_error("Critical operation failed", exception=e, fatal=True)
|
||||
```
|
||||
"""
|
||||
error_frame = ErrorFrame(error=error_msg, fatal=fatal, exception=exception, processor=self)
|
||||
await self.push_error_frame(error=error_frame)
|
||||
|
||||
async def push_error_frame(self, error: ErrorFrame):
|
||||
async def push_error(self, error: ErrorFrame):
|
||||
"""Push an error frame upstream.
|
||||
|
||||
Args:
|
||||
@@ -676,8 +638,6 @@ class FrameProcessor(BaseObject):
|
||||
"""
|
||||
if not error.processor:
|
||||
error.processor = self
|
||||
await self._call_event_handler("on_error", error)
|
||||
logger.error(f"{error.processor} error: {error.error}")
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -799,10 +759,8 @@ class FrameProcessor(BaseObject):
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Uncaught exception handling _start_interruption: {e}",
|
||||
exception=e,
|
||||
)
|
||||
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Internal method to push frames to adjacent processors.
|
||||
@@ -839,7 +797,8 @@ class FrameProcessor(BaseObject):
|
||||
await self._observer.on_push_frame(data)
|
||||
await self._prev.queue_frame(frame, direction)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Uncaught exception: {e}", exception=e)
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
def _check_started(self, frame: Frame):
|
||||
"""Check if the processor has been started.
|
||||
@@ -915,7 +874,8 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_process_frame", frame)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error processing frame: {e}", exception=e)
|
||||
logger.exception(f"{self}: error processing frame: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
async def __input_frame_task_handler(self):
|
||||
"""Handle frames from the input queue.
|
||||
|
||||
@@ -24,7 +24,7 @@ try:
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables import Runnable
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
|
||||
logger.exception("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -113,6 +113,6 @@ class LangchainProcessor(FrameProcessor):
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -23,7 +23,7 @@ try:
|
||||
from strands import Agent
|
||||
from strands.multiagent.graph import Graph
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error("In order to use Strands Agents, you need to `pip install strands-agents`.")
|
||||
logger.exception("In order to use Strands Agents, you need to `pip install strands-agents`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
finally:
|
||||
if ttfb_tracking:
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -26,7 +26,7 @@ from pipecat.frames.frames import (
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -98,9 +98,15 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._current_text_parts: List[TextPartForConcatenation] = []
|
||||
self._current_text_parts: List[str] = []
|
||||
self._aggregation_start_time: Optional[str] = None
|
||||
|
||||
# Whether to add spaces between text parts.
|
||||
# (The use of this could be expanded to the UserTranscriptProcessor in
|
||||
# the future if needed; currently the UserTranscriptProcessor assumes
|
||||
# that user transcription frames do not need aggregation).
|
||||
self._add_spaces = True
|
||||
|
||||
async def _emit_aggregated_text(self):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
@@ -141,7 +147,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
Result: "Hello there how are you"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = concatenate_aggregated_text(self._current_text_parts)
|
||||
content = concatenate_aggregated_text(self._current_text_parts, self._add_spaces)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
@@ -185,11 +191,11 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
if not self._aggregation_start_time:
|
||||
self._aggregation_start_time = time_now_iso8601()
|
||||
|
||||
self._current_text_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# Track whether we need to add spaces between text parts
|
||||
# Assumption: we can just keep track of the latest frame's value
|
||||
self._add_spaces = not frame.includes_inter_frame_spaces
|
||||
|
||||
self._current_text_parts.append(frame.text)
|
||||
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -264,10 +264,7 @@ def _setup_webrtc_routes(
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
|
||||
runner_args = SmallWebRTCRunnerArguments(
|
||||
webrtc_connection=connection, body=request.request_data
|
||||
)
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
@@ -302,7 +299,7 @@ def _setup_webrtc_routes(
|
||||
result: StartBotResult = {"sessionId": session_id}
|
||||
if request_data.get("enableDefaultIceServers"):
|
||||
result["iceConfig"] = IceConfig(
|
||||
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
|
||||
iceServers=[IceServer(urls="stun:stun.l.google.com:19302")]
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -329,8 +326,7 @@ def _setup_webrtc_routes(
|
||||
type=request_data["type"],
|
||||
pc_id=request_data.get("pc_id"),
|
||||
restart_pc=request_data.get("restart_pc"),
|
||||
request_data=request_data.get("request_data")
|
||||
or request_data.get("requestData"),
|
||||
request_data=request_data,
|
||||
)
|
||||
return await offer(webrtc_request, background_tasks)
|
||||
elif request.method == HTTPMethod.PATCH.value:
|
||||
|
||||
@@ -281,14 +281,6 @@ async def maybe_capture_participant_camera(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
if isinstance(transport, SmallWebRTCTransport):
|
||||
await transport.capture_participant_video(video_source="camera")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
async def maybe_capture_participant_screen(
|
||||
transport: BaseTransport, client: Any, framerate: int = 0
|
||||
@@ -311,14 +303,6 @@ async def maybe_capture_participant_screen(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
if isinstance(transport, SmallWebRTCTransport):
|
||||
await transport.capture_participant_video(video_source="screenVideo")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _smallwebrtc_sdp_cleanup_ice_candidates(text: str, pattern: str) -> str:
|
||||
"""Clean up ICE candidates in SDP text for SmallWebRTC.
|
||||
|
||||
@@ -199,7 +199,7 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hang up Plivo call: {e}")
|
||||
logger.exception(f"Failed to hang up Plivo call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Plivo WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -225,7 +225,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hang up Telnyx call: {e}")
|
||||
logger.exception(f"Failed to hang up Telnyx call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Telnyx WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -236,7 +236,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hang up Twilio call: {e}")
|
||||
logger.exception(f"Failed to hang up Twilio call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Twilio WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -166,6 +166,6 @@ class AIService(FrameProcessor):
|
||||
async for f in generator:
|
||||
if f:
|
||||
if isinstance(f, ErrorFrame):
|
||||
await self.push_error_frame(f)
|
||||
await self.push_error(f)
|
||||
else:
|
||||
await self.push_frame(f)
|
||||
|
||||
@@ -373,7 +373,9 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
if hasattr(event.delta, "text"):
|
||||
await self.push_frame(LLMTextFrame(event.delta.text))
|
||||
frame = LLMTextFrame(event.delta.text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
completion_tokens_estimate += self._estimate_tokens(event.delta.text)
|
||||
elif hasattr(event.delta, "partial_json") and tool_use_block:
|
||||
json_accumulator += event.delta.partial_json
|
||||
@@ -458,7 +460,8 @@ class AnthropicLLMService(LLMService):
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"{e}"))
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -206,8 +206,9 @@ class AssemblyAISTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
self._connected = False
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -232,7 +233,8 @@ class AssemblyAISTTService(STTService):
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
@@ -240,7 +242,8 @@ class AssemblyAISTTService(STTService):
|
||||
await self._websocket.close()
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
finally:
|
||||
self._websocket = None
|
||||
@@ -259,11 +262,13 @@ class AssemblyAISTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
break
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
|
||||
"""Parse a raw message into the appropriate message type."""
|
||||
@@ -292,7 +297,8 @@ class AssemblyAISTTService(STTService):
|
||||
elif isinstance(parsed_message, TerminationMessage):
|
||||
await self._handle_termination(parsed_message)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _handle_termination(self, message: TerminationMessage):
|
||||
"""Handle termination message."""
|
||||
|
||||
@@ -146,6 +146,15 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
@@ -228,7 +237,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -240,7 +250,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Async")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._started = False
|
||||
@@ -285,11 +296,12 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg.get("error_code"):
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
|
||||
else:
|
||||
await self.push_error(error_msg=f"Unknown message type: {msg}")
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic keepalive messages to maintain WebSocket connection."""
|
||||
@@ -332,14 +344,16 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AsyncAIHttpTTSService(TTSService):
|
||||
@@ -419,6 +433,15 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
@@ -472,7 +495,8 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
await self.push_error(error_msg=f"Async API error: {error_text}")
|
||||
logger.error(f"Async API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(error=f"Async API error: {error_text}"))
|
||||
raise Exception(f"Async API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -488,7 +512,8 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -8,10 +8,8 @@ import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .agent_core import *
|
||||
from .llm import *
|
||||
from .nova_sonic import *
|
||||
from .sagemaker import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""AWS AgentCore Processor Module.
|
||||
|
||||
This module defines the AWSAgentCoreProcessor, which invokes agents hosted on
|
||||
Amazon Bedrock AgentCore Runtime and streams their responses as LLMTextFrames.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aioboto3
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
def default_context_to_payload_transformer(
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
) -> Optional[str]:
|
||||
"""Default transformer to create AgentCore payload from LLM context.
|
||||
|
||||
Extracts the latest user or system message text and wraps it in {"prompt": "<text>"}.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation messages.
|
||||
|
||||
Returns:
|
||||
A JSON string payload for AgentCore, or None if no valid message found.
|
||||
"""
|
||||
messages = context.messages
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if isinstance(last_message, LLMSpecificMessage) or last_message.get("role") not in (
|
||||
"user",
|
||||
"system",
|
||||
):
|
||||
return None
|
||||
|
||||
content = last_message.get("content")
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if isinstance(content, str):
|
||||
prompt = content
|
||||
elif isinstance(content, list):
|
||||
prompt = " ".join([part.get("text", "") for part in content])
|
||||
else:
|
||||
return None
|
||||
|
||||
return json.dumps({"prompt": prompt})
|
||||
|
||||
|
||||
def default_response_to_output_transformer(response_line: str) -> Optional[str]:
|
||||
"""Default transformer to extract output text from AgentCore response.
|
||||
|
||||
Expects responses with {"response": "<text>"} format.
|
||||
|
||||
Args:
|
||||
response_line: The raw response line from AgentCore (without "data: " prefix).
|
||||
|
||||
Returns:
|
||||
The extracted output text, or None if no text found.
|
||||
"""
|
||||
response_json = json.loads(response_line)
|
||||
return response_json.get("response")
|
||||
|
||||
|
||||
class AWSAgentCoreProcessor(FrameProcessor):
|
||||
"""Processor that runs an Amazon Bedrock AgentCore agent.
|
||||
|
||||
Input:
|
||||
- LLMContextFrame: Supplies a context used to invoke the agent.
|
||||
|
||||
Output:
|
||||
- LLMTextFrame: The agent's text response(s).
|
||||
A single agent invocation may result in multiple text frames.
|
||||
|
||||
This processor transforms the input context to a payload for the AgentCore
|
||||
agent, and transforms the agent's response(s) into output text frame(s). Both
|
||||
mappings are configurable via transformers. Below is the default behavior.
|
||||
|
||||
Input transformer (context_to_payload_transformer):
|
||||
- Grabs the latest user or system message (if it's the latest message)
|
||||
- Extracts its text content
|
||||
- Constructs a payload that looks like {"prompt": "<text>"}
|
||||
|
||||
Output transformer (response_to_output_transformer):
|
||||
- Expects responses that look like {"response": "<text>"}
|
||||
- Extracts the text for use in the LLMTextFrame(s)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agentArn: str,
|
||||
aws_access_key: Optional[str] = None,
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: Optional[str] = None,
|
||||
context_to_payload_transformer: Optional[
|
||||
Callable[[LLMContext | OpenAILLMContext], Optional[str]]
|
||||
] = None,
|
||||
response_to_output_transformer: Optional[Callable[[str], Optional[str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the AWS AgentCore processor.
|
||||
|
||||
Args:
|
||||
agentArn: The Amazon Web Services Resource Name (ARN) of the agent.
|
||||
aws_access_key: AWS access key ID. If None, uses default credentials.
|
||||
aws_secret_key: AWS secret access key. If None, uses default credentials.
|
||||
aws_session_token: AWS session token for temporary credentials.
|
||||
aws_region: AWS region.
|
||||
context_to_payload_transformer: Optional callable to transform
|
||||
LLMContext into AgentCore payload string. If None, uses
|
||||
default_context_to_payload_transformer.
|
||||
response_to_output_transformer: Optional callable to extract output text
|
||||
from AgentCore response. If None, uses
|
||||
default_response_to_output_transformer.
|
||||
**kwargs: Additional arguments passed to parent FrameProcessor.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._agentArn = agentArn
|
||||
self._aws_session = aioboto3.Session()
|
||||
|
||||
# Store AWS session parameters for creating client in async context
|
||||
self._aws_params = {
|
||||
"aws_access_key_id": aws_access_key or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": aws_secret_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region_name": aws_region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
# Set transformers with defaults
|
||||
self._context_to_payload_transformer = (
|
||||
context_to_payload_transformer or default_context_to_payload_transformer
|
||||
)
|
||||
self._response_to_output_transformer = (
|
||||
response_to_output_transformer or default_response_to_output_transformer
|
||||
)
|
||||
|
||||
# State for managing output response bookends
|
||||
self._output_response_open = False
|
||||
self._last_text_frame_time: Optional[float] = None
|
||||
self._close_task: Optional[asyncio.Task] = None
|
||||
self._output_response_timeout = 1.0 # seconds
|
||||
|
||||
async def _close_output_response_after_timeout(self):
|
||||
"""Close the output response after timeout if no new text frames arrive."""
|
||||
await asyncio.sleep(self._output_response_timeout)
|
||||
if self._output_response_open:
|
||||
self._output_response_open = False
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _push_text_frame(self, text: str):
|
||||
"""Push a text frame, managing output response bookends."""
|
||||
# Cancel any pending close task
|
||||
if self._close_task and not self._close_task.done():
|
||||
await self.cancel_task(self._close_task)
|
||||
|
||||
# Open output response if needed
|
||||
if not self._output_response_open:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
self._output_response_open = True
|
||||
|
||||
# Push the text frame
|
||||
await self.push_frame(LLMTextFrame(text))
|
||||
self._last_text_frame_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Schedule closing the output response after timeout
|
||||
self._close_task = self.create_task(self._close_output_response_after_timeout())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and handle LLM message frames.
|
||||
|
||||
Args:
|
||||
frame: The incoming frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
# Create payload to invoke AgentCore agent
|
||||
payload = self._context_to_payload_transformer(frame.context)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
async with self._aws_session.client("bedrock-agentcore", **self._aws_params) as client:
|
||||
# Invoke the AgentCore agent
|
||||
response = await client.invoke_agent_runtime(
|
||||
agentRuntimeArn=self._agentArn, payload=payload.encode()
|
||||
)
|
||||
|
||||
# Determine if this is a streamed multi-part response, which
|
||||
# will affect our parsing
|
||||
is_multi_part_response = "text/event-stream" in response.get("contentType", "")
|
||||
|
||||
# Handle each response part (there may be one, for single
|
||||
# responses, or multiple, for streamed multi-part responses)
|
||||
async for part in response.get("response", []):
|
||||
part_string = part.decode("utf-8")
|
||||
|
||||
# In streamed multi-part responses, each part might have
|
||||
# one or more lines, each of which starts with "data: ".
|
||||
# Treat each line as a response.
|
||||
if is_multi_part_response:
|
||||
for line in part_string.split("\n"):
|
||||
# Get response text from this line
|
||||
if not line:
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
logger.warning(f"Expected line to start with 'data: ', got: {line}")
|
||||
continue
|
||||
line = line[6:] # omit "data: "
|
||||
|
||||
# Transform response line to output text
|
||||
text = self._response_to_output_transformer(line)
|
||||
if text:
|
||||
await self._push_text_frame(text)
|
||||
|
||||
# In single-part responses, the whole part is one response
|
||||
# and there's no "data: " prefix
|
||||
else:
|
||||
# Transform response part string to output text
|
||||
text = self._response_to_output_transformer(part_string)
|
||||
if text:
|
||||
await self._push_text_frame(text)
|
||||
|
||||
# Final close if output response is still open after all parts processed
|
||||
if self._output_response_open:
|
||||
if self._close_task and not self._close_task.done():
|
||||
await self.cancel_task(self._close_task)
|
||||
self._output_response_open = False
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -734,7 +734,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
aws_access_key: Optional[str] = None,
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: Optional[str] = None,
|
||||
aws_region: str = "us-east-1",
|
||||
params: Optional[InputParams] = None,
|
||||
client_config: Optional[Config] = None,
|
||||
retry_timeout_secs: Optional[float] = 5.0,
|
||||
@@ -1078,7 +1078,9 @@ class AWSBedrockLLMService(LLMService):
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
await self.push_frame(LLMTextFrame(delta["text"]))
|
||||
frame = LLMTextFrame(delta["text"])
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
completion_tokens_estimate += self._estimate_tokens(delta["text"])
|
||||
elif "toolUse" in delta and "input" in delta["toolUse"]:
|
||||
# Handle partial JSON for tool use
|
||||
@@ -1136,7 +1138,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
except (ReadTimeoutError, asyncio.TimeoutError):
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -453,7 +453,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._ready_to_send_context = True
|
||||
await self._finish_connecting_if_context_available()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Initialization error: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
@@ -577,7 +577,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.info("Finished disconnecting")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
def _create_client(self) -> BedrockRuntimeClient:
|
||||
config = Config(
|
||||
@@ -885,7 +885,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# Errors are kind of expected while disconnecting, so just
|
||||
# ignore them and do nothing
|
||||
return
|
||||
await self.push_error(error_msg=f"Error processing responses: {e}", exception=e)
|
||||
logger.error(f"{self} error processing responses: {e}")
|
||||
if self._wants_connection:
|
||||
await self.reset_conversation()
|
||||
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""AWS SageMaker bidirectional streaming client.
|
||||
|
||||
This module provides a client for streaming bidirectional communication with
|
||||
SageMaker endpoints using the HTTP/2 protocol. Supports sending audio, text,
|
||||
and JSON data to SageMaker model endpoints and receiving streaming responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from aws_sdk_sagemaker_runtime_http2.client import SageMakerRuntimeHTTP2Client
|
||||
from aws_sdk_sagemaker_runtime_http2.config import Config, HTTPAuthSchemeResolver
|
||||
from aws_sdk_sagemaker_runtime_http2.models import (
|
||||
InvokeEndpointWithBidirectionalStreamInput,
|
||||
RequestPayloadPart,
|
||||
RequestStreamEventPayloadPart,
|
||||
ResponseStreamEvent,
|
||||
)
|
||||
from smithy_aws_core.auth.sigv4 import SigV4AuthScheme
|
||||
from smithy_aws_core.identity import EnvironmentCredentialsResolver
|
||||
from smithy_core.aio.eventstream import DuplexEventStream
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use SageMaker BiDi client, you need to `pip install pipecat-ai[sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class SageMakerBidiClient:
|
||||
"""Client for bidirectional streaming with AWS SageMaker endpoints.
|
||||
|
||||
Handles low-level HTTP/2 bidirectional streaming protocol for communicating
|
||||
with SageMaker model endpoints. Provides methods for sending various data
|
||||
types (audio, text, JSON) and receiving streaming responses.
|
||||
|
||||
This client uses AWS SigV4 authentication and supports credential resolution
|
||||
from environment variables, AWS CLI configuration, and instance metadata.
|
||||
|
||||
Example::
|
||||
|
||||
client = SageMakerBidiClient(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string="model=nova-3&language=en"
|
||||
)
|
||||
await client.start_session()
|
||||
await client.send_audio_chunk(audio_bytes)
|
||||
response = await client.receive_response()
|
||||
await client.close_session()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
model_invocation_path: str = "",
|
||||
model_query_string: str = "",
|
||||
):
|
||||
"""Initialize the SageMaker BiDi client.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint to connect to.
|
||||
region: AWS region where the endpoint is deployed.
|
||||
model_invocation_path: API path for the model invocation (e.g., "v1/listen").
|
||||
model_query_string: Query string parameters for the model (e.g., "model=nova-3").
|
||||
"""
|
||||
self.endpoint_name = endpoint_name
|
||||
self.region = region
|
||||
self.model_invocation_path = model_invocation_path
|
||||
self.model_query_string = model_query_string
|
||||
self.bidi_endpoint = f"https://runtime.sagemaker.{region}.amazonaws.com:8443"
|
||||
self._client: Optional[SageMakerRuntimeHTTP2Client] = None
|
||||
self._stream: Optional[
|
||||
DuplexEventStream[RequestStreamEventPayloadPart, ResponseStreamEvent, any]
|
||||
] = None
|
||||
self._output_stream = None
|
||||
self._is_active = False
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the SageMaker Runtime HTTP2 client with AWS credentials.
|
||||
|
||||
Creates and configures the SageMaker Runtime HTTP2 client with SigV4
|
||||
authentication. Attempts to resolve AWS credentials from environment
|
||||
variables, AWS CLI configuration, or instance metadata.
|
||||
"""
|
||||
logger.debug(f"Initializing SageMaker BiDi client for region: {self.region}")
|
||||
logger.debug(f"Using endpoint URI: {self.bidi_endpoint}")
|
||||
|
||||
# Check for AWS credentials
|
||||
has_env_creds = bool(os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"))
|
||||
|
||||
if not has_env_creds:
|
||||
logger.warning(
|
||||
"AWS credentials not found in environment variables. "
|
||||
"Attempting to use EnvironmentCredentialsResolver which will check "
|
||||
"AWS CLI configuration and instance metadata."
|
||||
)
|
||||
|
||||
config = Config(
|
||||
endpoint_uri=self.bidi_endpoint,
|
||||
region=self.region,
|
||||
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
|
||||
auth_scheme_resolver=HTTPAuthSchemeResolver(),
|
||||
auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="sagemaker")},
|
||||
)
|
||||
self._client = SageMakerRuntimeHTTP2Client(config=config)
|
||||
|
||||
async def start_session(self):
|
||||
"""Start a bidirectional streaming session with the SageMaker endpoint.
|
||||
|
||||
Initializes the client if needed, creates the bidirectional stream, and
|
||||
establishes the connection to the SageMaker endpoint. Must be called
|
||||
before sending or receiving data.
|
||||
|
||||
Returns:
|
||||
The output stream for receiving responses.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client initialization or connection fails.
|
||||
"""
|
||||
if not self._client:
|
||||
self._initialize_client()
|
||||
|
||||
logger.debug(f"Starting BiDi session with endpoint: {self.endpoint_name}")
|
||||
logger.debug(f"Model invocation path: {self.model_invocation_path}")
|
||||
logger.debug(f"Model query string: {self.model_query_string}")
|
||||
|
||||
# Create the bidirectional stream
|
||||
stream_input = InvokeEndpointWithBidirectionalStreamInput(
|
||||
endpoint_name=self.endpoint_name,
|
||||
model_invocation_path=self.model_invocation_path,
|
||||
model_query_string=self.model_query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
self._stream = await self._client.invoke_endpoint_with_bidirectional_stream(
|
||||
stream_input
|
||||
)
|
||||
self._is_active = True
|
||||
|
||||
# Get output stream
|
||||
output = await self._stream.await_output()
|
||||
self._output_stream = output[1]
|
||||
|
||||
logger.debug("BiDi session started successfully")
|
||||
return self._output_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start BiDi session: {e}")
|
||||
self._is_active = False
|
||||
raise RuntimeError(f"Failed to start SageMaker BiDi session: {e}")
|
||||
|
||||
async def send_data(self, data_bytes: bytes, data_type: Optional[str] = None):
|
||||
"""Send a chunk of data to the stream.
|
||||
|
||||
Generic method for sending any type of data to the SageMaker endpoint.
|
||||
Use the convenience methods (send_audio_chunk, send_text, send_json)
|
||||
for common data types.
|
||||
|
||||
Args:
|
||||
data_bytes: Raw bytes to send.
|
||||
data_type: Optional data type header. Common values are "BINARY" for
|
||||
audio/binary data and "UTF8" for text/JSON data.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
if not self._is_active or not self._stream:
|
||||
raise RuntimeError("BiDi session not active")
|
||||
|
||||
try:
|
||||
payload = RequestPayloadPart(bytes_=data_bytes, data_type=data_type)
|
||||
event = RequestStreamEventPayloadPart(value=payload)
|
||||
await self._stream.input_stream.send(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send data: {e}")
|
||||
raise
|
||||
|
||||
async def send_audio_chunk(self, audio_bytes: bytes):
|
||||
"""Send a chunk of audio data to the stream.
|
||||
|
||||
Convenience method for sending audio data. Automatically sets the data
|
||||
type to "BINARY".
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio bytes to send (e.g., PCM audio data).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
await self.send_data(audio_bytes, data_type="BINARY")
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Send text data to the stream.
|
||||
|
||||
Convenience method for sending text data. Automatically encodes the text
|
||||
as UTF-8 and sets the data type to "UTF8".
|
||||
|
||||
Args:
|
||||
text: Text string to send.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
await self.send_data(text.encode("utf-8"), data_type="UTF8")
|
||||
|
||||
async def send_json(self, data: dict):
|
||||
"""Send JSON data to the stream.
|
||||
|
||||
Convenience method for sending JSON-encoded messages. Useful for control
|
||||
messages like KeepAlive or CloseStream. Automatically serializes the
|
||||
dictionary to JSON, encodes as UTF-8, and sets the data type to "UTF8".
|
||||
|
||||
Args:
|
||||
data: Dictionary to send as JSON (e.g., {"type": "KeepAlive"}).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
import json
|
||||
|
||||
await self.send_data(json.dumps(data).encode("utf-8"), data_type="UTF8")
|
||||
|
||||
async def receive_response(self) -> Optional[ResponseStreamEvent]:
|
||||
"""Receive a response from the stream.
|
||||
|
||||
Blocks until a response is available from the SageMaker endpoint. Returns
|
||||
None when the stream is closed.
|
||||
|
||||
Returns:
|
||||
The response event containing payload data, or None if stream is closed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active.
|
||||
"""
|
||||
if not self._is_active or not self._output_stream:
|
||||
raise RuntimeError("BiDi session not active")
|
||||
|
||||
try:
|
||||
result = await self._output_stream.receive()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to receive response: {e}")
|
||||
raise
|
||||
|
||||
async def close_session(self):
|
||||
"""Close the bidirectional streaming session.
|
||||
|
||||
Gracefully closes the input stream and marks the session as inactive.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if not self._is_active:
|
||||
return
|
||||
|
||||
logger.debug("Closing BiDi session...")
|
||||
self._is_active = False
|
||||
|
||||
try:
|
||||
if self._stream:
|
||||
await self._stream.input_stream.close()
|
||||
logger.debug("BiDi session closed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing BiDi session: {e}")
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if the session is currently active.
|
||||
|
||||
Returns:
|
||||
True if session is active, False otherwise.
|
||||
"""
|
||||
return self._is_active
|
||||
@@ -58,7 +58,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
api_key: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
region: Optional[str] = "us-east-1",
|
||||
sample_rate: int = 16000,
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
@@ -69,7 +69,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
|
||||
aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
|
||||
aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
|
||||
region: AWS region for the service.
|
||||
region: AWS region for the service. Defaults to "us-east-1".
|
||||
sample_rate: Audio sample rate in Hz. Must be 8000 or 16000. Defaults to 16000.
|
||||
language: Language for transcription. Defaults to English.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
@@ -140,7 +140,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
@@ -181,7 +182,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
@@ -198,11 +200,13 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
@@ -285,7 +289,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
@@ -305,7 +310,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._ws_client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -523,15 +529,15 @@ class AWSTranscribeSTTService(STTService):
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
await self.push_error(error_msg=f"AWS Transcribe error: {error_msg}")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(ErrorFrame(f"AWS Transcribe error: {error_msg}"))
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
await self.push_error(
|
||||
error_msg=f"WebSocket connection closed in receive loop", exception=e
|
||||
)
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
break
|
||||
|
||||
@@ -209,6 +209,15 @@ class AWSPollyTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AWS TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AWS's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to AWS Polly language format.
|
||||
|
||||
@@ -312,6 +321,7 @@ class AWSPollyTTSService(TTSService):
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
logger.exception(f"{self} error generating TTS: {error}")
|
||||
error_message = f"AWS Polly TTS error: {str(error)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
logger.error(f"{self} error: image generation timed out")
|
||||
yield ErrorFrame("Image generation timed out")
|
||||
return
|
||||
|
||||
@@ -103,6 +104,7 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -61,5 +61,5 @@ class AzureRealtimeLLMService(OpenAIRealtimeLLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"initialization error: {e}", exception=e)
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
@@ -121,7 +121,8 @@ class AzureSTTService(STTService):
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech recognition service.
|
||||
@@ -150,9 +151,8 @@ class AzureSTTService(STTService):
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Uncaught exception during initialization: {e}", exception=e
|
||||
)
|
||||
logger.error(f"{self} exception during initialization: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech recognition service.
|
||||
|
||||
@@ -151,6 +151,15 @@ class AzureBaseTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Azure TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Azure's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Azure language format.
|
||||
|
||||
@@ -327,6 +336,7 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
@@ -354,13 +364,15 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
@@ -437,6 +449,5 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == CancellationReason.Error:
|
||||
yield ErrorFrame(
|
||||
error=f"Unknown error occurred: {cancellation_details.error_details}"
|
||||
)
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
yield ErrorFrame(error=f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
@@ -276,7 +276,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
@@ -284,7 +285,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from Cartesia STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -317,7 +319,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
|
||||
elif data["type"] == "error":
|
||||
error_msg = data.get("message", "Unknown error")
|
||||
await self.push_error(error_msg=error_msg)
|
||||
logger.error(f"Cartesia error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(error=error_msg))
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
|
||||
@@ -497,7 +497,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -509,7 +510,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
@@ -562,12 +564,13 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self.append_to_audio_context(msg["context_id"], frame)
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
self._context_id = None
|
||||
else:
|
||||
await self.push_error(error_msg=f"Error, unknown message type: {msg}")
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
@@ -605,14 +608,16 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class CartesiaHttpTTSService(TTSService):
|
||||
@@ -803,7 +808,8 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
yield ErrorFrame(error=f"Cartesia API error: {error_text}")
|
||||
logger.error(f"Cartesia API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(error=f"Cartesia API error: {error_text}"))
|
||||
raise Exception(f"Cartesia API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -819,7 +825,8 @@ class CartesiaHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -6,9 +6,7 @@
|
||||
|
||||
"""Deepgram Flux speech-to-text service implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
@@ -96,7 +94,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
mip_opt_out: Optional. Opts out requests from the Deepgram Model Improvement Program
|
||||
(default False).
|
||||
tag: List of tags to label requests for identification during usage reporting.
|
||||
min_confidence: Optional. Minimum confidence required confidence to create a TranscriptionFrame
|
||||
"""
|
||||
|
||||
eager_eot_threshold: Optional[float] = None
|
||||
@@ -105,7 +102,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
keyterm: list = []
|
||||
mip_opt_out: Optional[bool] = None
|
||||
tag: list = []
|
||||
min_confidence: Optional[float] = None # New parameter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -150,17 +146,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
params=params
|
||||
)
|
||||
"""
|
||||
# Note: For DeepgramFluxSTTService, differently from other processes, we need to create
|
||||
# the _receive_task inside _connect_websocket, because the websocket should only be
|
||||
# considered connected and ready to send audio once we receive from Flux the message
|
||||
# which confirms the connection has been established.
|
||||
# If we try to keep the logic reconnect_on_error, when receiving a message, the
|
||||
# _receive_task_handler would try to reconnect in case of error, invoking the
|
||||
# _connect_websocket again and leading to a case where the first _receive_task_handler
|
||||
# was never destroyed.
|
||||
# So we can keep it here as false, because inside the method send_with_retry, it will
|
||||
# already try to reconnect if needed.
|
||||
super().__init__(sample_rate=sample_rate, reconnect_on_error=False, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
@@ -177,13 +163,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
self._register_event_handler("on_end_of_turn")
|
||||
self._register_event_handler("on_eager_end_of_turn")
|
||||
self._register_event_handler("on_update")
|
||||
self._connection_established_event = asyncio.Event()
|
||||
# Watchdog task to prevent dangling tasks
|
||||
# If we stop sending audio to Flux after we have received that the User has started speaking
|
||||
# we never receive the user stopped speaking event unless we resume sending audio to it.
|
||||
self._last_stt_time = None
|
||||
self._watchdog_task = None
|
||||
self._user_is_speaking = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to WebSocket and start background tasks.
|
||||
@@ -193,6 +172,9 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from WebSocket and clean up tasks.
|
||||
|
||||
@@ -200,32 +182,21 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
and cleans up resources to prevent memory leaks.
|
||||
"""
|
||||
try:
|
||||
# Cancel background tasks BEFORE closing websocket
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=2.0)
|
||||
self._receive_task = None
|
||||
|
||||
# Now close the websocket
|
||||
await self._disconnect_websocket()
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Reset state only after everything is cleaned up
|
||||
self._websocket = None
|
||||
|
||||
async def _send_silence(self, duration_secs: float = 0.5):
|
||||
"""Send a block of silence of the specified duration (default 500 ms)."""
|
||||
sample_width = 2 # bytes per sample for 16-bit PCM
|
||||
num_channels = 1 # mono
|
||||
num_samples = int(self.sample_rate * duration_secs)
|
||||
silence = b"\x00" * (num_samples * sample_width * num_channels)
|
||||
await self._websocket.send(silence)
|
||||
|
||||
async def _watchdog_task_handler(self):
|
||||
while self._websocket and self._websocket.state is State.OPEN:
|
||||
now = time.monotonic()
|
||||
# More than 500 ms without sending new audio to Flux
|
||||
if self._user_is_speaking and self._last_stt_time and now - self._last_stt_time > 0.5:
|
||||
logger.warning("Sending silence to Flux to prevent dangling task")
|
||||
await self._send_silence()
|
||||
self._last_stt_time = time.monotonic()
|
||||
# check every 100ms
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish WebSocket connection to API.
|
||||
|
||||
@@ -237,35 +208,15 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
self._connection_established_event.clear()
|
||||
self._user_is_speaking = False
|
||||
self._websocket = await websocket_connect(
|
||||
self._websocket_url,
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
# Creating the receiver task
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(
|
||||
self._receive_task_handler(self._report_error)
|
||||
)
|
||||
|
||||
# Creating the watchdog task
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
# Now wait for the connection established event
|
||||
logger.debug("WebSocket connected, waiting for server confirmation...")
|
||||
await self._connection_established_event.wait()
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -276,16 +227,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
metrics collection. Handles disconnection errors gracefully.
|
||||
"""
|
||||
try:
|
||||
# Cancel background tasks BEFORE closing websocket
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=2.0)
|
||||
self._receive_task = None
|
||||
if self._watchdog_task:
|
||||
await self.cancel_task(self._watchdog_task, timeout=2.0)
|
||||
self._watchdog_task = None
|
||||
self._last_stt_time = None
|
||||
|
||||
self._connection_established_event.clear()
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
@@ -293,7 +234,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from Deepgram Flux Websocket")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -303,13 +245,10 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
|
||||
This signals to the server that no more audio data will be sent.
|
||||
"""
|
||||
try:
|
||||
if self._websocket:
|
||||
logger.debug("Sending CloseStream message to Deepgram Flux")
|
||||
message = {"type": "CloseStream"}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error sending closeStream: {e}", exception=e)
|
||||
if self._websocket:
|
||||
logger.debug("Sending CloseStream message to Deepgram Flux")
|
||||
message = {"type": "CloseStream"}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -396,13 +335,15 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
are issues sending the audio data.
|
||||
"""
|
||||
if not self._websocket:
|
||||
logger.error("Not connected to Deepgram Flux.")
|
||||
yield ErrorFrame("Not connected to Deepgram Flux.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._last_stt_time = time.monotonic()
|
||||
await self.send_with_retry(audio, self._report_error)
|
||||
await self._websocket.send(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
yield None
|
||||
@@ -479,7 +420,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
# Skip malformed messages
|
||||
continue
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
# Error will be handled inside WebsocketService->_receive_task_handler
|
||||
raise
|
||||
else:
|
||||
@@ -521,8 +463,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
transcription processing.
|
||||
"""
|
||||
logger.info("Connected to Flux - ready to stream audio")
|
||||
# Notify connection is established
|
||||
self._connection_established_event.set()
|
||||
|
||||
async def _handle_fatal_error(self, data: Dict[str, Any]):
|
||||
"""Handle fatal error messages from Deepgram Flux.
|
||||
@@ -590,7 +530,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
transcript: maybe the first few words of the turn.
|
||||
"""
|
||||
logger.debug("User started speaking")
|
||||
self._user_is_speaking = True
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.start_metrics()
|
||||
@@ -611,22 +550,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.trace(f"Received event TurnResumed: {event}")
|
||||
await self._call_event_handler("on_turn_resumed")
|
||||
|
||||
def _calculate_average_confidence(self, transcript_data) -> Optional[float]:
|
||||
"""Calculate the average confidence from transcript data.
|
||||
|
||||
Return None if the data is missing or invalid.
|
||||
"""
|
||||
# Example: Assume transcript_data has a list of words with confidence
|
||||
words = transcript_data.get("words")
|
||||
if not words or not isinstance(words, list):
|
||||
return None
|
||||
confidences = [
|
||||
w.get("confidence") for w in words if isinstance(w.get("confidence"), (float, int))
|
||||
]
|
||||
if not confidences:
|
||||
return None
|
||||
return sum(confidences) / len(confidences)
|
||||
|
||||
async def _handle_end_of_turn(self, transcript: str, data: Dict[str, Any]):
|
||||
"""Handle EndOfTurn events from Deepgram Flux.
|
||||
|
||||
@@ -646,26 +569,16 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
data: The TurnInfo message data containing event type, transcript and some extra metadata.
|
||||
"""
|
||||
logger.debug("User stopped speaking")
|
||||
self._user_is_speaking = False
|
||||
|
||||
# Compute the average confidence
|
||||
average_confidence = self._calculate_average_confidence(data)
|
||||
|
||||
if not self._params.min_confidence or average_confidence > self._params.min_confidence:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcription confidence below min_confidence threshold: {average_confidence}"
|
||||
)
|
||||
|
||||
)
|
||||
await self._handle_transcription(transcript, True, self._language)
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(UserStoppedSpeakingFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
@@ -233,14 +233,7 @@ class DeepgramSTTService(STTService):
|
||||
)
|
||||
|
||||
if not await self._connection.start(options=self._settings, addons=self._addons):
|
||||
await self.push_error(error_msg=f"Unable to connect to Deepgram")
|
||||
else:
|
||||
headers = {
|
||||
k: v
|
||||
for k, v in self._connection._socket.response.headers.items()
|
||||
if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
logger.error(f"{self}: unable to connect to Deepgram")
|
||||
|
||||
async def _disconnect(self):
|
||||
if await self._connection.is_connected():
|
||||
@@ -263,7 +256,7 @@ class DeepgramSTTService(STTService):
|
||||
async def _on_error(self, *args, **kwargs):
|
||||
error: ErrorResponse = kwargs["error"]
|
||||
logger.warning(f"{self} connection error, will retry: {error}")
|
||||
await self.push_error(error_msg=f"{error}")
|
||||
await self.push_error(ErrorFrame(error=f"{error}"))
|
||||
await self.stop_all_metrics()
|
||||
# NOTE(aleix): we don't disconnect (i.e. call finish on the connection)
|
||||
# because this triggers more errors internally in the Deepgram SDK. So,
|
||||
|
||||
@@ -1,444 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat STT service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time transcription with support for interim results, multiple
|
||||
languages, and various Deepgram features.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from deepgram import LiveOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramSageMakerSTTService(STTService):
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech recognition using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
transcription with support for interim results, speaker diarization, and
|
||||
multiple languages.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- Deepgram SDK for LiveOptions configuration
|
||||
|
||||
Example::
|
||||
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
live_options=LiveOptions(
|
||||
model="nova-3",
|
||||
language="en",
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram model
|
||||
deployed (e.g., "my-deepgram-nova-3-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses value from
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions for detailed configuration. If None,
|
||||
uses sensible defaults (nova-3 model, English, interim results enabled).
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
# Create default options similar to DeepgramSTTService
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language=Language.EN,
|
||||
model="nova-3",
|
||||
channels=1,
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
)
|
||||
|
||||
# Merge with provided options
|
||||
merged_options = default_options.to_dict()
|
||||
if live_options:
|
||||
default_model = default_options.model
|
||||
merged_options.update(live_options.to_dict())
|
||||
# Handle the "None" string bug from deepgram-sdk
|
||||
if "model" in merged_options and merged_options["model"] == "None":
|
||||
merged_options["model"] = default_model
|
||||
|
||||
# Convert Language enum to string if needed
|
||||
if "language" in merged_options and isinstance(merged_options["language"], Language):
|
||||
merged_options["language"] = merged_options["language"].value
|
||||
|
||||
self.set_model_name(merged_options["model"])
|
||||
self._settings = merged_options
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the Deepgram model and reconnect.
|
||||
|
||||
Disconnects from the current session, updates the model setting, and
|
||||
establishes a new connection with the updated model.
|
||||
|
||||
Args:
|
||||
model: The Deepgram model name to use (e.g., "nova-3").
|
||||
"""
|
||||
await super().set_model(model)
|
||||
logger.info(f"Switching STT model to: [{model}]")
|
||||
self._settings["model"] = model
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the recognition language and reconnect.
|
||||
|
||||
Disconnects from the current session, updates the language setting, and
|
||||
establishes a new connection with the updated language.
|
||||
|
||||
Args:
|
||||
language: The language to use for speech recognition (e.g., Language.EN,
|
||||
Language.ES).
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._settings["language"] = language
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["sample_rate"] = self.sample_rate
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Deepgram for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via BiDi stream callbacks).
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram query string from settings, creates the BiDi client,
|
||||
starts the streaming session, and launches background tasks for processing
|
||||
responses and sending KeepAlive messages.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram on SageMaker...")
|
||||
|
||||
# Update sample rate in settings
|
||||
self._settings["sample_rate"] = self.sample_rate
|
||||
|
||||
# Build query string from settings, converting booleans to strings
|
||||
query_params = {}
|
||||
for key, value in self._settings.items():
|
||||
if value is not None:
|
||||
# Convert boolean values to lowercase strings for Deepgram API
|
||||
if isinstance(value, bool):
|
||||
query_params[key] = str(value).lower()
|
||||
else:
|
||||
query_params[key] = str(value)
|
||||
|
||||
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
|
||||
|
||||
# Create BiDi client
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start the session
|
||||
await self._client.start_session()
|
||||
|
||||
# Start processing responses in the background
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
# Start keepalive task to maintain connection
|
||||
self._keepalive_task = self.create_task(self._send_keepalive())
|
||||
|
||||
logger.debug("Connected to Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a CloseStream message to Deepgram, cancels background tasks
|
||||
(KeepAlive and response processing), and closes the BiDi session.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram on SageMaker...")
|
||||
|
||||
# Send CloseStream message to Deepgram
|
||||
try:
|
||||
await self._client.send_json({"type": "CloseStream"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send CloseStream message: {e}")
|
||||
|
||||
# Cancel keepalive task
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Cancel response processing task
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
# Close the BiDi session
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_keepalive(self):
|
||||
"""Send periodic KeepAlive messages to maintain the connection.
|
||||
|
||||
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
|
||||
connection is active. This prevents the connection from timing out during
|
||||
periods of silence.
|
||||
"""
|
||||
while self._client and self._client.is_active:
|
||||
await asyncio.sleep(5)
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "KeepAlive"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send KeepAlive: {e}")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream, decodes the payload,
|
||||
parses JSON responses from Deepgram, and processes transcription results.
|
||||
Runs as a background task until the connection is closed or cancelled.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
# Check if this is a PayloadPart with bytes
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
|
||||
try:
|
||||
# Parse JSON response from Deepgram
|
||||
parsed = json.loads(response_data)
|
||||
|
||||
# Extract and process transcript if available
|
||||
if "channel" in parsed:
|
||||
await self._handle_transcript_response(parsed)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON response: {response_data}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
async def _handle_transcript_response(self, parsed: dict):
|
||||
"""Handle a transcript response from Deepgram.
|
||||
|
||||
Extracts the transcript text, determines if it's final or interim, extracts
|
||||
language information, and pushes the appropriate frame (TranscriptionFrame
|
||||
or InterimTranscriptionFrame) downstream.
|
||||
|
||||
Args:
|
||||
parsed: The parsed JSON response from Deepgram containing channel,
|
||||
alternatives, transcript, and metadata.
|
||||
"""
|
||||
alternatives = parsed.get("channel", {}).get("alternatives", [])
|
||||
if not alternatives or not alternatives[0].get("transcript"):
|
||||
return
|
||||
|
||||
transcript = alternatives[0]["transcript"]
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
# Stop TTFB metrics on first transcript
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
speech_final = parsed.get("speech_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
if alternatives[0].get("languages"):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final and speech_final:
|
||||
# Final transcription
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# Interim transcription
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability and tracing
|
||||
integration. The actual transcription processing is handled by the parent
|
||||
class and observers.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: The detected language of the transcription, if available.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start TTFB and processing metrics collection."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Deepgram SageMaker-specific handling.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Send finalize message to Deepgram when user stops speaking
|
||||
# This tells Deepgram to flush any remaining audio and return final results
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
@@ -10,45 +10,35 @@ This module provides integration with Deepgram's text-to-speech API
|
||||
for generating speech from text using various voice models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import TTSService, WebsocketTTSService
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
from deepgram import DeepgramClient, DeepgramClientOptions, SpeakOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramWebsocketTTSService, you need to `pip install pipecat-ai[deepgram]`."
|
||||
)
|
||||
logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramTTSService(WebsocketTTSService):
|
||||
"""Deepgram WebSocket-based text-to-speech service.
|
||||
class DeepgramTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service.
|
||||
|
||||
Provides real-time text-to-speech synthesis using Deepgram's WebSocket API.
|
||||
Supports streaming audio generation with interruption handling via the Clear
|
||||
message for conversational AI use cases.
|
||||
Provides text-to-speech synthesis using Deepgram's streaming API.
|
||||
Supports various voice models and audio encoding formats with
|
||||
configurable sample rates and quality settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -56,220 +46,51 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
base_url: str = "wss://api.deepgram.com",
|
||||
base_url: str = "",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram WebSocket TTS service.
|
||||
"""Initialize the Deepgram TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Deepgram API key for authentication.
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
base_url: WebSocket base URL for Deepgram API. Defaults to "wss://api.deepgram.com".
|
||||
base_url: Custom base URL for Deepgram API. Uses default if empty.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService class.
|
||||
**kwargs: Additional arguments passed to parent TTSService class.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._settings = {
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
|
||||
self._receive_task = None
|
||||
client_options = DeepgramClientOptions(url=base_url)
|
||||
self._deepgram_client = DeepgramClient(api_key, config=client_options)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram WebSocket TTS service supports metrics generation.
|
||||
True, as Deepgram TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram WebSocket TTS service.
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
Returns:
|
||||
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# When the LLM finishes responding, flush any remaining text in Deepgram's buffer
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Deepgram WebSocket and start receive task."""
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Deepgram WebSocket and clean up tasks."""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Deepgram WebSocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to Deepgram WebSocket")
|
||||
|
||||
# Build WebSocket URL with query parameters
|
||||
params = []
|
||||
params.append(f"model={self._voice_id}")
|
||||
params.append(f"encoding={self._settings['encoding']}")
|
||||
params.append(f"sample_rate={self.sample_rate}")
|
||||
|
||||
url = f"{self._base_url}/v1/speak?{'&'.join(params)}"
|
||||
|
||||
headers = {"Authorization": f"Token {self._api_key}"}
|
||||
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close WebSocket connection and reset state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Deepgram WebSocket")
|
||||
# Send Close message to gracefully close the connection
|
||||
await self._websocket.send(json.dumps({"type": "Close"}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Send Clear message to stop current audio generation
|
||||
if self._websocket:
|
||||
try:
|
||||
clear_msg = {"type": "Clear"}
|
||||
await self._websocket.send(json.dumps(clear_msg))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process messages from Deepgram WebSocket."""
|
||||
async for message in self._get_websocket():
|
||||
if isinstance(message, bytes):
|
||||
# Binary message contains audio data
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(message, self.sample_rate, 1)
|
||||
await self.push_frame(frame)
|
||||
elif isinstance(message, str):
|
||||
# Text message contains metadata or control messages
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
msg_type = msg.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {msg}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {msg}")
|
||||
# Flushed indicates the end of audio generation for the current buffer
|
||||
# This happens after flush_audio() is called
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {msg}")
|
||||
# Buffer has been cleared after interruption
|
||||
# TTSStoppedFrame will be sent by the interruption handler
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: {msg.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {msg}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._websocket:
|
||||
try:
|
||||
flush_msg = {"type": "Flush"}
|
||||
await self._websocket.send(json.dumps(flush_msg))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's WebSocket TTS API.
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
@@ -279,27 +100,33 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self.sample_rate,
|
||||
container="none",
|
||||
)
|
||||
|
||||
try:
|
||||
# Reconnect if the websocket is closed
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
response = await self._deepgram_client.speak.asyncrest.v("1").stream_raw(
|
||||
{"text": text}, options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Send text message to Deepgram
|
||||
# Note: We don't send Flush here - that should only be sent when the
|
||||
# LLM finishes a complete response via flush_audio()
|
||||
speak_msg = {"type": "Speak", "text": text}
|
||||
await self._get_websocket().send(json.dumps(speak_msg))
|
||||
async for data in response.aiter_bytes():
|
||||
await self.stop_ttfb_metrics()
|
||||
if data:
|
||||
yield TTSAudioRawFrame(audio=data, sample_rate=self.sample_rate, num_channels=1)
|
||||
|
||||
# The audio frames will be handled in _receive_messages
|
||||
yield None
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class DeepgramHttpTTSService(TTSService):
|
||||
@@ -350,6 +177,15 @@ class DeepgramHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
@@ -409,4 +245,5 @@ class DeepgramHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
|
||||
@@ -351,7 +351,8 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
def audio_format_from_sample_rate(sample_rate: int) -> str:
|
||||
@@ -415,8 +416,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Only used when commit_strategy is VAD. None uses ElevenLabs default.
|
||||
min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms).
|
||||
Only used when commit_strategy is VAD. None uses ElevenLabs default.
|
||||
include_timestamps: Whether to include word-level timestamps in transcripts.
|
||||
enable_logging: Whether to enable logging on ElevenLabs' side.
|
||||
"""
|
||||
|
||||
language_code: Optional[str] = None
|
||||
@@ -425,8 +424,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
vad_threshold: Optional[float] = None
|
||||
min_speech_duration_ms: Optional[int] = None
|
||||
min_silence_duration_ms: Optional[int] = None
|
||||
include_timestamps: bool = False
|
||||
enable_logging: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -462,8 +459,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
self._audio_format = "" # initialized in start()
|
||||
self._receive_task = None
|
||||
|
||||
self._settings = {"language": params.language_code}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate processing metrics.
|
||||
|
||||
@@ -482,13 +477,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Changing language requires reconnecting to the WebSocket.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
new_language = (
|
||||
language_to_elevenlabs_language(language)
|
||||
if isinstance(language, Language)
|
||||
else language
|
||||
)
|
||||
self._params.language_code = new_language
|
||||
self._settings["language"] = new_language
|
||||
self._params.language_code = language.value if isinstance(language, Language) else language
|
||||
# Reconnect with new settings
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -597,6 +586,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"ElevenLabs Realtime STT error: {str(e)}")
|
||||
|
||||
yield None
|
||||
@@ -630,16 +620,10 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if self._params.language_code:
|
||||
params.append(f"language_code={self._params.language_code}")
|
||||
|
||||
params.append(f"audio_format={self._audio_format}")
|
||||
params.append(f"encoding={self._audio_format}")
|
||||
params.append(f"sample_rate={self.sample_rate}")
|
||||
params.append(f"commit_strategy={self._params.commit_strategy.value}")
|
||||
|
||||
# Add optional parameters
|
||||
if self._params.include_timestamps:
|
||||
params.append(f"include_timestamps={str(self._params.include_timestamps).lower()}")
|
||||
|
||||
if self._params.enable_logging:
|
||||
params.append(f"enable_logging={str(self._params.enable_logging).lower()}")
|
||||
|
||||
# Add VAD parameters if using VAD commit strategy and values are specified
|
||||
if self._params.commit_strategy == CommitStrategy.VAD:
|
||||
if self._params.vad_silence_threshold_secs is not None:
|
||||
@@ -661,9 +645,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug("Connected to ElevenLabs Realtime STT")
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Unable to connect to ElevenLabs Realtime STT: {e}", exception=e
|
||||
)
|
||||
logger.error(f"{self}: unable to connect to ElevenLabs Realtime STT: {e}")
|
||||
await self.push_error(ErrorFrame(f"Connection error: {str(e)}"))
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Disconnect from ElevenLabs Realtime STT WebSocket."""
|
||||
@@ -672,7 +655,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from ElevenLabs Realtime STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -729,20 +712,15 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
elif message_type == "committed_transcript_with_timestamps":
|
||||
await self._on_committed_transcript_with_timestamps(data)
|
||||
|
||||
elif message_type == "error":
|
||||
error_msg = data.get("error", "Unknown error")
|
||||
logger.error(f"ElevenLabs error: {error_msg}")
|
||||
await self.push_error(error_msg=f"Error: {error_msg}")
|
||||
elif message_type == "input_error":
|
||||
error_msg = data.get("error", "Unknown input error")
|
||||
logger.error(f"ElevenLabs input error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Input error: {error_msg}"))
|
||||
|
||||
elif message_type == "auth_error":
|
||||
error_msg = data.get("error", "Authentication error")
|
||||
logger.error(f"ElevenLabs auth error: {error_msg}")
|
||||
await self.push_error(error_msg=f"Auth error: {error_msg}")
|
||||
|
||||
elif message_type == "quota_exceeded_error":
|
||||
error_msg = data.get("error", "Quota exceeded")
|
||||
logger.error(f"ElevenLabs quota exceeded: {error_msg}")
|
||||
await self.push_error(error_msg=f"Quota exceeded: {error_msg}")
|
||||
elif message_type in ["auth_error", "quota_exceeded", "transcriber_error", "error"]:
|
||||
error_msg = data.get("error", data.get("message", "Unknown error"))
|
||||
logger.error(f"ElevenLabs error ({message_type}): {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"{message_type}: {error_msg}"))
|
||||
|
||||
else:
|
||||
logger.debug(f"Unknown message type: {message_type}")
|
||||
@@ -787,11 +765,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Args:
|
||||
data: Committed transcript data.
|
||||
"""
|
||||
# If timestamps are enabled, skip this message and wait for the
|
||||
# committed_transcript_with_timestamps message which contains all the data
|
||||
if self._params.include_timestamps:
|
||||
return
|
||||
|
||||
text = data.get("text", "").strip()
|
||||
if not text:
|
||||
return
|
||||
@@ -819,18 +792,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
async def _on_committed_transcript_with_timestamps(self, data: dict):
|
||||
"""Handle committed transcript with word-level timestamps.
|
||||
|
||||
This message is sent when include_timestamps=true. The result data includes:
|
||||
- text: The transcribed text
|
||||
- language_code: Detected language (if available)
|
||||
- words: Array of word objects with timing information:
|
||||
- text: The word text
|
||||
- start: Start time in seconds
|
||||
- end: End time in seconds
|
||||
- type: "word" or "spacing"
|
||||
- speaker_id: Speaker identifier (if available)
|
||||
- logprob: Log probability score (if available)
|
||||
- characters: Array of character strings (if available)
|
||||
|
||||
Args:
|
||||
data: Committed transcript data with timestamps.
|
||||
"""
|
||||
@@ -838,24 +799,9 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Get language if provided
|
||||
language = data.get("language_code")
|
||||
|
||||
logger.debug(f"Committed transcript with timestamps: [{text}]")
|
||||
logger.trace(f"Timestamps: {data.get('words', [])}")
|
||||
|
||||
await self._handle_transcription(text, True, language)
|
||||
|
||||
# This message is sent after committed_transcript when include_timestamps=true.
|
||||
# It contains the full transcript data including text and word-level timestamps.
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
# This is sent after the committed_transcript, so we don't need to
|
||||
# push another TranscriptionFrame, but we could use the timestamps
|
||||
# for additional processing if needed in the future
|
||||
|
||||
@@ -424,7 +424,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
@@ -535,8 +536,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
self._websocket = None
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
@@ -551,7 +553,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from ElevenLabs")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
@@ -581,7 +584,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
self._partial_word = ""
|
||||
@@ -736,13 +740,15 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
else:
|
||||
await self._send_text(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
self._started = False
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class ElevenLabsHttpTTSService(WordTTSService):
|
||||
@@ -1037,6 +1043,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"{self} error: {error_text}")
|
||||
yield ErrorFrame(error=f"ElevenLabs API error: {error_text}")
|
||||
return
|
||||
|
||||
@@ -1084,7 +1091,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
logger.warning(f"Failed to parse JSON from stream: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
continue
|
||||
|
||||
# After processing all chunks, emit any remaining partial word
|
||||
@@ -1108,7 +1116,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
self._previous_text = text
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame
|
||||
|
||||
@@ -110,6 +110,7 @@ class FalImageGenService(ImageGenService):
|
||||
image_url = response["images"][0]["url"] if response else None
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -290,4 +290,5 @@ class FalSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -76,7 +76,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
api_key: str,
|
||||
reference_id: Optional[str] = None, # This is the voice ID
|
||||
model: Optional[str] = None, # Deprecated
|
||||
model_id: str = "s1",
|
||||
model_id: str = "speech-1.5",
|
||||
output_format: FishAudioOutputFormat = "pcm",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -93,7 +93,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
The `model` parameter is deprecated and will be removed in version 0.1.0.
|
||||
Use `reference_id` instead to specify the voice model.
|
||||
|
||||
model_id: Specify which Fish Audio TTS model to use (e.g. "s1")
|
||||
model_id: Specify which Fish Audio TTS model to use (e.g. "speech-1.5")
|
||||
output_format: Audio output format. Defaults to "pcm".
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
@@ -159,6 +159,15 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Fish Audio TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Fish Audio's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model and reconnect.
|
||||
|
||||
@@ -228,7 +237,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -242,7 +252,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
await self._websocket.send(ormsgpack.packb(stop_message))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
@@ -284,7 +295,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -320,7 +332,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
flush_message = {"event": "flush"}
|
||||
await self._get_websocket().send(ormsgpack.packb(flush_message))
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -328,4 +341,5 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -468,7 +468,8 @@ class GladiaSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._connection_active = False
|
||||
|
||||
if not self._should_reconnect:
|
||||
@@ -558,7 +559,8 @@ class GladiaSTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.debug("Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
@@ -621,7 +623,8 @@ class GladiaSTTService(STTService):
|
||||
# Expected when closing the connection
|
||||
pass
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _maybe_reconnect(self) -> bool:
|
||||
"""Handle exponential backoff reconnection logic."""
|
||||
@@ -629,9 +632,7 @@ class GladiaSTTService(STTService):
|
||||
return False
|
||||
self._reconnection_attempts += 1
|
||||
if self._reconnection_attempts > self._max_reconnection_attempts:
|
||||
await self.push_error(
|
||||
error_msg=f"Max reconnection attempts ({self._max_reconnection_attempts}) reached",
|
||||
)
|
||||
logger.error(f"Max reconnection attempts ({self._max_reconnection_attempts}) reached")
|
||||
self._should_reconnect = False
|
||||
return False
|
||||
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
|
||||
|
||||
@@ -1175,7 +1175,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._connection_task = self.create_task(self._connection_task_handler(config=config))
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Initialization error: {e}", exception=e)
|
||||
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}"))
|
||||
|
||||
async def _connection_task_handler(self, config: LiveConnectConfig):
|
||||
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
|
||||
@@ -1252,11 +1252,11 @@ class GeminiLiveLLMService(LLMService):
|
||||
)
|
||||
|
||||
if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
error_msg = (
|
||||
logger.error(
|
||||
f"Max consecutive failures ({MAX_CONSECUTIVE_FAILURES}) reached, "
|
||||
"treating as fatal error"
|
||||
)
|
||||
await self.push_error(error_msg=error_msg, exception=error)
|
||||
await self.push_error(ErrorFrame(error=f"{self} Error in receive loop: {error}"))
|
||||
return False
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1284,7 +1284,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._completed_tool_calls = set()
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
"""Send user audio frame to Gemini Live API."""
|
||||
@@ -1453,6 +1453,8 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._bot_text_buffer += text
|
||||
self._search_result_buffer += text # Also accumulate for grounding
|
||||
frame = LLMTextFrame(text=text)
|
||||
# Gemini Live text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
@@ -1723,8 +1725,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_input_tokens=usage.cached_content_token_count,
|
||||
reasoning_tokens=usage.thoughts_token_count,
|
||||
)
|
||||
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
@@ -1745,7 +1745,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
# state management, and that exponential backoff for retries can have
|
||||
# cost/stability implications for a service cluster, let's just treat a
|
||||
# send-side error as fatal.
|
||||
await self.push_error(error_msg=f"Send error: {error}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}", fatal=True))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
|
||||
@@ -110,6 +110,7 @@ class GoogleImageGenService(ImageGenService):
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not response or not response.generated_images:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
@@ -127,4 +128,5 @@ class GoogleImageGenService(ImageGenService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating image: {e}")
|
||||
yield ErrorFrame(f"Image generation error: {str(e)}")
|
||||
|
||||
@@ -793,7 +793,7 @@ class GoogleLLMService(LLMService):
|
||||
return
|
||||
generation_params.setdefault("thinking_config", {})["thinking_budget"] = 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unset thinking budget: {e}")
|
||||
logger.exception(f"Failed to unset thinking budget: {e}")
|
||||
|
||||
async def _stream_content(
|
||||
self, params_from_context: GeminiLLMInvocationParams
|
||||
@@ -920,7 +920,9 @@ class GoogleLLMService(LLMService):
|
||||
for part in candidate.content.parts:
|
||||
if not part.thought and part.text:
|
||||
search_result += part.text
|
||||
await self.push_frame(LLMTextFrame(part.text))
|
||||
frame = LLMTextFrame(part.text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
elif part.function_call:
|
||||
function_call = part.function_call
|
||||
id = function_call.id or str(uuid.uuid4())
|
||||
@@ -983,7 +985,7 @@ class GoogleLLMService(LLMService):
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
if grounding_metadata and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
|
||||
@@ -774,7 +774,8 @@ class GoogleSTTService(STTService):
|
||||
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _stream_audio(self):
|
||||
@@ -805,13 +806,15 @@ class GoogleSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process an audio chunk for STT transcription.
|
||||
@@ -899,7 +902,8 @@ class GoogleSTTService(STTService):
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a
|
||||
# timeout, propagate to _stream_audio to reconnect)
|
||||
raise
|
||||
|
||||
@@ -596,6 +596,15 @@ class GoogleHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Google TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Google's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language format.
|
||||
|
||||
@@ -737,6 +746,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -793,6 +803,15 @@ class GoogleBaseTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Google and Gemini TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Google's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language format.
|
||||
|
||||
@@ -995,7 +1014,9 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"TTS generation error: {str(e)}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
class GeminiTTSService(GoogleBaseTTSService):
|
||||
@@ -1245,5 +1266,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"Gemini TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -123,8 +123,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cache_read_input_tokens = None
|
||||
self._reasoning_tokens = None
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
@@ -139,8 +137,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
cache_read_input_tokens=self._cache_read_input_tokens,
|
||||
reasoning_tokens=self._reasoning_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
@@ -153,7 +149,7 @@ class GrokLLMService(OpenAILLMService):
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens, completion_tokens, and optional cached/reasoning tokens.
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
@@ -168,13 +164,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
|
||||
# Capture cached & reasoning tokens (these typically only appear once per request)
|
||||
if tokens.cache_read_input_tokens is not None:
|
||||
self._cache_read_input_tokens = tokens.cache_read_input_tokens
|
||||
|
||||
if tokens.reasoning_tokens is not None:
|
||||
self._reasoning_tokens = tokens.reasoning_tokens
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
|
||||
@@ -111,6 +111,15 @@ class GroqTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Groq TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Groq's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Groq's TTS API.
|
||||
@@ -146,6 +155,7 @@ class GroqTTSService(TTSService):
|
||||
bytes = w.readframes(num_frames)
|
||||
yield TTSAudioRawFrame(bytes, frame_rate, channels)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -179,7 +179,7 @@ class HeyGenClient:
|
||||
await self._task_manager.cancel_task(self._event_task)
|
||||
self._event_task = None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
logger.exception(f"Exception during cleanup: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame, audio_chunk_size: int) -> None:
|
||||
"""Start the client and establish all necessary connections.
|
||||
|
||||
@@ -14,14 +14,12 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import WordTTSService
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -31,7 +29,6 @@ try:
|
||||
PostedUtterance,
|
||||
PostedUtteranceVoiceWithId,
|
||||
)
|
||||
from hume.tts.types import TimestampMessage
|
||||
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Hume, you need to `pip install pipecat-ai[hume]`.")
|
||||
@@ -41,7 +38,7 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
|
||||
|
||||
|
||||
class HumeTTSService(WordTTSService):
|
||||
class HumeTTSService(TTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
|
||||
Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
|
||||
@@ -51,7 +48,6 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
- Generates speech from text using Hume TTS.
|
||||
- Streams PCM audio.
|
||||
- Supports word-level timestamps for precise audio-text synchronization.
|
||||
- Supports dynamic updates of voice and synthesis parameters at runtime.
|
||||
- Provides metrics for Time To First Byte (TTFB) and TTS usage.
|
||||
"""
|
||||
@@ -96,13 +92,7 @@ class HumeTTSService(WordTTSService):
|
||||
f"Hume TTS streams at {HUME_SAMPLE_RATE} Hz; configured sample_rate={sample_rate}"
|
||||
)
|
||||
|
||||
# WordTTSService sets push_text_frames=False by default, which we want
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key)
|
||||
self._params = params or HumeTTSService.InputParams()
|
||||
@@ -112,10 +102,6 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Track cumulative time for word timestamps across utterances
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Can generate metrics.
|
||||
|
||||
@@ -124,6 +110,15 @@ class HumeTTSService(WordTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Hume TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Hume's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame) -> None:
|
||||
"""Start the service.
|
||||
|
||||
@@ -131,27 +126,6 @@ class HumeTTSService(WordTTSService):
|
||||
frame: The start frame.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._reset_state()
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset internal state variables."""
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
|
||||
# Reset timing on interruption or stop
|
||||
self._reset_state()
|
||||
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def update_setting(self, key: str, value: Any) -> None:
|
||||
"""Runtime updates via `TTSUpdateSettingsFrame`.
|
||||
@@ -168,7 +142,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
if key_l == "voice_id":
|
||||
self.set_voice(str(value))
|
||||
logger.debug(f"HumeTTSService voice_id set to: {self.voice}")
|
||||
logger.info(f"HumeTTSService voice_id set to: {self.voice}")
|
||||
elif key_l == "description":
|
||||
self._params.description = None if value is None else str(value)
|
||||
elif key_l == "speed":
|
||||
@@ -181,7 +155,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Hume TTS with word timestamps.
|
||||
"""Generate speech from text using Hume TTS.
|
||||
|
||||
Args:
|
||||
text: The text to be synthesized.
|
||||
@@ -212,12 +186,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Start TTS sequence if not already started
|
||||
if not self._started:
|
||||
self.start_word_timestamps()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
yield TTSStartedFrame()
|
||||
|
||||
try:
|
||||
# Instant mode is always enabled here (not user-configurable)
|
||||
@@ -228,50 +197,23 @@ class HumeTTSService(WordTTSService):
|
||||
# Use version "2" by default if no description is provided
|
||||
# Version "1" is needed when description is used
|
||||
version = "1" if self._params.description is not None else "2"
|
||||
|
||||
# Track the duration of this utterance based on the last timestamp
|
||||
utterance_duration = 0.0
|
||||
|
||||
async for chunk in self._client.tts.synthesize_json_streaming(
|
||||
utterances=[utterance],
|
||||
format=pcm_fmt,
|
||||
instant_mode=True,
|
||||
version=version,
|
||||
include_timestamp_types=["word"], # Request word-level timestamps
|
||||
):
|
||||
# Process audio chunks
|
||||
audio_b64 = getattr(chunk, "audio", None)
|
||||
if audio_b64:
|
||||
await self.stop_ttfb_metrics()
|
||||
pcm_bytes = base64.b64decode(audio_b64)
|
||||
self._audio_bytes += pcm_bytes
|
||||
if not audio_b64:
|
||||
continue
|
||||
|
||||
# Buffer audio until we have enough to avoid glitches
|
||||
if len(self._audio_bytes) >= self.chunk_size:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=self._audio_bytes,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
self._audio_bytes = b""
|
||||
pcm_bytes = base64.b64decode(audio_b64)
|
||||
self._audio_bytes += pcm_bytes
|
||||
|
||||
# Process timestamp messages
|
||||
if isinstance(chunk, TimestampMessage):
|
||||
timestamp = chunk.timestamp
|
||||
if timestamp.type == "word":
|
||||
# Convert milliseconds to seconds and add cumulative offset
|
||||
word_start_time = self._cumulative_time + (timestamp.time.begin / 1000.0)
|
||||
word_end_time = self._cumulative_time + (timestamp.time.end / 1000.0)
|
||||
# Buffer audio until we have enough to avoid glitches
|
||||
if len(self._audio_bytes) < self.chunk_size:
|
||||
continue
|
||||
|
||||
# Track the maximum end time for this utterance
|
||||
utterance_duration = max(utterance_duration, word_end_time)
|
||||
|
||||
# Add word timestamp
|
||||
await self.add_word_timestamps([(timestamp.text, word_start_time)])
|
||||
|
||||
# Flush any remaining audio bytes
|
||||
if self._audio_bytes:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=self._audio_bytes,
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -282,13 +224,10 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Update cumulative time for next utterance
|
||||
if utterance_duration > 0:
|
||||
self._cumulative_time = utterance_duration
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Ensure TTFB timer is stopped even on early failures
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame via push_stop_frames
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -146,8 +146,6 @@ class InworldTTSService(TTSService):
|
||||
Parameters:
|
||||
temperature: Voice temperature control for synthesis variability (e.g., 1.1).
|
||||
Valid range: [0, 2]. Higher values increase variability.
|
||||
speaking_rate: Speaking speed control (range: [0.5, 1.5]). Defaults to 1.0 when
|
||||
unset.
|
||||
|
||||
Note:
|
||||
Language is automatically inferred from the input text by Inworld's TTS models,
|
||||
@@ -155,7 +153,6 @@ class InworldTTSService(TTSService):
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None # optional temperature control (range: [0, 2])
|
||||
speaking_rate: Optional[float] = None # optional speaking rate control (range: [0.5, 1.5])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -201,7 +198,6 @@ class InworldTTSService(TTSService):
|
||||
- Other formats as supported by Inworld API
|
||||
params: Optional input parameters for additional configuration. Use this to specify:
|
||||
- temperature: Voice temperature control for variability (range: [0, 2], e.g., 1.1, optional)
|
||||
- speaking_rate: Set desired speaking speed (range: [0.5, 1.5], optional)
|
||||
Language is automatically inferred from input text.
|
||||
**kwargs: Additional arguments passed to the parent TTSService class.
|
||||
|
||||
@@ -232,18 +228,15 @@ class InworldTTSService(TTSService):
|
||||
self._settings = {
|
||||
"voiceId": voice_id, # Voice selection from direct parameter
|
||||
"modelId": model, # TTS model selection from direct parameter
|
||||
"audioConfig": { # Audio format configuration
|
||||
"audioEncoding": encoding, # Format: LINEAR16, MP3, etc.
|
||||
"sampleRateHertz": 0, # Will be set in start() from parent service
|
||||
"audio_config": { # Audio format configuration
|
||||
"audio_encoding": encoding, # Format: LINEAR16, MP3, etc.
|
||||
"sample_rate_hertz": 0, # Will be set in start() from parent service
|
||||
},
|
||||
}
|
||||
|
||||
# Add optional temperature parameter if provided (valid range: [0, 2])
|
||||
if params and params.temperature is not None:
|
||||
self._settings["temperature"] = params.temperature
|
||||
# Add optional speaking rate if provided (valid range: [0.5, 1.5])
|
||||
if params and params.speaking_rate is not None:
|
||||
self._settings["audioConfig"]["speakingRate"] = params.speaking_rate
|
||||
|
||||
# Register voice and model with parent service for metrics and tracking
|
||||
self.set_voice(voice_id) # Used for logging and metrics
|
||||
@@ -257,6 +250,15 @@ class InworldTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Inworld TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Inworld's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Inworld TTS service.
|
||||
|
||||
@@ -264,7 +266,7 @@ class InworldTTSService(TTSService):
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["audioConfig"]["sampleRateHertz"] = self.sample_rate
|
||||
self._settings["audio_config"]["sample_rate_hertz"] = self.sample_rate
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Inworld TTS service.
|
||||
@@ -330,7 +332,9 @@ class InworldTTSService(TTSService):
|
||||
"text": text, # Text to synthesize
|
||||
"voiceId": self._settings["voiceId"], # Voice selection (Ashley, Hades, etc.)
|
||||
"modelId": self._settings["modelId"], # TTS model (inworld-tts-1)
|
||||
"audioConfig": self._settings["audioConfig"], # Audio format settings (LINEAR16, 48kHz)
|
||||
"audio_config": self._settings[
|
||||
"audio_config"
|
||||
], # Audio format settings (LINEAR16, 48kHz)
|
||||
}
|
||||
|
||||
# Add optional temperature parameter if configured (valid range: [0, 2])
|
||||
@@ -397,7 +401,8 @@ class InworldTTSService(TTSService):
|
||||
# STEP 7: ERROR HANDLING
|
||||
# ================================================================================
|
||||
# Log any unexpected errors and notify the pipeline
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# ================================================================================
|
||||
# STEP 8: CLEANUP AND COMPLETION
|
||||
@@ -512,7 +517,7 @@ class InworldTTSService(TTSService):
|
||||
# Extract the base64-encoded audio content from response
|
||||
if "audioContent" not in response_data:
|
||||
logger.error("No audioContent in Inworld API response")
|
||||
yield ErrorFrame(error="No audioContent in response")
|
||||
await self.push_error(ErrorFrame(error="No audioContent in response"))
|
||||
return
|
||||
|
||||
# ================================================================================
|
||||
|
||||
@@ -166,27 +166,23 @@ class LLMService(AIService):
|
||||
# However, subclasses should override this with a more specific adapter when necessary.
|
||||
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||
|
||||
def __init__(self, run_in_parallel: bool = True, wait_for_all: bool = False, **kwargs):
|
||||
def __init__(self, run_in_parallel: bool = True, **kwargs):
|
||||
"""Initialize the LLM service.
|
||||
|
||||
Args:
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
wait_for_all: Whether to wait for all function calls (parallel or
|
||||
sequential) to complete. Defaults to False.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._wait_for_all = wait_for_all
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
self._skip_tts: Optional[bool] = None
|
||||
self._skip_tts: bool = False
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -297,8 +293,7 @@ class LLMService(AIService):
|
||||
direction: The direction of frame pushing.
|
||||
"""
|
||||
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
if self._skip_tts is not None:
|
||||
frame.skip_tts = self._skip_tts
|
||||
frame.skip_tts = self._skip_tts
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@@ -440,7 +435,6 @@ class LLMService(AIService):
|
||||
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
|
||||
|
||||
runner_items = []
|
||||
for function_call in function_calls:
|
||||
if function_call.function_name in self._functions.keys():
|
||||
item = self._functions[function_call.function_name]
|
||||
@@ -452,20 +446,28 @@ class LLMService(AIService):
|
||||
)
|
||||
continue
|
||||
|
||||
runner_items.append(
|
||||
FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
)
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
)
|
||||
|
||||
if self._run_in_parallel:
|
||||
await self._run_parallel_function_calls(runner_items)
|
||||
else:
|
||||
await self._run_sequential_function_calls(runner_items)
|
||||
if self._run_in_parallel:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
else:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
|
||||
async def request_image_frame(
|
||||
self,
|
||||
@@ -538,46 +540,6 @@ class LLMService(AIService):
|
||||
await task
|
||||
del self._function_call_tasks[task]
|
||||
|
||||
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
tasks = []
|
||||
for runner_item in runner_items:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
tasks.append(task)
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
|
||||
if self._wait_for_all:
|
||||
# Protect gather from being cancelled. This will protect all tasks
|
||||
# form being cancelled. That is fine, because we cancel them
|
||||
# explicitly when handling the interruption (InterruptionFrame). We
|
||||
# need to set `return_exceptions=True` because `asyncio.shield()`
|
||||
# will get cancelled (from FrameProcessor process task), then
|
||||
# `asyncio.gather()` will keep running (because it was protected by
|
||||
# the shield). Then, individiaul function call tasks will be
|
||||
# cancelled by us and we don't need to propagate those
|
||||
# CancelledErrors at that point.
|
||||
await asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
|
||||
|
||||
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
if self._wait_for_all:
|
||||
# Run each function call sequentially, waiting for each to complete.
|
||||
for runner_item in runner_items:
|
||||
self._function_call_tasks[None] = runner_item
|
||||
await self._run_function_call(runner_item)
|
||||
del self._function_call_tasks[None]
|
||||
else:
|
||||
# Enqueue all function calls for background execution.
|
||||
for runner_item in runner_items:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
@@ -661,19 +623,20 @@ class LLMService(AIService):
|
||||
name = runner_item.function_name
|
||||
tool_call_id = runner_item.tool_call_id
|
||||
|
||||
# We remove the callback because we are going to cancel the task
|
||||
# now, otherwise we will be removing it from the set while we
|
||||
# are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
|
||||
if task:
|
||||
# We remove the callback because we are going to cancel the
|
||||
# task next, otherwise we will be removing it from the set
|
||||
# while we are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
await self.cancel_task(task)
|
||||
cancelled_tasks.add(task)
|
||||
await self.cancel_task(task)
|
||||
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
|
||||
@@ -124,6 +124,15 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that LMNT TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that LMNT's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to LMNT service language format.
|
||||
|
||||
@@ -214,7 +223,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -230,7 +240,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
# await self._websocket.send(json.dumps({"eof": True}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -264,9 +275,10 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
if "error" in msg:
|
||||
logger.error(f"{self} error: {msg['error']}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['error']}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
@@ -299,11 +311,13 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(json.dumps({"flush": True}))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -176,6 +176,7 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _stdio_list_tools(self) -> ToolsSchema:
|
||||
@@ -206,6 +207,7 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _streamable_http_list_tools(self) -> ToolsSchema:
|
||||
@@ -244,6 +246,7 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _call_tool(self, session, function_name, arguments, result_callback):
|
||||
@@ -299,6 +302,7 @@ class MCPClient(BaseObject):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read tool '{tool_name}': {str(e)}")
|
||||
logger.exception("Full exception details:")
|
||||
continue
|
||||
|
||||
logger.debug(f"Completed reading {len(tool_schemas)} tools")
|
||||
|
||||
@@ -253,9 +253,8 @@ class Mem0MemoryService(FrameProcessor):
|
||||
# Otherwise, pass the enhanced context frame downstream
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Error processing with Mem0: {str(e)}", exception=e
|
||||
)
|
||||
logger.error(f"Error processing with Mem0: {str(e)}")
|
||||
await self.push_frame(ErrorFrame(f"Error processing with Mem0: {str(e)}"))
|
||||
await self.push_frame(frame) # Still pass the original frame through
|
||||
else:
|
||||
# For non-context frames, just pass them through
|
||||
|
||||
@@ -40,40 +40,24 @@ def language_to_minimax_language(language: Language) -> Optional[str]:
|
||||
The corresponding MiniMax language name, or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
Language.AF: "Afrikaans",
|
||||
Language.AR: "Arabic",
|
||||
Language.BG: "Bulgarian",
|
||||
Language.CA: "Catalan",
|
||||
Language.CS: "Czech",
|
||||
Language.DA: "Danish",
|
||||
Language.DE: "German",
|
||||
Language.EL: "Greek",
|
||||
Language.EN: "English",
|
||||
Language.ES: "Spanish",
|
||||
Language.FA: "Persian", # ⚠️ Only supported by speech-2.6-* models
|
||||
Language.FI: "Finnish",
|
||||
Language.FIL: "Filipino", # ⚠️ Only supported by speech-2.6-* models
|
||||
Language.FR: "French",
|
||||
Language.HE: "Hebrew",
|
||||
Language.HI: "Hindi",
|
||||
Language.HR: "Croatian",
|
||||
Language.HU: "Hungarian",
|
||||
Language.ID: "Indonesian",
|
||||
Language.IT: "Italian",
|
||||
Language.JA: "Japanese",
|
||||
Language.KO: "Korean",
|
||||
Language.MS: "Malay",
|
||||
Language.NB: "Norwegian",
|
||||
Language.NN: "Nynorsk",
|
||||
Language.NL: "Dutch",
|
||||
Language.PL: "Polish",
|
||||
Language.PT: "Portuguese",
|
||||
Language.RO: "Romanian",
|
||||
Language.RU: "Russian",
|
||||
Language.SK: "Slovak",
|
||||
Language.SL: "Slovenian",
|
||||
Language.SV: "Swedish",
|
||||
Language.TA: "Tamil", # ⚠️ Only supported by speech-2.6-* models
|
||||
Language.TH: "Thai",
|
||||
Language.TR: "Turkish",
|
||||
Language.UK: "Ukrainian",
|
||||
@@ -100,22 +84,13 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""Configuration parameters for MiniMax TTS.
|
||||
|
||||
Parameters:
|
||||
language: Language for TTS generation. Supports 40 languages.
|
||||
Note: Filipino, Tamil, and Persian require speech-2.6-* models.
|
||||
language: Language for TTS generation.
|
||||
speed: Speech speed (range: 0.5 to 2.0).
|
||||
volume: Speech volume (range: 0 to 10).
|
||||
pitch: Pitch adjustment (range: -12 to 12).
|
||||
emotion: Emotional tone (options: "happy", "sad", "angry", "fearful",
|
||||
"disgusted", "surprised", "calm", "fluent").
|
||||
english_normalization: Deprecated; use `text_normalization` instead
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
The `english_normalization` parameter is deprecated and will be removed in a future version.
|
||||
Use the `text_normalization` parameter instead.
|
||||
|
||||
text_normalization: Enable text normalization (Chinese/English).
|
||||
latex_read: Enable LaTeX formula reading.
|
||||
exclude_aggregated_audio: Whether to exclude aggregated audio in final chunk.
|
||||
"disgusted", "surprised", "neutral").
|
||||
english_normalization: Whether to apply English text normalization.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
@@ -123,10 +98,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[int] = 0
|
||||
emotion: Optional[str] = None
|
||||
english_normalization: Optional[bool] = None # Deprecated
|
||||
text_normalization: Optional[bool] = None
|
||||
latex_read: Optional[bool] = None
|
||||
exclude_aggregated_audio: Optional[bool] = None
|
||||
english_normalization: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -148,12 +120,9 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
base_url: API base URL, defaults to MiniMax's T2A endpoint.
|
||||
Global: https://api.minimax.io/v1/t2a_v2
|
||||
Mainland China: https://api.minimaxi.chat/v1/t2a_v2
|
||||
Western United States: https://api-uw.minimax.io/v1/t2a_v2
|
||||
group_id: MiniMax Group ID to identify project.
|
||||
model: TTS model name. Defaults to "speech-02-turbo". Options include:
|
||||
"speech-2.6-hd", "speech-2.6-turbo" (latest, supports Filipino/Tamil/Persian),
|
||||
"speech-02-hd", "speech-02-turbo",
|
||||
"speech-01-hd", "speech-01-turbo".
|
||||
model: TTS model name. Defaults to "speech-02-turbo". Options include
|
||||
"speech-02-hd", "speech-02-turbo", "speech-01-hd", "speech-01-turbo".
|
||||
voice_id: Voice identifier. Defaults to "Calm_Woman".
|
||||
aiohttp_session: aiohttp.ClientSession for API communication.
|
||||
sample_rate: Output audio sample rate in Hz. If None, uses pipeline default.
|
||||
@@ -207,34 +176,15 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
"fluent",
|
||||
]
|
||||
if params.emotion in supported_emotions:
|
||||
self._settings["voice_setting"]["emotion"] = params.emotion
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unsupported emotion: {params.emotion}. Supported emotions: {supported_emotions}"
|
||||
)
|
||||
logger.warning(f"Unsupported emotion: {params.emotion}. Using default.")
|
||||
|
||||
# If `english_normalization`, add `text_normalization` and print warning
|
||||
# Add english_normalization if provided
|
||||
if params.english_normalization is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter `english_normalization` is deprecated and will be removed in a future version. Use `text_normalization` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self._settings["voice_setting"]["text_normalization"] = params.english_normalization
|
||||
|
||||
# Add text_normalization if provided (corrected parameter name)
|
||||
if params.text_normalization is not None:
|
||||
self._settings["voice_setting"]["text_normalization"] = params.text_normalization
|
||||
|
||||
# Add latex_read if provided
|
||||
if params.latex_read is not None:
|
||||
self._settings["voice_setting"]["latex_read"] = params.latex_read
|
||||
self._settings["english_normalization"] = params.english_normalization
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -244,6 +194,15 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that MiniMax TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that MiniMax's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to MiniMax service language format.
|
||||
|
||||
@@ -281,7 +240,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["audio_setting"]["sample_rate"] = self.sample_rate
|
||||
logger.debug(f"MiniMax TTS initialized with sample_rate: {self.sample_rate}")
|
||||
logger.debug(f"MiniMax TTS initialized with sample rate: {self.sample_rate}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -314,6 +273,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"MiniMax TTS error: HTTP {response.status}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
@@ -379,19 +339,16 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
num_channels=1,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error converting hex to binary: {e}",
|
||||
)
|
||||
logger.error(f"Error converting hex to binary: {e}")
|
||||
continue
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Error decoding JSON: {e}, data: {data_block[:100]}",
|
||||
)
|
||||
logger.error(f"Error decoding JSON: {e}, data: {data_block[:100]}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -110,6 +110,7 @@ class MoondreamService(VisionService):
|
||||
if analysis fails.
|
||||
"""
|
||||
if not self._model:
|
||||
logger.error(f"{self} error: Moondream model not available ({self.model_name})")
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
return
|
||||
|
||||
|
||||
@@ -151,6 +151,15 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Neuphonic service language format.
|
||||
|
||||
@@ -285,7 +294,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -298,7 +308,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Neuphonic")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -363,14 +374,16 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class NeuphonicHttpTTSService(TTSService):
|
||||
@@ -436,6 +449,15 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Neuphonic service language format.
|
||||
|
||||
@@ -534,6 +556,7 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
error_message = f"Neuphonic API error: HTTP {response.status} - {error_text}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
@@ -563,7 +586,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
yield TTSAudioRawFrame(audio_bytes, self.sample_rate, 1)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
# Don't yield error frame for individual message failures
|
||||
continue
|
||||
|
||||
@@ -571,7 +595,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
logger.debug("TTS generation cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -8,23 +8,98 @@
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaLLMService from
|
||||
pipecat.services.nvidia.llm instead.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"NimLLMService from pipecat.services.nim.llm is deprecated. "
|
||||
"Please use NvidiaLLMService from pipecat.services.nvidia.llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
class NimLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
NimLLMService = NvidiaLLMService
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NimLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA NIM API service implementation.
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
"""
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class NvidiaLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NvidiaLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
@@ -1,663 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_nvidia_riva_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to NVIDIA Riva ASR language codes.
|
||||
|
||||
Source:
|
||||
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
# Arabic
|
||||
Language.AR: "ar-AR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_GB: "en-GB",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Portuguese
|
||||
Language.PT: "pt-BR", # Default to Brazilian
|
||||
Language.PT_BR: "pt-BR",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US", # US Spanish
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class NvidiaSTTService(STTService):
|
||||
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
|
||||
|
||||
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
|
||||
through streaming recognition. Supports interim results and continuous audio
|
||||
processing for low-latency applications.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: NVIDIA Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for NVIDIA Riva.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._settings = {
|
||||
"language": str(params.language),
|
||||
"profanity_filter": self._profanity_filter,
|
||||
"automatic_punctuation": self._automatic_punctuation,
|
||||
"verbatim_transcripts": not self._no_verbatim_transcripts,
|
||||
"boosted_lm_words": self._boosted_lm_words,
|
||||
"boosted_lm_score": self._boosted_lm_score,
|
||||
}
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
False - this service does not support metrics generation.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the NVIDIA Riva STT service and initialize streaming configuration.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the NVIDIA Riva STT service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the NVIDIA Riva STT service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._thread_task:
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
self._thread_running = True
|
||||
await asyncio.to_thread(self._response_handler)
|
||||
except asyncio.CancelledError:
|
||||
self._thread_running = False
|
||||
raise
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_response(self, response):
|
||||
for result in response.results:
|
||||
if result and not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=result.is_final,
|
||||
language=self._language_code,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
"""Get the next audio chunk for NVIDIA Riva processing.
|
||||
|
||||
Returns:
|
||||
Audio bytes from the queue.
|
||||
|
||||
Raises:
|
||||
StopIteration: When the thread is no longer running.
|
||||
"""
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
except FuturesCancelledError:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator for audio chunk processing.
|
||||
|
||||
Returns:
|
||||
Self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
|
||||
|
||||
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
|
||||
transcription on audio segments. It inherits from SegmentedSTTService to handle
|
||||
audio buffering and speech detection.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva segmented STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
profanity_filter: Whether to filter profanity from results.
|
||||
automatic_punctuation: Whether to add automatic punctuation.
|
||||
verbatim_transcripts: Whether to return verbatim transcripts.
|
||||
boosted_lm_words: List of words to boost in language model.
|
||||
boosted_lm_score: Score boost for specified words.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
profanity_filter: bool = False
|
||||
automatic_punctuation: bool = True
|
||||
verbatim_transcripts: bool = False
|
||||
boosted_lm_words: Optional[List[str]] = None
|
||||
boosted_lm_score: float = 4.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
"model_name": "canary-1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva segmented STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication
|
||||
server: NVIDIA Riva server address (defaults to NVIDIA Cloud Function endpoint)
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for NVIDIA Riva
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSegmentedSTTService.InputParams()
|
||||
|
||||
# Set model name
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
# Initialize NVIDIA Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
# Store the language as a Language enum and as a string
|
||||
self._language_enum = params.language or Language.EN_US
|
||||
self._language = self.language_to_service_language(self._language_enum) or "en-US"
|
||||
|
||||
# Configure transcription parameters
|
||||
self._profanity_filter = params.profanity_filter
|
||||
self._automatic_punctuation = params.automatic_punctuation
|
||||
self._verbatim_transcripts = params.verbatim_transcripts
|
||||
self._boosted_lm_words = params.boosted_lm_words
|
||||
self._boosted_lm_score = params.boosted_lm_score
|
||||
|
||||
# Voice activity detection thresholds (use NVIDIA Riva defaults)
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
# Create NVIDIA Riva client
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = {"language": self._language_enum}
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to NVIDIA Riva's language code.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
return language_to_nvidia_riva_language(language)
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the NVIDIA Riva ASR client with authentication metadata."""
|
||||
if self._asr_service is not None:
|
||||
return
|
||||
|
||||
# Set up authentication metadata for NVIDIA Cloud Functions
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=self._language, # Now using the string, not a tuple
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=self._verbatim_transcripts,
|
||||
)
|
||||
|
||||
# Add word boosting if specified
|
||||
if self._boosted_lm_words:
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
# Add voice activity detection parameters
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
|
||||
# Add any custom configuration
|
||||
if self._custom_configuration:
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True - this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the service when the pipeline starts.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
|
||||
Args:
|
||||
language: Target language for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_enum = language
|
||||
self._language = self.language_to_service_language(language) or "en-US"
|
||||
self._settings["language"] = language
|
||||
|
||||
# Update configuration with new language
|
||||
if self._config:
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in NVIDIA Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
|
||||
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
@@ -1,187 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva text-to-speech service implementation.
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
NVIDIA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class NvidiaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language code for synthesis. Defaults to US English.
|
||||
quality: Audio quality setting (0-100). Defaults to 20.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
Note: Model cannot be changed after initialization for Riva service.
|
||||
|
||||
Args:
|
||||
model: The model name to set (operation not supported).
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user