Compare commits
1 Commits
hush/TurnT
...
hush/endFr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b28fc8e8e |
298
CHANGELOG.md
298
CHANGELOG.md
@@ -9,299 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `LiveKitRESTHelper` utility class for managing LiveKit rooms via REST API.
|
||||
|
||||
- Added `DeepgramSageMakerSTTService` which connects to a SageMaker hosted
|
||||
Deepgram STT model. Added `07c-interruptible-deepgram-sagemaker.py`
|
||||
foundational example.
|
||||
|
||||
- Added `SageMakerBidiClient` to connect to SageMaker hosted BiDi compatible
|
||||
services.
|
||||
|
||||
- Added support for `include_timestamps` and `enable_logging` in
|
||||
`ElevenLabsRealtimeSTTService`. When `include_timestamps` is enabled,
|
||||
timestamp data is included in the `TranscriptionFrame`'s `result`
|
||||
parameter.
|
||||
|
||||
- Added optional speaking rate control to `InworldTTSService`.
|
||||
|
||||
- Introduced a new `AggregatedTextFrame` type to support passing text along with
|
||||
an `aggregated_by` field to describe the type of text
|
||||
included. `TTSTextFrame`s now inherit from `AggregatedTextFrame`. With this
|
||||
inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate the
|
||||
perceived output and determine whether or not the text was spoken based on if
|
||||
that frame is also a `TTSTextFrame`.
|
||||
|
||||
With this frame, the llm token stream can be transformed into custom
|
||||
composable chunks, allowing for aggregation outside the TTS service. This
|
||||
makes it possible to listen for or handle those aggregations and sets the
|
||||
stage for doing things like composing a best effort of the perceived llm
|
||||
output in a more digestable form and to do so whether or not it is processed
|
||||
by a TTS or if even a TTS exists.
|
||||
|
||||
- Introduced `LLMTextProcessor`: A new processor meant to allow customization
|
||||
for how LLMTextFrames should be aggregated and considered. It's purpose is to
|
||||
turn `LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService
|
||||
will still aggregate `LLMTextFrame`s by sentence for the service to
|
||||
consume. However, if you wish to override how the llm text is aggregated, you
|
||||
should no longer override the TTS's internal text_aggregator, but instead,
|
||||
insert this processor between your LLM and TTS in the pipeline.
|
||||
|
||||
- New `bot-output` RTVI message to represent what the bot actually "says".
|
||||
|
||||
- The `RTVIObserver` now emits `bot-output` messages based off the new
|
||||
`AggregatedTextFrame`s (`bot-tts-text` and `bot-llm-text` are still
|
||||
supported and generated, but `bot-transcript` is now deprecated in lieu of
|
||||
this new, more thorough, message).
|
||||
|
||||
- The new `RTVIBotOutputMessage` includes the fields:
|
||||
|
||||
- `spoken`: A boolean indicating whether the text was spoken by TTS
|
||||
|
||||
- `aggregated_by`: A string representing how the text was aggregated
|
||||
("sentence", "word", "my custom aggregation")
|
||||
|
||||
- Introduced new fields to `RTVIObserver` to support the new `bot-output`
|
||||
messaging:
|
||||
|
||||
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output
|
||||
messages.
|
||||
|
||||
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that
|
||||
match aggregation types that should not be included in bot-output
|
||||
messages. (Ex. `credit_card`)
|
||||
|
||||
- Introduced new methods, `add_text_transformer()` and
|
||||
`remove_text_transformer()`, to `RTVIObserver` to support providing (and
|
||||
subsequently removing) callbacks for various types of aggregations (or all
|
||||
aggregations with `*`) that can modify the text before being sent as a
|
||||
`bot-output` or `tts-text` message. (Think obscuring the credit card or
|
||||
inserting extra detail the client might want that the context doesn't need.)
|
||||
|
||||
- In `MiniMaxHttpTTSService`:
|
||||
|
||||
- Added support for speech-2.6-hd and speech-2.6-turbo models
|
||||
|
||||
- Added languages: Afrikaans, Bulgarian, Catalan, Danish, Persian, Filipino,
|
||||
Hebrew, Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian,
|
||||
Swedish, and Tamil
|
||||
|
||||
- Added new emotions: calm and fluent
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `daily-python` to 0.22.0.
|
||||
|
||||
- `BaseTextAggregator` changes:
|
||||
|
||||
Modified the BaseTextAggregator type so that when text gets aggregated,
|
||||
metadata can be associated with it. Currently, that just means a `type`, so
|
||||
that the aggregation can be classified or described. Changes made to support
|
||||
this:
|
||||
|
||||
- ⚠️ IMPORTANT: Aggregators are now expected to strip leading/trailing white
|
||||
space characters before returning their aggregation from `aggregation()` or
|
||||
`.text`. This way all aggregators have a consistent contract allowing
|
||||
downstream use to know how to stitch aggregations back together.
|
||||
|
||||
- Introduced a new `Aggregation` dataclass to represent both the aggregated
|
||||
`text` and a string identifying the `type` of aggregation (ex. "sentence",
|
||||
"word", "my custom aggregation")
|
||||
|
||||
- ⚠️ Breaking change: `BaseTextAggregator.text` now returns an `Aggregation`
|
||||
(instead of `str`).
|
||||
|
||||
Before:
|
||||
|
||||
```python
|
||||
aggregated_text = myAggregator.text
|
||||
```
|
||||
|
||||
Now:
|
||||
|
||||
```python
|
||||
aggregated_text = myAggregator.text.text
|
||||
```
|
||||
|
||||
- ⚠️ Breaking change: `BaseTextAggregator.aggregate()` now returns
|
||||
`Optional[Aggregation]` (instead of `Optional[str]`).
|
||||
|
||||
Before:
|
||||
|
||||
```python
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
print(f"successfully aggregated text: {aggregation}")
|
||||
```
|
||||
|
||||
Now:
|
||||
|
||||
```python
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
if aggregation:
|
||||
print(f"successfully aggregated text: {aggregation.text}")
|
||||
```
|
||||
|
||||
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator`
|
||||
updated to produce/consume `Aggregation` objects.
|
||||
|
||||
- All uses of the above Aggregators have been updated accordingly.
|
||||
|
||||
- Augmented the `PatternPairAggregator` so that matched patterns can be treated
|
||||
as their own aggregation, taking advantage of the new. To that end:
|
||||
|
||||
- Introduced a new, preferred version of `add_pattern` to support a new option
|
||||
for treating a match as a separate aggregation returned from
|
||||
`aggregate()`. This replaces the now deprecated `add_pattern_pair` method
|
||||
and you provide a `MatchAction` in lieu of the `remove_match` field.
|
||||
|
||||
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization
|
||||
for how a match should be handled.
|
||||
|
||||
- `REMOVE`: The text along with its delimiters will be removed from the
|
||||
streaming text. Sentence aggregation will continue on as if this text
|
||||
did not exist.
|
||||
|
||||
- `KEEP`: The delimiters will be removed, but the content between them
|
||||
will be kept. Sentence aggregation will continue on with the internal
|
||||
text included.
|
||||
|
||||
- `AGGREGATE`: The delimiters will be removed and the content between will
|
||||
be treated as a separate aggregation. Any text before the start of the
|
||||
pattern will be returned early, whether or not a complete sentence was
|
||||
found. Then the pattern will be returned. Then the aggregation will
|
||||
continue on sentence matching after the closing delimiter is found. The
|
||||
content between the delimiters is not aggregated by sentence. It is
|
||||
aggregated as one single block of text.
|
||||
|
||||
- `PatternMatch` now extends `Aggregation` and provides richer info to
|
||||
handlers.
|
||||
|
||||
- ⚠️ Breaking change: The `PatternMatch` type returned to handlers registered
|
||||
via `on_pattern_match` has been updated to subclass from the new
|
||||
`Aggregation` type, which means that `content` has been replaced with
|
||||
`text` and `pattern_id` has been replaced with `type`:
|
||||
|
||||
```python
|
||||
async dev on_match_tag(match: PatternMatch):
|
||||
pattern = match.type # instead of match.pattern_id
|
||||
text = match.text # instead of match.content
|
||||
```
|
||||
|
||||
- `TextFrame` now includes the field `append_to_context` to support setting
|
||||
whether or not the encompassing text should be added to the LLM context (by
|
||||
the LLM assistant aggregator). It defaults to `True`.
|
||||
|
||||
- `TTSService` base class updates:
|
||||
|
||||
- `TTSService`s now accept a new `skip_aggregator_types` to avoid speaking
|
||||
certain aggregation types (now determined/returned by the aggregator)
|
||||
|
||||
- Introduced the ability to do a just-in-time transform of text before it gets
|
||||
sent to the TTS service via callbacks you can set up via a new init field,
|
||||
`text_transforms` or a new method `add_text_transformer()`. This makes it
|
||||
possible to do things like introduce TTS-specific tags for spelling or
|
||||
emotion or change the pronunciation of something on the
|
||||
fly. `remove_text_transformer` has also been added to support removing a
|
||||
registered transform callback.
|
||||
|
||||
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when
|
||||
either an aggregation occurs that should not be spoken or when the TTS
|
||||
service supports word-by-word timestamping. In the latter case, the
|
||||
`TTSService` preliminarily generates an `AggregatedTextFrame`, aggregated by
|
||||
sentence to generate the full sentence content as early as possible.
|
||||
|
||||
- Updated `CartesiaTTSService`:
|
||||
|
||||
- Modified use of custom default text_aggregator to avoid deprecation warnings
|
||||
and push users towards use of transformers or the `LLMTextProcessor`
|
||||
|
||||
- Added convenience methods for taking advantage of Cartesia's SSML tags:
|
||||
spell, emotion, pauses, volume, and speed.
|
||||
|
||||
- Updated `RimeTTSService`:
|
||||
|
||||
- Modified use of custom default text_aggregator to avoid deprecation warnings
|
||||
and push users towards use of transformers or the `LLMTextProcessor`
|
||||
|
||||
- Added convenience methods for taking advantage of Rime's customization
|
||||
options: spell, pauses, pronunciations, and inline speed control.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
|
||||
`LLMTextProcessor`. TTSServices still have an internal aggregator for support
|
||||
of default behavior, but if you want to override the aggregation behavior, you
|
||||
should use the new processor.
|
||||
|
||||
- The RTVI `bot-transcription` event is deprecated in favor of the new
|
||||
`bot-output` message which is the canonical representation of bot output
|
||||
(spoken or not). The code still emits a transcription message for backwards
|
||||
compatibility while transition occurs.
|
||||
|
||||
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a
|
||||
`pattern_id` and `remove_match` field in favor of the new `add_pattern` method
|
||||
which takes a `type` and an `action`
|
||||
|
||||
- `english_normalization` input parameter for `MiniMaxHttpTTSService` is
|
||||
deprecated, use `test_normalization` instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `ElevenLabsRealtimeSTTService` where dynamic language
|
||||
updates were not working.
|
||||
|
||||
- Fixed an issue in `ElevenLabsRealtimeSTTService` where setting the sample
|
||||
rate would result in transcripts failing.
|
||||
|
||||
- Fixed `InworldTTSService` audio config payload to use camelCase keys expected
|
||||
by the Inworld API.
|
||||
|
||||
## [0.0.95] - 2025-11-18
|
||||
|
||||
### Added
|
||||
|
||||
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
|
||||
example wiring; leverages the enhancement model for robust detection with no
|
||||
ONNX dependency or added processing complexity.
|
||||
|
||||
- Added a watchdog to `DeepgramFluxSTTService` to prevent dangling tasks in case the
|
||||
user was speaking and we stop receiving audio.
|
||||
|
||||
- Introduced a minimum confidence parameter in `DeepgramFluxSTTService` to avoid
|
||||
generating transcriptions below a defined threshold.
|
||||
|
||||
- Added `ElevenLabsRealtimeSTTService` which implements the Realtime STT
|
||||
service from ElevenLabs.
|
||||
|
||||
- Added word-level timestamps support to Hume TTS service
|
||||
- Added a `TTSService.includes_inter_frame_spaces` property getter, so that TTS
|
||||
services that subclass `TTSService` can indicate whether the text in the
|
||||
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
|
||||
|
||||
### Changed
|
||||
|
||||
- ⚠️ Breaking change: `LLMContext.create_image_message()`,
|
||||
`LLMContext.create_audio_message()`, `LLMContext.add_image_frame_message()`
|
||||
and `LLMContext.add_audio_frames_message()` are now async methods. This fixes
|
||||
an issue where the asyncio event loop would be blocked while encoding audio or
|
||||
images.
|
||||
|
||||
- `ConsumerProcessor` now queues frames from the producer internally instead of
|
||||
pushing them directly. This allows us to subclass consumer processors and
|
||||
manipulate frames before they are pushed.
|
||||
|
||||
- `BaseTextFilter` only require subclasses to implement the `filter()` method.
|
||||
|
||||
- Extracted the logic for retrying connections, and create a new `send_with_retry`
|
||||
method inside `WebSocketService`.
|
||||
|
||||
- Refactored `DeepgramFluxSTTService` to automatically reconnect if sending a
|
||||
message fails.
|
||||
|
||||
- Updated all STT and TTS services to use consistent error handling pattern with
|
||||
`push_error()` method for better pipeline error event integration.
|
||||
|
||||
- Added support for `maybe_capture_participant_camera()` and
|
||||
`maybe_capture_participant_screen()` for `SmallWebRTCTransport` in the runner
|
||||
utils.
|
||||
|
||||
- Added Hindi support for Rime TTS services.
|
||||
|
||||
- Updated `GeminiTTSService` to use Google Cloud Text-to-Speech streaming API
|
||||
@@ -321,11 +40,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `SimliVideoService` connection issue.
|
||||
|
||||
- Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the
|
||||
`request_data` was not being passed to the `SmallWebRTCRunnerArguments` body.
|
||||
|
||||
- Fixed subtle issue of assistant context messages ending up with double spaces
|
||||
between words or sentences.
|
||||
|
||||
@@ -340,6 +54,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Prevented `HeyGenVideoService` from automatically disconnecting after 5 minutes.
|
||||
|
||||
### Added
|
||||
|
||||
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
|
||||
example wiring; leverages the enhancement model for robust detection with no
|
||||
ONNX dependency or added processing complexity.
|
||||
|
||||
## [0.0.94] - 2025-11-10
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
# TurnAwareTranscriptProcessor Example
|
||||
|
||||
## Overview
|
||||
|
||||
The `TurnAwareTranscriptProcessor` combines user and assistant transcript tracking with turn boundary detection. It correctly handles interruptions by only capturing what was actually spoken.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
|
||||
|
||||
# Create the processor
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Register event handlers
|
||||
@turn_processor.event_handler("on_turn_started")
|
||||
async def handle_turn_started(processor, turn_number):
|
||||
print(f"Turn {turn_number} started")
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, was_interrupted):
|
||||
print(f"\nTurn {turn_number} ended:")
|
||||
print(f" User said: {user_text}")
|
||||
print(f" Assistant said: {assistant_text}")
|
||||
print(f" Was interrupted: {was_interrupted}")
|
||||
|
||||
@turn_processor.event_handler("on_transcript_update")
|
||||
async def handle_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
print(f"[{msg.role}]: {msg.content}")
|
||||
|
||||
# Add to pipeline
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
turn_processor, # Process transcripts and track turns
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
1. **Turn Boundary Detection**: Automatically detects when turns start and end based on user and bot speaking patterns
|
||||
2. **Interruption Handling**: Correctly captures only what was actually spoken when interruptions occur
|
||||
3. **Real-time Transcripts**: Emits transcript messages for both user and assistant speech
|
||||
4. **Turn Events**: Provides start/end events with accumulated transcripts for each turn
|
||||
|
||||
## Events
|
||||
|
||||
### on_turn_started
|
||||
Emitted when a new turn begins (user starts speaking).
|
||||
|
||||
**Handler signature**: `async def handler(processor, turn_number)`
|
||||
|
||||
### on_turn_ended
|
||||
Emitted when a turn ends with accumulated transcripts.
|
||||
|
||||
**Handler signature**: `async def handler(processor, turn_number, user_transcript, assistant_transcript, was_interrupted)`
|
||||
|
||||
### on_transcript_update
|
||||
Inherited from `BaseTranscriptProcessor`, emitted for individual transcript messages.
|
||||
|
||||
**Handler signature**: `async def handler(processor, frame)`
|
||||
|
||||
## Turn Logic
|
||||
|
||||
- Turns start when the user begins speaking (`UserStartedSpeakingFrame`)
|
||||
- Turns end when:
|
||||
- The user starts speaking again (previous turn ends, new turn starts)
|
||||
- The bot is interrupted (`InterruptionFrame`)
|
||||
- The pipeline ends (`EndFrame`/`CancelFrame`)
|
||||
|
||||
## Integration with OpenTelemetry
|
||||
|
||||
You can use turn events to enrich OpenTelemetry spans:
|
||||
|
||||
```python
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
|
||||
turn_tracker = TurnTrackingObserver()
|
||||
turn_tracer = TurnTraceObserver(turn_tracker)
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def add_transcripts_to_span(processor, turn_number, user_text, assistant_text, interrupted):
|
||||
# Get current span and add transcript data
|
||||
from opentelemetry import trace
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("turn.user_text", user_text)
|
||||
current_span.set_attribute("turn.assistant_text", assistant_text)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The processor handles async frame processing correctly by delaying turn end until frames are processed
|
||||
- Works with word-level timestamps from TTS services like Cartesia
|
||||
- Accumulates both user (`TranscriptionFrame`) and assistant (`TTSTextFrame`) speech
|
||||
- Emits individual transcript messages in addition to turn-level aggregation
|
||||
@@ -44,7 +44,6 @@ DAILY_SAMPLE_ROOM_URL=https://...
|
||||
|
||||
# Deepgram
|
||||
DEEPGRAM_API_KEY=...
|
||||
SAGEMAKER_ENDPOINT_NAME=...
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=...
|
||||
|
||||
@@ -13,29 +13,24 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.hume.tts import HUME_SAMPLE_RATE, HumeTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -93,7 +88,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS (HumeTTSService with word timestamps)
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
@@ -107,14 +102,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
audio_out_sample_rate=HUME_SAMPLE_RATE,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
}
|
||||
),
|
||||
],
|
||||
observers=[RTVIObserver(rtvi)],
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
@@ -124,9 +112,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info(
|
||||
"💡 Word timestamps are enabled! Watch the console for TTSTextFrame logs showing each word with its PTS."
|
||||
)
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@@ -52,10 +52,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramFluxSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
params=DeepgramFluxSTTService.InputParams(min_confidence=0.3),
|
||||
)
|
||||
stt = DeepgramFluxSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -117,7 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = await LLMContext.create_image_message(
|
||||
message = LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -15,21 +15,14 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame, UserImageRequestFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
@@ -73,27 +66,6 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
class MoondreamTextFrameWrapper(FrameProcessor):
|
||||
"""Wraps Moondream-provided TextFrames with LLM response start/end frames.
|
||||
|
||||
This processor detects TextFrames and automatically wraps them with
|
||||
LLMFullResponseStartFrame and LLMFullResponseEndFrame to provide proper
|
||||
response boundaries for downstream processors.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we receive a TextFrame, wrap it with response start/end frames
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(LLMFullResponseStartFrame(), direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.push_frame(LLMFullResponseEndFrame(), direction)
|
||||
else:
|
||||
# For all other frames, just pass them through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -158,12 +130,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# If you run into weird description, try with use_cpu=True
|
||||
moondream = MoondreamService()
|
||||
|
||||
# Wrap TextFrames with LLM response start/end frames, which makes Moondream
|
||||
# output be treated like LLM responses for the purpose of context
|
||||
# aggregation. Without this, the assistant context aggregator would ignore
|
||||
# Moondream output (if the TTS service is disabled).
|
||||
moondream_text_wrapper = MoondreamTextFrameWrapper()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
@@ -171,7 +137,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context_aggregator.user(), # User responses
|
||||
ParallelPipeline(
|
||||
[llm], # LLM
|
||||
[moondream, moondream_text_wrapper],
|
||||
[moondream],
|
||||
),
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
|
||||
@@ -391,7 +391,7 @@ class AudioAccumulator(FrameProcessor):
|
||||
)
|
||||
self._user_speaking = False
|
||||
context = LLMContext()
|
||||
await context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self.push_frame(LLMContextFrame(context=context))
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
||||
|
||||
@@ -150,7 +150,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
LLMLogObserver(),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION),
|
||||
UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE),
|
||||
EndFrame: None,
|
||||
}
|
||||
|
||||
@@ -62,11 +62,7 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -110,16 +106,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
pattern_aggregator = PatternPairAggregator()
|
||||
|
||||
# Add pattern for voice switching
|
||||
pattern_aggregator.add_pattern(
|
||||
type="voice",
|
||||
pattern_aggregator.add_pattern_pair(
|
||||
pattern_id="voice_tag",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
action=MatchAction.REMOVE, # Remove tags from final text
|
||||
remove_match=True,
|
||||
)
|
||||
|
||||
# Register handler for voice switching
|
||||
async def on_voice_tag(match: PatternMatch):
|
||||
voice_name = match.text.strip().lower()
|
||||
voice_name = match.content.strip().lower()
|
||||
if voice_name in VOICE_IDS:
|
||||
# First flush any existing audio to finish the current context
|
||||
await tts.flush_audio()
|
||||
@@ -129,7 +125,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
else:
|
||||
logger.warning(f"Unknown voice: {voice_name}")
|
||||
|
||||
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
|
||||
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp.client.session_group import SseServerParameters
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
@@ -22,16 +23,16 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.mcp_service import MCPClient
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -60,42 +61,56 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Initialize Deepgram SageMaker STT Service
|
||||
# This requires:
|
||||
# - AWS credentials configured (via environment variables or AWS CLI)
|
||||
# - A deployed SageMaker endpoint with Deepgram model
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region=os.getenv("AWS_REGION"),
|
||||
model="us.amazon.nova-pro-v1:0",
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8),
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-7-sonnet-latest"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
try:
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
mcp = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
context = LLMContext(messages)
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to a number of tools provided by mcp.run. Use any and all tools to help users.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
When asked for today's date, use 'https://www.datetoday.net/'.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
messages = [{"role": "system", "content": system}]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
stt,
|
||||
context_aggregator.user(), # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,9 +125,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
@@ -132,6 +146,14 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("MCP_RUN_SSE_URL"):
|
||||
logger.error(
|
||||
f"Please set MCP_RUN_SSE_URL environment variable for this example. See https://mcp.run"
|
||||
)
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -16,7 +15,7 @@ import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from mcp.client.session_group import SseServerParameters
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
@@ -67,12 +66,10 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
pattern = r"!\[[^\]]*\]\((https?://[^)]+\.(png|jpg|jpeg|PNG|JPG|JPEG|gif))\)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
@@ -135,11 +132,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection and the user's GitHub repositories and account.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
You can also offer to answer users questions about their GitHub repositories and account.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
@@ -149,11 +145,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages = [{"role": "system", "content": system}]
|
||||
|
||||
try:
|
||||
rijksmuseum_mcp = MCPClient(
|
||||
mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
args=["-y", "mcp-server-error setting up mcp"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
@@ -161,32 +157,24 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.error(f"error setting up rijksmuseum mcp")
|
||||
logger.exception("error trace:")
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
github_mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={
|
||||
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
|
||||
},
|
||||
)
|
||||
)
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
# ie. "https://www.mcp.run/api/mcp/sse?..."
|
||||
# ensure the profile has a tool or few installed
|
||||
mcp_run = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp.run")
|
||||
logger.exception("error trace:")
|
||||
|
||||
rijksmuseum_tools = {}
|
||||
github_tools = {}
|
||||
tools = {}
|
||||
run_tools = {}
|
||||
try:
|
||||
rijksmuseum_tools = await rijksmuseum_mcp.register_tools(llm)
|
||||
github_tools = await github_mcp.register_tools(llm)
|
||||
tools = await mcp.register_tools(llm)
|
||||
run_tools = await mcp_run.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
|
||||
all_standard_tools = run_tools.standard_tools + tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
|
||||
context = LLMContext(messages, all_tools)
|
||||
@@ -238,9 +226,9 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN"):
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("MCP_RUN_SSE_URL"):
|
||||
logger.error(
|
||||
f"Please set `RIJKSMUSEUM_API_KEY` and `GITHUB_PERSONAL_ACCESS_TOKEN` environment variables. See https://github.com/r-huijts/rijksmuseum-mcp."
|
||||
f"Please set RIJKSMUSEUM_API_KEY and MCP_RUN_SSE_URL environment variables. See https://github.com/r-huijts/rijksmuseum-mcp and https://mcp.run"
|
||||
)
|
||||
import sys
|
||||
|
||||
@@ -49,14 +49,14 @@ aic = [ "aic-sdk~=1.1.0" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
aws = [ "aioboto3~=15.5.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.2.0; python_version>='3.12'" ]
|
||||
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.1.1; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cerebras = []
|
||||
daily = [ "daily-python~=0.22.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.21.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
@@ -69,21 +69,19 @@ gstreamer = [ "pygobject~=3.50.0" ]
|
||||
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
|
||||
hume = [ "hume>=0.11.2" ]
|
||||
inworld = []
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0", "pyjwt>=2.10.1" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "pipecat-ai[websockets-base]" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
|
||||
mcp = [ "mcp[cli]>=1.11.0,<2" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
nim = []
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
@@ -91,16 +89,17 @@ openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
remote-smart-turn = []
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
|
||||
sambanova = []
|
||||
sarvam = [ "sarvamai==0.1.21", "pipecat-ai[websockets-base]" ]
|
||||
sentry = [ "sentry-sdk>=2.28.0,<3" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime>=1.20.1,<2" ]
|
||||
simli = [ "simli-ai~=1.0.3"]
|
||||
simli = [ "simli-ai~=0.1.25"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.1" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.5.0" ]
|
||||
|
||||
@@ -30,8 +30,8 @@ EVAL_SIMPLE_MATH = EvalConfig(
|
||||
)
|
||||
|
||||
EVAL_WEATHER = EvalConfig(
|
||||
prompt="What's the weather in San Francisco (in farhenheit or celsius)?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees (in farhenheit or celsius).",
|
||||
prompt="What's the weather in San Francisco?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees.",
|
||||
)
|
||||
|
||||
EVAL_ONLINE_SEARCH = EvalConfig(
|
||||
@@ -70,7 +70,7 @@ EVAL_VOICEMAIL = EvalConfig(
|
||||
|
||||
EVAL_CONVERSATION = EvalConfig(
|
||||
prompt="Hello, this is Mark.",
|
||||
eval="The user acknowledges the greeting.",
|
||||
eval="The user replies with a greeting.",
|
||||
eval_speaks_first=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,11 +31,7 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
|
||||
|
||||
class IVRStatus(Enum):
|
||||
@@ -118,15 +114,15 @@ class IVRProcessor(FrameProcessor):
|
||||
def _setup_xml_patterns(self):
|
||||
"""Set up XML pattern detection and handlers."""
|
||||
# Register DTMF pattern
|
||||
self._aggregator.add_pattern("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
|
||||
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", remove_match=True)
|
||||
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
|
||||
|
||||
# Register mode pattern
|
||||
self._aggregator.add_pattern("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
|
||||
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", remove_match=True)
|
||||
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
|
||||
|
||||
# Register IVR pattern
|
||||
self._aggregator.add_pattern("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", remove_match=True)
|
||||
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -152,7 +148,7 @@ class IVRProcessor(FrameProcessor):
|
||||
result = await self._aggregator.aggregate(frame.text)
|
||||
if result:
|
||||
# Push aggregated text that doesn't contain XML patterns
|
||||
await self.push_frame(LLMTextFrame(result.text), direction)
|
||||
await self.push_frame(LLMTextFrame(result), direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -163,7 +159,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing DTMF content.
|
||||
"""
|
||||
value = match.text
|
||||
value = match.content
|
||||
logger.debug(f"DTMF detected: {value}")
|
||||
|
||||
try:
|
||||
@@ -184,7 +180,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing IVR status content.
|
||||
"""
|
||||
status = match.text
|
||||
status = match.content
|
||||
logger.trace(f"IVR status detected: {status}")
|
||||
|
||||
# Convert string to enum, with validation
|
||||
@@ -215,7 +211,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing mode content.
|
||||
"""
|
||||
mode = match.text
|
||||
mode = match.content
|
||||
logger.debug(f"Mode detected: {mode}")
|
||||
if mode == "conversation":
|
||||
await self._handle_conversation()
|
||||
|
||||
@@ -12,7 +12,6 @@ and LLM processing.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -338,14 +337,11 @@ class TextFrame(DataFrame):
|
||||
# mandatory fields of theirs to have defaults to preserve
|
||||
# non-default-before-default argument order)
|
||||
includes_inter_frame_spaces: bool = field(init=False)
|
||||
# Whether this text frame should be appended to the LLM context.
|
||||
append_to_context: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
@@ -356,38 +352,11 @@ class TextFrame(DataFrame):
|
||||
class LLMTextFrame(TextFrame):
|
||||
"""Text frame generated by LLM services."""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# LLM services send text frames with all necessary spaces included
|
||||
self.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedTextFrame(TextFrame):
|
||||
"""Text frame representing an aggregation of TextFrames.
|
||||
|
||||
This frame contains multiple TextFrames aggregated together for processing
|
||||
or output along with a field to indicate how they are aggregated.
|
||||
|
||||
Parameters:
|
||||
aggregated_by: Method used to aggregate the text frames.
|
||||
"""
|
||||
|
||||
aggregated_by: AggregationType | str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSTextFrame(AggregatedTextFrame):
|
||||
class TTSTextFrame(TextFrame):
|
||||
"""Text frame generated by Text-to-Speech services."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -14,7 +14,6 @@ translation from this universal context into whatever format it needs, using a
|
||||
service-specific adapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import wave
|
||||
@@ -138,7 +137,7 @@ class LLMContext:
|
||||
return {"role": role, "content": content}
|
||||
|
||||
@staticmethod
|
||||
async def create_image_message(
|
||||
def create_image_message(
|
||||
*,
|
||||
role: str = "user",
|
||||
format: str,
|
||||
@@ -155,21 +154,15 @@ class LLMContext:
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
|
||||
def encode_image():
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
encoded_image = await asyncio.to_thread(encode_image)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
|
||||
return LLMContext.create_image_url_message(role=role, url=url, text=text)
|
||||
|
||||
@staticmethod
|
||||
async def create_audio_message(
|
||||
def create_audio_message(
|
||||
*, role: str = "user", audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
) -> LLMContextMessage:
|
||||
"""Create a context message containing audio.
|
||||
@@ -179,26 +172,21 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
async def encode_audio():
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_audio
|
||||
|
||||
encoded_audio = await asyncio.to_thread(encode_audio)
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content.append(
|
||||
{
|
||||
@@ -333,7 +321,7 @@ class LLMContext:
|
||||
"""
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
async def add_image_frame_message(
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
@@ -344,12 +332,10 @@ class LLMContext:
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
message = await LLMContext.create_image_message(
|
||||
format=format, size=size, image=image, text=text
|
||||
)
|
||||
message = LLMContext.create_image_message(format=format, size=size, image=image, text=text)
|
||||
self.add_message(message)
|
||||
|
||||
async def add_audio_frames_message(
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add a message containing audio frames.
|
||||
@@ -358,7 +344,7 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
message = await LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
|
||||
message = LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
|
||||
self.add_message(message)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1001,7 +1001,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started or not frame.append_to_context:
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._params.expect_stripped_words:
|
||||
|
||||
@@ -66,7 +66,7 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -90,7 +90,15 @@ class LLMContextAggregator(FrameProcessor):
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation: List[TextPartForConcatenation] = []
|
||||
self._aggregation: List[str] = []
|
||||
|
||||
# Whether to add spaces between text parts.
|
||||
# (Currently only used by LLMAssistantAggregator, but could be expanded
|
||||
# to LLMUserAggregator in the future if needed; that would require
|
||||
# additional work since LLMUserAggregator currently trims spaces from
|
||||
# incoming frames before determining whether it "really" received any
|
||||
# text).
|
||||
self._add_spaces = True
|
||||
|
||||
@property
|
||||
def messages(self) -> List[LLMContextMessage]:
|
||||
@@ -183,7 +191,7 @@ class LLMContextAggregator(FrameProcessor):
|
||||
Returns:
|
||||
The concatenated aggregation string.
|
||||
"""
|
||||
return concatenate_aggregated_text(self._aggregation)
|
||||
return concatenate_aggregated_text(self._aggregation, self._add_spaces)
|
||||
|
||||
|
||||
class LLMUserAggregator(LLMContextAggregator):
|
||||
@@ -433,12 +441,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
# Transcriptions never include inter-part spaces (so far).
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
self._aggregation.append(text)
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
@@ -793,7 +796,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
await self._context.add_image_frame_message(
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
@@ -811,18 +814,18 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started or not frame.append_to_context:
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
# Make sure we really have text (spaces count, too!)
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# Track whether we need to add spaces between text parts
|
||||
# Assumption: we can just keep track of the latest frame's value
|
||||
self._add_spaces = not frame.includes_inter_frame_spaces
|
||||
|
||||
self._aggregation.append(frame.text)
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LLM text processor module for processing and aggregating raw LLM output text.
|
||||
|
||||
This processor will convert LLMTextFrames into AggregatedTextFrames based on the
|
||||
configured text aggregator. Using the customizable aggregator, it provides
|
||||
functionality to handle or manipulate LLM text frames before they are sent to other
|
||||
components such as TTS services or context aggregators. It can be used to pre-aggregate
|
||||
and categorize, modify, or filter direct output tokens from the LLM.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class LLMTextProcessor(FrameProcessor):
|
||||
"""A processor for handling or manipulating LLM text frames before they are processed further.
|
||||
|
||||
This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured
|
||||
text aggregator. Using the customizable aggregator, it provides functionality to handle or
|
||||
manipulate LLM text frames before they are sent to other components such as TTS services or
|
||||
context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct
|
||||
output tokens from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs):
|
||||
"""Initialize the LLM text processor.
|
||||
|
||||
Args:
|
||||
text_aggregator: An optional text aggregator to use for processing LLM text frames. By
|
||||
default, a SimpleTextAggregator aggregating by sentence will be used.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
|
||||
TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame.skip_tts)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._handle_llm_end()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_interruption(self, _):
|
||||
"""Handle interruptions by resetting the text aggregator."""
|
||||
await self._text_aggregator.handle_interruption()
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the internal state of the text processor and its aggregator."""
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
@@ -83,4 +83,4 @@ class ConsumerProcessor(FrameProcessor):
|
||||
while True:
|
||||
frame = await self._queue.get()
|
||||
new_frame = await self._transformer(frame)
|
||||
await self.queue_frame(new_frame, self._direction)
|
||||
await self.push_frame(new_frame, self._direction)
|
||||
|
||||
@@ -24,7 +24,6 @@ from typing import (
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
@@ -33,8 +32,6 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.audio.utils import calculate_audio_volume
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -707,29 +704,6 @@ class RTVITextMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotOutputMessageData(RTVITextMessageData):
|
||||
"""Data for bot output RTVI messages.
|
||||
|
||||
Extends RTVITextMessageData to include metadata about the output.
|
||||
"""
|
||||
|
||||
spoken: bool = False # Indicates if the text has been spoken by TTS
|
||||
aggregated_by: AggregationType | str
|
||||
# Indicates what form the text is in (e.g., by word, sentence, etc.)
|
||||
|
||||
|
||||
class RTVIBotOutputMessage(BaseModel):
|
||||
"""Message containing bot output text.
|
||||
|
||||
An event meant to holistically represent what the bot is outputting,
|
||||
along with metadata about the output and if it has been spoken.
|
||||
"""
|
||||
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-output"] = "bot-output"
|
||||
data: RTVIBotOutputMessageData
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
"""Message containing bot transcription text.
|
||||
|
||||
@@ -922,7 +896,6 @@ class RTVIObserverParams:
|
||||
Parameter `errors_enabled` is deprecated. Error messages are always enabled.
|
||||
|
||||
Parameters:
|
||||
bot_output_enabled: Indicates if bot output messages should be sent.
|
||||
bot_llm_enabled: Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled: Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled: Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
@@ -934,17 +907,9 @@ class RTVIObserverParams:
|
||||
metrics_enabled: Indicates if metrics messages should be sent.
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
bot_output_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
audio_level_period_secs: How often audio levels should be sent if enabled.
|
||||
"""
|
||||
|
||||
bot_output_enabled: bool = True
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
@@ -956,15 +921,6 @@ class RTVIObserverParams:
|
||||
metrics_enabled: bool = True
|
||||
system_logs_enabled: bool = False
|
||||
errors_enabled: Optional[bool] = None
|
||||
skip_aggregator_types: Optional[List[AggregationType | str]] = None
|
||||
bot_output_transforms: Optional[
|
||||
List[
|
||||
Tuple[
|
||||
AggregationType | str,
|
||||
Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
]
|
||||
]
|
||||
] = None
|
||||
audio_level_period_secs: float = 0.15
|
||||
|
||||
|
||||
@@ -1017,45 +973,8 @@ class RTVIObserver(BaseObserver):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._aggregation_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = self._params.bot_output_transforms or []
|
||||
|
||||
def add_bot_output_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" to
|
||||
handle all text before sending to the client.
|
||||
"""
|
||||
self._aggregation_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_bot_output_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._aggregation_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._aggregation_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _logger_sink(self, message):
|
||||
"""Logger sink so we can send system logs to RTVI clients."""
|
||||
"""Logger sink so we cna send system logs to RTVI clients."""
|
||||
message = RTVISystemLogMessage(data=RTVITextMessageData(text=message))
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
@@ -1129,15 +1048,12 @@ class RTVIObserver(BaseObserver):
|
||||
await self.send_rtvi_message(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, AggregatedTextFrame) and (
|
||||
self._params.bot_output_enabled or self._params.bot_tts_enabled
|
||||
):
|
||||
if isinstance(frame, TTSTextFrame) and not isinstance(src, BaseOutputTransport):
|
||||
# This check is to make sure we handle the frame when it has gone
|
||||
# through the transport and has correct timing.
|
||||
mark_as_seen = False
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.send_rtvi_message(message)
|
||||
else:
|
||||
await self._handle_aggregated_llm_text(frame)
|
||||
mark_as_seen = False
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
@@ -1168,6 +1084,15 @@ class RTVIObserver(BaseObserver):
|
||||
if mark_as_seen:
|
||||
self._frames_seen.add(frame.id)
|
||||
|
||||
async def _push_bot_transcription(self):
|
||||
"""Push accumulated bot transcription as a message."""
|
||||
if len(self._bot_transcription) > 0:
|
||||
message = RTVIBotTranscriptionMessage(
|
||||
data=RTVITextMessageData(text=self._bot_transcription)
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
self._bot_transcription = ""
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
"""Handle user speaking interruption frames."""
|
||||
message = None
|
||||
@@ -1190,45 +1115,14 @@ class RTVIObserver(BaseObserver):
|
||||
if message:
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
|
||||
"""Handle aggregated LLM text output frames."""
|
||||
# Skip certain aggregator types if configured to do so.
|
||||
if (
|
||||
self._params.skip_aggregator_types
|
||||
and frame.aggregated_by in self._params.skip_aggregator_types
|
||||
):
|
||||
return
|
||||
|
||||
text = frame.text
|
||||
type = frame.aggregated_by
|
||||
for aggregation_type, transform in self._aggregation_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
text = await transform(text, type)
|
||||
|
||||
isTTS = isinstance(frame, TTSTextFrame)
|
||||
if self._params.bot_output_enabled:
|
||||
message = RTVIBotOutputMessage(
|
||||
data=RTVIBotOutputMessageData(text=text, spoken=isTTS, aggregated_by=type)
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
if isTTS and self._params.bot_tts_enabled:
|
||||
tts_message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.send_rtvi_message(tts_message)
|
||||
|
||||
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
|
||||
"""Handle LLM text output frames."""
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
# TODO (mrkb): Remove all this logic when we fully deprecate bot-transcription messages.
|
||||
self._bot_transcription += frame.text
|
||||
|
||||
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
|
||||
await self.send_rtvi_message(
|
||||
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
|
||||
)
|
||||
self._bot_transcription = ""
|
||||
if match_endofsentence(self._bot_transcription):
|
||||
await self._push_bot_transcription()
|
||||
|
||||
async def _handle_user_transcriptions(self, frame: Frame):
|
||||
"""Handle user transcription frames."""
|
||||
@@ -1354,7 +1248,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
# Default to 0.3.0 which is the last version before actually having a
|
||||
# "client-version".
|
||||
self._client_version = [0, 3, 0]
|
||||
self._llm_skip_tts: bool = False # Keep in sync with llm_service.py's configuration.
|
||||
self._skip_tts: bool = False # Keep in sync with llm_service.py
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -1547,7 +1441,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
elif isinstance(frame, RTVIActionFrame):
|
||||
await self._action_queue.put(frame)
|
||||
elif isinstance(frame, LLMConfigureOutputFrame):
|
||||
self._llm_skip_tts = frame.skip_tts
|
||||
self._skip_tts = frame.skip_tts
|
||||
await self.push_frame(frame, direction)
|
||||
# Other frames
|
||||
else:
|
||||
@@ -1803,9 +1697,9 @@ class RTVIProcessor(FrameProcessor):
|
||||
opts = data.options if data.options is not None else RTVISendTextOptions()
|
||||
if opts.run_immediately:
|
||||
await self.interrupt_bot()
|
||||
cur_llm_skip_tts = self._llm_skip_tts
|
||||
cur_skip_tts = self._skip_tts
|
||||
should_skip_tts = not opts.audio_response
|
||||
toggle_skip_tts = cur_llm_skip_tts != should_skip_tts
|
||||
toggle_skip_tts = cur_skip_tts != should_skip_tts
|
||||
if toggle_skip_tts:
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=should_skip_tts)
|
||||
await self.push_frame(output_frame)
|
||||
@@ -1815,7 +1709,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
)
|
||||
await self.push_frame(text_frame)
|
||||
if toggle_skip_tts:
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=cur_llm_skip_tts)
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=cur_skip_tts)
|
||||
await self.push_frame(output_frame)
|
||||
|
||||
async def _handle_update_context(self, data: RTVIAppendToContextData):
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -25,10 +24,9 @@ from pipecat.frames.frames import (
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -100,9 +98,15 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._current_text_parts: List[TextPartForConcatenation] = []
|
||||
self._current_text_parts: List[str] = []
|
||||
self._aggregation_start_time: Optional[str] = None
|
||||
|
||||
# Whether to add spaces between text parts.
|
||||
# (The use of this could be expanded to the UserTranscriptProcessor in
|
||||
# the future if needed; currently the UserTranscriptProcessor assumes
|
||||
# that user transcription frames do not need aggregation).
|
||||
self._add_spaces = True
|
||||
|
||||
async def _emit_aggregated_text(self):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
@@ -143,7 +147,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
Result: "Hello there how are you"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = concatenate_aggregated_text(self._current_text_parts)
|
||||
content = concatenate_aggregated_text(self._current_text_parts, self._add_spaces)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
@@ -187,11 +191,11 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
if not self._aggregation_start_time:
|
||||
self._aggregation_start_time = time_now_iso8601()
|
||||
|
||||
self._current_text_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# Track whether we need to add spaces between text parts
|
||||
# Assumption: we can just keep track of the latest frame's value
|
||||
self._add_spaces = not frame.includes_inter_frame_spaces
|
||||
|
||||
self._current_text_parts.append(frame.text)
|
||||
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -308,267 +312,3 @@ class TranscriptProcessor:
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TurnAwareTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes transcripts with turn boundary awareness.
|
||||
|
||||
This processor combines user and assistant transcript tracking with turn
|
||||
detection, emitting events when turns start and end. It correctly handles
|
||||
interruptions by only capturing what was actually spoken.
|
||||
|
||||
Turn boundaries are detected based on:
|
||||
- User started speaking (UserStartedSpeakingFrame)
|
||||
- Bot stopped speaking (BotStoppedSpeakingFrame)
|
||||
- Interruptions (InterruptionFrame)
|
||||
|
||||
Events:
|
||||
on_turn_started: Emitted when a new turn begins.
|
||||
Handler signature: async def handler(processor, turn_number)
|
||||
|
||||
on_turn_ended: Emitted when a turn ends.
|
||||
Handler signature: async def handler(processor, turn_number,
|
||||
user_transcript, assistant_transcript,
|
||||
was_interrupted)
|
||||
|
||||
on_transcript_update: Inherited from BaseTranscriptProcessor, emitted for
|
||||
individual transcript messages.
|
||||
|
||||
Example::
|
||||
|
||||
turn_processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
@turn_processor.event_handler("on_turn_started")
|
||||
async def handle_turn_started(processor, turn_number):
|
||||
print(f"Turn {turn_number} started")
|
||||
|
||||
@turn_processor.event_handler("on_turn_ended")
|
||||
async def handle_turn_ended(processor, turn_number, user_text, assistant_text, interrupted):
|
||||
print(f"Turn {turn_number} ended")
|
||||
print(f"User said: {user_text}")
|
||||
print(f"Assistant said: {assistant_text}")
|
||||
print(f"Was interrupted: {interrupted}")
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
stt,
|
||||
turn_processor,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the turn-aware transcript processor.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Turn tracking state
|
||||
self._turn_number = 0
|
||||
self._turn_active = False
|
||||
self._turn_start_time: Optional[str] = None
|
||||
|
||||
# Accumulate text for current turn
|
||||
self._current_turn_user_parts: List[TextPartForConcatenation] = []
|
||||
self._current_turn_assistant_parts: List[TextPartForConcatenation] = []
|
||||
|
||||
# Track bot speaking state
|
||||
self._bot_is_speaking = False
|
||||
|
||||
# Register turn events
|
||||
self._register_event_handler("on_turn_started")
|
||||
self._register_event_handler("on_turn_ended")
|
||||
|
||||
async def _start_turn(self):
|
||||
"""Start a new turn."""
|
||||
if not self._turn_active:
|
||||
self._turn_number += 1
|
||||
self._turn_active = True
|
||||
self._turn_start_time = time_now_iso8601()
|
||||
self._current_turn_user_parts = []
|
||||
self._current_turn_assistant_parts = []
|
||||
|
||||
logger.debug(f"Turn {self._turn_number} started")
|
||||
await self._call_event_handler("on_turn_started", self._turn_number)
|
||||
|
||||
async def _end_turn(self, was_interrupted: bool = False):
|
||||
"""End the current turn and emit aggregated transcripts.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn ended due to an interruption.
|
||||
"""
|
||||
if not self._turn_active:
|
||||
return
|
||||
|
||||
# Aggregate user text
|
||||
user_transcript = ""
|
||||
if self._current_turn_user_parts:
|
||||
user_transcript = concatenate_aggregated_text(self._current_turn_user_parts)
|
||||
|
||||
# Aggregate assistant text
|
||||
assistant_transcript = ""
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_transcript = concatenate_aggregated_text(self._current_turn_assistant_parts)
|
||||
|
||||
# Emit turn ended event
|
||||
logger.debug(
|
||||
f"Turn {self._turn_number} ended (interrupted={was_interrupted}). "
|
||||
f"User: '{user_transcript}', Assistant: '{assistant_transcript}'"
|
||||
)
|
||||
await self._call_event_handler(
|
||||
"on_turn_ended",
|
||||
self._turn_number,
|
||||
user_transcript,
|
||||
assistant_transcript,
|
||||
was_interrupted,
|
||||
)
|
||||
|
||||
# Reset turn state
|
||||
self._turn_active = False
|
||||
self._current_turn_user_parts = []
|
||||
self._current_turn_assistant_parts = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for turn-aware transcript tracking.
|
||||
|
||||
Handles:
|
||||
- UserStartedSpeakingFrame: Start new turn
|
||||
- TranscriptionFrame: Accumulate user speech and emit transcript message
|
||||
- BotStartedSpeakingFrame: Track bot speaking state
|
||||
- TTSTextFrame: Accumulate assistant speech
|
||||
- BotStoppedSpeakingFrame: End turn if no interruption pending
|
||||
- InterruptionFrame: End turn immediately as interrupted
|
||||
- EndFrame/CancelFrame: End any active turn
|
||||
|
||||
Args:
|
||||
frame: Input frame to process.
|
||||
direction: Frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
# User started speaking
|
||||
if self._bot_is_speaking:
|
||||
# This is an interruption - end the current turn with what was spoken
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self._end_turn(was_interrupted=True)
|
||||
self._bot_is_speaking = False
|
||||
elif self._turn_active:
|
||||
# Previous turn is ending normally (bot finished speaking)
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self._end_turn(was_interrupted=False)
|
||||
|
||||
# Start a new turn
|
||||
await self._start_turn()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
# Accumulate user speech for the current turn
|
||||
if self._turn_active:
|
||||
self._current_turn_user_parts.append(
|
||||
TextPartForConcatenation(frame.text, includes_inter_part_spaces=True)
|
||||
)
|
||||
|
||||
# Also emit individual transcript message
|
||||
message = TranscriptionMessage(
|
||||
role="user",
|
||||
user_id=frame.user_id,
|
||||
content=frame.text,
|
||||
timestamp=frame.timestamp,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
# Bot started speaking
|
||||
self._bot_is_speaking = True
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
# Accumulate assistant speech for the current turn
|
||||
if self._turn_active:
|
||||
self._current_turn_assistant_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
# Bot stopped speaking - just mark it, don't end turn yet
|
||||
# Turn will end when next user speaks or pipeline ends
|
||||
self._bot_is_speaking = False
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
# Emit assistant transcript message with what was spoken before interruption
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(self._current_turn_assistant_parts)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
# Push frame first to ensure proper cleanup
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# End turn as interrupted
|
||||
await self._end_turn(was_interrupted=True)
|
||||
self._bot_is_speaking = False
|
||||
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
# Pipeline ending - finalize any active turn
|
||||
if self._turn_active:
|
||||
# Emit any pending assistant transcript (allow time for TTSTextFrames to be processed)
|
||||
# Give a brief moment for any pending frames to process
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
if self._current_turn_assistant_parts:
|
||||
assistant_content = concatenate_aggregated_text(
|
||||
self._current_turn_assistant_parts
|
||||
)
|
||||
if assistant_content:
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
timestamp=self._turn_start_time or time_now_iso8601(),
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
await self._end_turn(was_interrupted=isinstance(frame, CancelFrame))
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -264,10 +264,7 @@ def _setup_webrtc_routes(
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
|
||||
runner_args = SmallWebRTCRunnerArguments(
|
||||
webrtc_connection=connection, body=request.request_data
|
||||
)
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
@@ -329,8 +326,7 @@ def _setup_webrtc_routes(
|
||||
type=request_data["type"],
|
||||
pc_id=request_data.get("pc_id"),
|
||||
restart_pc=request_data.get("restart_pc"),
|
||||
request_data=request_data.get("request_data")
|
||||
or request_data.get("requestData"),
|
||||
request_data=request_data,
|
||||
)
|
||||
return await offer(webrtc_request, background_tasks)
|
||||
elif request.method == HTTPMethod.PATCH.value:
|
||||
|
||||
@@ -281,14 +281,6 @@ async def maybe_capture_participant_camera(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
if isinstance(transport, SmallWebRTCTransport):
|
||||
await transport.capture_participant_video(video_source="camera")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
async def maybe_capture_participant_screen(
|
||||
transport: BaseTransport, client: Any, framerate: int = 0
|
||||
@@ -311,14 +303,6 @@ async def maybe_capture_participant_screen(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
if isinstance(transport, SmallWebRTCTransport):
|
||||
await transport.capture_participant_video(video_source="screenVideo")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _smallwebrtc_sdp_cleanup_ice_candidates(text: str, pattern: str) -> str:
|
||||
"""Clean up ICE candidates in SDP text for SmallWebRTC.
|
||||
|
||||
@@ -152,6 +152,9 @@ class AIService(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() may wait on tasks to
|
||||
# finish and downstream processors need to receive the EndFrame.
|
||||
await self.push_frame(frame, direction)
|
||||
await self.stop(frame)
|
||||
|
||||
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
|
||||
|
||||
@@ -373,7 +373,9 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
if hasattr(event.delta, "text"):
|
||||
await self.push_frame(LLMTextFrame(event.delta.text))
|
||||
frame = LLMTextFrame(event.delta.text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
completion_tokens_estimate += self._estimate_tokens(event.delta.text)
|
||||
elif hasattr(event.delta, "partial_json") and tool_use_block:
|
||||
json_accumulator += event.delta.partial_json
|
||||
|
||||
@@ -146,6 +146,15 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
@@ -424,6 +433,15 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
|
||||
@@ -8,10 +8,8 @@ import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .agent_core import *
|
||||
from .llm import *
|
||||
from .nova_sonic import *
|
||||
from .sagemaker import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""AWS AgentCore Processor Module.
|
||||
|
||||
This module defines the AWSAgentCoreProcessor, which invokes agents hosted on
|
||||
Amazon Bedrock AgentCore Runtime and streams their responses as LLMTextFrames.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aioboto3
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
def default_context_to_payload_transformer(
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
) -> Optional[str]:
|
||||
"""Default transformer to create AgentCore payload from LLM context.
|
||||
|
||||
Extracts the latest user or system message text and wraps it in {"prompt": "<text>"}.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation messages.
|
||||
|
||||
Returns:
|
||||
A JSON string payload for AgentCore, or None if no valid message found.
|
||||
"""
|
||||
messages = context.messages
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if isinstance(last_message, LLMSpecificMessage) or last_message.get("role") not in (
|
||||
"user",
|
||||
"system",
|
||||
):
|
||||
return None
|
||||
|
||||
content = last_message.get("content")
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if isinstance(content, str):
|
||||
prompt = content
|
||||
elif isinstance(content, list):
|
||||
prompt = " ".join([part.get("text", "") for part in content])
|
||||
else:
|
||||
return None
|
||||
|
||||
return json.dumps({"prompt": prompt})
|
||||
|
||||
|
||||
def default_response_to_output_transformer(response_line: str) -> Optional[str]:
|
||||
"""Default transformer to extract output text from AgentCore response.
|
||||
|
||||
Expects responses with {"response": "<text>"} format.
|
||||
|
||||
Args:
|
||||
response_line: The raw response line from AgentCore (without "data: " prefix).
|
||||
|
||||
Returns:
|
||||
The extracted output text, or None if no text found.
|
||||
"""
|
||||
response_json = json.loads(response_line)
|
||||
return response_json.get("response")
|
||||
|
||||
|
||||
class AWSAgentCoreProcessor(FrameProcessor):
|
||||
"""Processor that runs an Amazon Bedrock AgentCore agent.
|
||||
|
||||
Input:
|
||||
- LLMContextFrame: Supplies a context used to invoke the agent.
|
||||
|
||||
Output:
|
||||
- LLMTextFrame: The agent's text response(s).
|
||||
A single agent invocation may result in multiple text frames.
|
||||
|
||||
This processor transforms the input context to a payload for the AgentCore
|
||||
agent, and transforms the agent's response(s) into output text frame(s). Both
|
||||
mappings are configurable via transformers. Below is the default behavior.
|
||||
|
||||
Input transformer (context_to_payload_transformer):
|
||||
- Grabs the latest user or system message (if it's the latest message)
|
||||
- Extracts its text content
|
||||
- Constructs a payload that looks like {"prompt": "<text>"}
|
||||
|
||||
Output transformer (response_to_output_transformer):
|
||||
- Expects responses that look like {"response": "<text>"}
|
||||
- Extracts the text for use in the LLMTextFrame(s)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agentArn: str,
|
||||
aws_access_key: Optional[str] = None,
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: Optional[str] = None,
|
||||
context_to_payload_transformer: Optional[
|
||||
Callable[[LLMContext | OpenAILLMContext], Optional[str]]
|
||||
] = None,
|
||||
response_to_output_transformer: Optional[Callable[[str], Optional[str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the AWS AgentCore processor.
|
||||
|
||||
Args:
|
||||
agentArn: The Amazon Web Services Resource Name (ARN) of the agent.
|
||||
aws_access_key: AWS access key ID. If None, uses default credentials.
|
||||
aws_secret_key: AWS secret access key. If None, uses default credentials.
|
||||
aws_session_token: AWS session token for temporary credentials.
|
||||
aws_region: AWS region.
|
||||
context_to_payload_transformer: Optional callable to transform
|
||||
LLMContext into AgentCore payload string. If None, uses
|
||||
default_context_to_payload_transformer.
|
||||
response_to_output_transformer: Optional callable to extract output text
|
||||
from AgentCore response. If None, uses
|
||||
default_response_to_output_transformer.
|
||||
**kwargs: Additional arguments passed to parent FrameProcessor.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._agentArn = agentArn
|
||||
self._aws_session = aioboto3.Session()
|
||||
|
||||
# Store AWS session parameters for creating client in async context
|
||||
self._aws_params = {
|
||||
"aws_access_key_id": aws_access_key or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": aws_secret_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region_name": aws_region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
# Set transformers with defaults
|
||||
self._context_to_payload_transformer = (
|
||||
context_to_payload_transformer or default_context_to_payload_transformer
|
||||
)
|
||||
self._response_to_output_transformer = (
|
||||
response_to_output_transformer or default_response_to_output_transformer
|
||||
)
|
||||
|
||||
# State for managing output response bookends
|
||||
self._output_response_open = False
|
||||
self._last_text_frame_time: Optional[float] = None
|
||||
self._close_task: Optional[asyncio.Task] = None
|
||||
self._output_response_timeout = 1.0 # seconds
|
||||
|
||||
async def _close_output_response_after_timeout(self):
|
||||
"""Close the output response after timeout if no new text frames arrive."""
|
||||
await asyncio.sleep(self._output_response_timeout)
|
||||
if self._output_response_open:
|
||||
self._output_response_open = False
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _push_text_frame(self, text: str):
|
||||
"""Push a text frame, managing output response bookends."""
|
||||
# Cancel any pending close task
|
||||
if self._close_task and not self._close_task.done():
|
||||
await self.cancel_task(self._close_task)
|
||||
|
||||
# Open output response if needed
|
||||
if not self._output_response_open:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
self._output_response_open = True
|
||||
|
||||
# Push the text frame
|
||||
await self.push_frame(LLMTextFrame(text))
|
||||
self._last_text_frame_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Schedule closing the output response after timeout
|
||||
self._close_task = self.create_task(self._close_output_response_after_timeout())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and handle LLM message frames.
|
||||
|
||||
Args:
|
||||
frame: The incoming frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
# Create payload to invoke AgentCore agent
|
||||
payload = self._context_to_payload_transformer(frame.context)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
async with self._aws_session.client("bedrock-agentcore", **self._aws_params) as client:
|
||||
# Invoke the AgentCore agent
|
||||
response = await client.invoke_agent_runtime(
|
||||
agentRuntimeArn=self._agentArn, payload=payload.encode()
|
||||
)
|
||||
|
||||
# Determine if this is a streamed multi-part response, which
|
||||
# will affect our parsing
|
||||
is_multi_part_response = "text/event-stream" in response.get("contentType", "")
|
||||
|
||||
# Handle each response part (there may be one, for single
|
||||
# responses, or multiple, for streamed multi-part responses)
|
||||
async for part in response.get("response", []):
|
||||
part_string = part.decode("utf-8")
|
||||
|
||||
# In streamed multi-part responses, each part might have
|
||||
# one or more lines, each of which starts with "data: ".
|
||||
# Treat each line as a response.
|
||||
if is_multi_part_response:
|
||||
for line in part_string.split("\n"):
|
||||
# Get response text from this line
|
||||
if not line:
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
logger.warning(f"Expected line to start with 'data: ', got: {line}")
|
||||
continue
|
||||
line = line[6:] # omit "data: "
|
||||
|
||||
# Transform response line to output text
|
||||
text = self._response_to_output_transformer(line)
|
||||
if text:
|
||||
await self._push_text_frame(text)
|
||||
|
||||
# In single-part responses, the whole part is one response
|
||||
# and there's no "data: " prefix
|
||||
else:
|
||||
# Transform response part string to output text
|
||||
text = self._response_to_output_transformer(part_string)
|
||||
if text:
|
||||
await self._push_text_frame(text)
|
||||
|
||||
# Final close if output response is still open after all parts processed
|
||||
if self._output_response_open:
|
||||
if self._close_task and not self._close_task.done():
|
||||
await self.cancel_task(self._close_task)
|
||||
self._output_response_open = False
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -1078,7 +1078,9 @@ class AWSBedrockLLMService(LLMService):
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
await self.push_frame(LLMTextFrame(delta["text"]))
|
||||
frame = LLMTextFrame(delta["text"])
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
completion_tokens_estimate += self._estimate_tokens(delta["text"])
|
||||
elif "toolUse" in delta and "input" in delta["toolUse"]:
|
||||
# Handle partial JSON for tool use
|
||||
|
||||
@@ -27,7 +27,6 @@ from pydantic import BaseModel, Field
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -1028,7 +1027,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
logger.debug(f"Assistant response text added: {text}")
|
||||
|
||||
# Report the text of the assistant response.
|
||||
frame = TTSTextFrame(text, aggregated_by=AggregationType.SENTENCE)
|
||||
frame = TTSTextFrame(text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -1063,9 +1062,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# TTSTextFrame would be ignored otherwise (the interruption frame
|
||||
# would have cleared the assistant aggregator state).
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
frame = TTSTextFrame(
|
||||
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
|
||||
)
|
||||
frame = TTSTextFrame(self._assistant_text_buffer)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
self._may_need_repush_assistant_text = False
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""AWS SageMaker bidirectional streaming client.
|
||||
|
||||
This module provides a client for streaming bidirectional communication with
|
||||
SageMaker endpoints using the HTTP/2 protocol. Supports sending audio, text,
|
||||
and JSON data to SageMaker model endpoints and receiving streaming responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from aws_sdk_sagemaker_runtime_http2.client import SageMakerRuntimeHTTP2Client
|
||||
from aws_sdk_sagemaker_runtime_http2.config import Config, HTTPAuthSchemeResolver
|
||||
from aws_sdk_sagemaker_runtime_http2.models import (
|
||||
InvokeEndpointWithBidirectionalStreamInput,
|
||||
RequestPayloadPart,
|
||||
RequestStreamEventPayloadPart,
|
||||
ResponseStreamEvent,
|
||||
)
|
||||
from smithy_aws_core.auth.sigv4 import SigV4AuthScheme
|
||||
from smithy_aws_core.identity import EnvironmentCredentialsResolver
|
||||
from smithy_core.aio.eventstream import DuplexEventStream
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use SageMaker BiDi client, you need to `pip install pipecat-ai[sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class SageMakerBidiClient:
|
||||
"""Client for bidirectional streaming with AWS SageMaker endpoints.
|
||||
|
||||
Handles low-level HTTP/2 bidirectional streaming protocol for communicating
|
||||
with SageMaker model endpoints. Provides methods for sending various data
|
||||
types (audio, text, JSON) and receiving streaming responses.
|
||||
|
||||
This client uses AWS SigV4 authentication and supports credential resolution
|
||||
from environment variables, AWS CLI configuration, and instance metadata.
|
||||
|
||||
Example::
|
||||
|
||||
client = SageMakerBidiClient(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string="model=nova-3&language=en"
|
||||
)
|
||||
await client.start_session()
|
||||
await client.send_audio_chunk(audio_bytes)
|
||||
response = await client.receive_response()
|
||||
await client.close_session()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
model_invocation_path: str = "",
|
||||
model_query_string: str = "",
|
||||
):
|
||||
"""Initialize the SageMaker BiDi client.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint to connect to.
|
||||
region: AWS region where the endpoint is deployed.
|
||||
model_invocation_path: API path for the model invocation (e.g., "v1/listen").
|
||||
model_query_string: Query string parameters for the model (e.g., "model=nova-3").
|
||||
"""
|
||||
self.endpoint_name = endpoint_name
|
||||
self.region = region
|
||||
self.model_invocation_path = model_invocation_path
|
||||
self.model_query_string = model_query_string
|
||||
self.bidi_endpoint = f"https://runtime.sagemaker.{region}.amazonaws.com:8443"
|
||||
self._client: Optional[SageMakerRuntimeHTTP2Client] = None
|
||||
self._stream: Optional[
|
||||
DuplexEventStream[RequestStreamEventPayloadPart, ResponseStreamEvent, any]
|
||||
] = None
|
||||
self._output_stream = None
|
||||
self._is_active = False
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the SageMaker Runtime HTTP2 client with AWS credentials.
|
||||
|
||||
Creates and configures the SageMaker Runtime HTTP2 client with SigV4
|
||||
authentication. Attempts to resolve AWS credentials from environment
|
||||
variables, AWS CLI configuration, or instance metadata.
|
||||
"""
|
||||
logger.debug(f"Initializing SageMaker BiDi client for region: {self.region}")
|
||||
logger.debug(f"Using endpoint URI: {self.bidi_endpoint}")
|
||||
|
||||
# Check for AWS credentials
|
||||
has_env_creds = bool(os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"))
|
||||
|
||||
if not has_env_creds:
|
||||
logger.warning(
|
||||
"AWS credentials not found in environment variables. "
|
||||
"Attempting to use EnvironmentCredentialsResolver which will check "
|
||||
"AWS CLI configuration and instance metadata."
|
||||
)
|
||||
|
||||
config = Config(
|
||||
endpoint_uri=self.bidi_endpoint,
|
||||
region=self.region,
|
||||
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
|
||||
auth_scheme_resolver=HTTPAuthSchemeResolver(),
|
||||
auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="sagemaker")},
|
||||
)
|
||||
self._client = SageMakerRuntimeHTTP2Client(config=config)
|
||||
|
||||
async def start_session(self):
|
||||
"""Start a bidirectional streaming session with the SageMaker endpoint.
|
||||
|
||||
Initializes the client if needed, creates the bidirectional stream, and
|
||||
establishes the connection to the SageMaker endpoint. Must be called
|
||||
before sending or receiving data.
|
||||
|
||||
Returns:
|
||||
The output stream for receiving responses.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client initialization or connection fails.
|
||||
"""
|
||||
if not self._client:
|
||||
self._initialize_client()
|
||||
|
||||
logger.debug(f"Starting BiDi session with endpoint: {self.endpoint_name}")
|
||||
logger.debug(f"Model invocation path: {self.model_invocation_path}")
|
||||
logger.debug(f"Model query string: {self.model_query_string}")
|
||||
|
||||
# Create the bidirectional stream
|
||||
stream_input = InvokeEndpointWithBidirectionalStreamInput(
|
||||
endpoint_name=self.endpoint_name,
|
||||
model_invocation_path=self.model_invocation_path,
|
||||
model_query_string=self.model_query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
self._stream = await self._client.invoke_endpoint_with_bidirectional_stream(
|
||||
stream_input
|
||||
)
|
||||
self._is_active = True
|
||||
|
||||
# Get output stream
|
||||
output = await self._stream.await_output()
|
||||
self._output_stream = output[1]
|
||||
|
||||
logger.debug("BiDi session started successfully")
|
||||
return self._output_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start BiDi session: {e}")
|
||||
self._is_active = False
|
||||
raise RuntimeError(f"Failed to start SageMaker BiDi session: {e}")
|
||||
|
||||
async def send_data(self, data_bytes: bytes, data_type: Optional[str] = None):
|
||||
"""Send a chunk of data to the stream.
|
||||
|
||||
Generic method for sending any type of data to the SageMaker endpoint.
|
||||
Use the convenience methods (send_audio_chunk, send_text, send_json)
|
||||
for common data types.
|
||||
|
||||
Args:
|
||||
data_bytes: Raw bytes to send.
|
||||
data_type: Optional data type header. Common values are "BINARY" for
|
||||
audio/binary data and "UTF8" for text/JSON data.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
if not self._is_active or not self._stream:
|
||||
raise RuntimeError("BiDi session not active")
|
||||
|
||||
try:
|
||||
payload = RequestPayloadPart(bytes_=data_bytes, data_type=data_type)
|
||||
event = RequestStreamEventPayloadPart(value=payload)
|
||||
await self._stream.input_stream.send(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send data: {e}")
|
||||
raise
|
||||
|
||||
async def send_audio_chunk(self, audio_bytes: bytes):
|
||||
"""Send a chunk of audio data to the stream.
|
||||
|
||||
Convenience method for sending audio data. Automatically sets the data
|
||||
type to "BINARY".
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio bytes to send (e.g., PCM audio data).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
await self.send_data(audio_bytes, data_type="BINARY")
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Send text data to the stream.
|
||||
|
||||
Convenience method for sending text data. Automatically encodes the text
|
||||
as UTF-8 and sets the data type to "UTF8".
|
||||
|
||||
Args:
|
||||
text: Text string to send.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
await self.send_data(text.encode("utf-8"), data_type="UTF8")
|
||||
|
||||
async def send_json(self, data: dict):
|
||||
"""Send JSON data to the stream.
|
||||
|
||||
Convenience method for sending JSON-encoded messages. Useful for control
|
||||
messages like KeepAlive or CloseStream. Automatically serializes the
|
||||
dictionary to JSON, encodes as UTF-8, and sets the data type to "UTF8".
|
||||
|
||||
Args:
|
||||
data: Dictionary to send as JSON (e.g., {"type": "KeepAlive"}).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active or send fails.
|
||||
"""
|
||||
import json
|
||||
|
||||
await self.send_data(json.dumps(data).encode("utf-8"), data_type="UTF8")
|
||||
|
||||
async def receive_response(self) -> Optional[ResponseStreamEvent]:
|
||||
"""Receive a response from the stream.
|
||||
|
||||
Blocks until a response is available from the SageMaker endpoint. Returns
|
||||
None when the stream is closed.
|
||||
|
||||
Returns:
|
||||
The response event containing payload data, or None if stream is closed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is not active.
|
||||
"""
|
||||
if not self._is_active or not self._output_stream:
|
||||
raise RuntimeError("BiDi session not active")
|
||||
|
||||
try:
|
||||
result = await self._output_stream.receive()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to receive response: {e}")
|
||||
raise
|
||||
|
||||
async def close_session(self):
|
||||
"""Close the bidirectional streaming session.
|
||||
|
||||
Gracefully closes the input stream and marks the session as inactive.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if not self._is_active:
|
||||
return
|
||||
|
||||
logger.debug("Closing BiDi session...")
|
||||
self._is_active = False
|
||||
|
||||
try:
|
||||
if self._stream:
|
||||
await self._stream.input_stream.close()
|
||||
logger.debug("BiDi session closed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing BiDi session: {e}")
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if the session is currently active.
|
||||
|
||||
Returns:
|
||||
True if session is active, False otherwise.
|
||||
"""
|
||||
return self._is_active
|
||||
@@ -209,6 +209,15 @@ class AWSPollyTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AWS TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AWS's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to AWS Polly language format.
|
||||
|
||||
|
||||
@@ -151,6 +151,15 @@ class AzureBaseTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Azure TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Azure's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Azure language format.
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Literal, Optional
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -126,72 +125,6 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class CartesiaEmotion(str, Enum):
|
||||
"""Predefined Emotions supported by Cartesia."""
|
||||
|
||||
# Primary emotions supported by Cartesia
|
||||
NEUTRAL = "neutral"
|
||||
ANGRY = "angry"
|
||||
EXCITED = "excited"
|
||||
CONTENT = "content"
|
||||
SAD = "sad"
|
||||
SCARED = "scared"
|
||||
# Additional emotions supported by Cartesia
|
||||
HAPPY = "happy"
|
||||
ENTHUSIASTIC = "enthusiastic"
|
||||
ELATED = "elated"
|
||||
EUPHORIC = "euphoric"
|
||||
TRIUMPHANT = "triumphant"
|
||||
AMAZED = "amazed"
|
||||
SURPRISED = "surprised"
|
||||
FLIRTATIOUS = "flirtatious"
|
||||
JOKING_COMEDIC = "joking/comedic"
|
||||
CURIOUS = "curious"
|
||||
PEACEFUL = "peaceful"
|
||||
SERENE = "serene"
|
||||
CALM = "calm"
|
||||
GRATEFUL = "grateful"
|
||||
AFFECTIONATE = "affectionate"
|
||||
TRUST = "trust"
|
||||
SYMPATHETIC = "sympathetic"
|
||||
ANTICIPATION = "anticipation"
|
||||
MYSTERIOUS = "mysterious"
|
||||
MAD = "mad"
|
||||
OUTRAGED = "outraged"
|
||||
FRUSTRATED = "frustrated"
|
||||
AGITATED = "agitated"
|
||||
THREATENED = "threatened"
|
||||
DISGUSTED = "disgusted"
|
||||
CONTEMPT = "contempt"
|
||||
ENVIOUS = "envious"
|
||||
SARCASTIC = "sarcastic"
|
||||
IRONIC = "ironic"
|
||||
DEJECTED = "dejected"
|
||||
MELANCHOLIC = "melancholic"
|
||||
DISAPPOINTED = "disappointed"
|
||||
HURT = "hurt"
|
||||
GUILTY = "guilty"
|
||||
BORED = "bored"
|
||||
TIRED = "tired"
|
||||
REJECTED = "rejected"
|
||||
NOSTALGIC = "nostalgic"
|
||||
WISTFUL = "wistful"
|
||||
APOLOGETIC = "apologetic"
|
||||
HESITANT = "hesitant"
|
||||
INSECURE = "insecure"
|
||||
CONFUSED = "confused"
|
||||
RESIGNED = "resigned"
|
||||
ANXIOUS = "anxious"
|
||||
PANICKED = "panicked"
|
||||
ALARMED = "alarmed"
|
||||
PROUD = "proud"
|
||||
CONFIDENT = "confident"
|
||||
DISTANT = "distant"
|
||||
SKEPTICAL = "skeptical"
|
||||
CONTEMPLATIVE = "contemplative"
|
||||
DETERMINED = "determined"
|
||||
|
||||
|
||||
class CartesiaTTSService(AudioContextWordTTSService):
|
||||
"""Cartesia TTS service with WebSocket streaming and word timestamps.
|
||||
|
||||
@@ -249,10 +182,6 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
container: Audio container format.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
@@ -271,18 +200,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
text_aggregator=text_aggregator,
|
||||
text_aggregator=text_aggregator or SkipTagsAggregator([("<spell>", "</spell>")]),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
|
||||
|
||||
params = params or CartesiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
@@ -336,27 +257,6 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
# A set of Cartesia-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Cartesia spell tag."""
|
||||
return f"<spell>{text}</spell>"
|
||||
|
||||
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
|
||||
"""Convenience method to create an emotion tag."""
|
||||
return f'<emotion value="{emotion}" />'
|
||||
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f'<break time="{seconds}s" />'
|
||||
|
||||
def VOLUME_TAG(volume: float) -> str:
|
||||
"""Convenience method to create a volume tag."""
|
||||
return f'<volume ratio="{volume}" />'
|
||||
|
||||
def SPEED_TAG(speed: float) -> str:
|
||||
"""Convenience method to create a speed tag."""
|
||||
return f'<speed ratio="{speed}" />'
|
||||
|
||||
def _is_cjk_language(self, language: str) -> bool:
|
||||
"""Check if the given language is CJK (Chinese, Japanese, Korean).
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@
|
||||
|
||||
"""Deepgram Flux speech-to-text service implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
@@ -96,7 +94,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
mip_opt_out: Optional. Opts out requests from the Deepgram Model Improvement Program
|
||||
(default False).
|
||||
tag: List of tags to label requests for identification during usage reporting.
|
||||
min_confidence: Optional. Minimum confidence required confidence to create a TranscriptionFrame
|
||||
"""
|
||||
|
||||
eager_eot_threshold: Optional[float] = None
|
||||
@@ -105,7 +102,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
keyterm: list = []
|
||||
mip_opt_out: Optional[bool] = None
|
||||
tag: list = []
|
||||
min_confidence: Optional[float] = None # New parameter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -167,13 +163,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
self._register_event_handler("on_end_of_turn")
|
||||
self._register_event_handler("on_eager_end_of_turn")
|
||||
self._register_event_handler("on_update")
|
||||
self._connection_established_event = asyncio.Event()
|
||||
# Watchdog task to prevent dangling tasks
|
||||
# If we stop sending audio to Flux after we have received that the User has started speaking
|
||||
# we never receive the user stopped speaking event unless we resume sending audio to it.
|
||||
self._last_stt_time = None
|
||||
self._watchdog_task = None
|
||||
self._user_is_speaking = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to WebSocket and start background tasks.
|
||||
@@ -183,14 +172,9 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
# Creating the receiver task (only created once during initial connection)
|
||||
if not self._receive_task:
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
# Creating the watchdog task (only created once during initial connection)
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from WebSocket and clean up tasks.
|
||||
|
||||
@@ -198,7 +182,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
and cleans up resources to prevent memory leaks.
|
||||
"""
|
||||
try:
|
||||
# Cancel background tasks BEFORE closing websocket
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=2.0)
|
||||
self._receive_task = None
|
||||
|
||||
# Now close the websocket
|
||||
await self._disconnect_websocket()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
@@ -206,25 +197,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
# Reset state only after everything is cleaned up
|
||||
self._websocket = None
|
||||
|
||||
async def _send_silence(self, duration_secs: float = 0.5):
|
||||
"""Send a block of silence of the specified duration (default 500 ms)."""
|
||||
sample_width = 2 # bytes per sample for 16-bit PCM
|
||||
num_channels = 1 # mono
|
||||
num_samples = int(self.sample_rate * duration_secs)
|
||||
silence = b"\x00" * (num_samples * sample_width * num_channels)
|
||||
await self._websocket.send(silence)
|
||||
|
||||
async def _watchdog_task_handler(self):
|
||||
while self._websocket and self._websocket.state is State.OPEN:
|
||||
now = time.monotonic()
|
||||
# More than 500 ms without sending new audio to Flux
|
||||
if self._user_is_speaking and self._last_stt_time and now - self._last_stt_time > 0.5:
|
||||
logger.warning("Sending silence to Flux to prevent dangling task")
|
||||
await self._send_silence()
|
||||
self._last_stt_time = time.monotonic()
|
||||
# check every 100ms
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish WebSocket connection to API.
|
||||
|
||||
@@ -236,16 +208,10 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
self._connection_established_event.clear()
|
||||
self._user_is_speaking = False
|
||||
self._websocket = await websocket_connect(
|
||||
self._websocket_url,
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
# Now wait for the connection established event
|
||||
logger.debug("WebSocket connected, waiting for server confirmation...")
|
||||
await self._connection_established_event.wait()
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
@@ -261,16 +227,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
metrics collection. Handles disconnection errors gracefully.
|
||||
"""
|
||||
try:
|
||||
# Cancel background tasks BEFORE closing websocket
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=2.0)
|
||||
self._receive_task = None
|
||||
if self._watchdog_task:
|
||||
await self.cancel_task(self._watchdog_task, timeout=2.0)
|
||||
self._watchdog_task = None
|
||||
self._last_stt_time = None
|
||||
|
||||
self._connection_established_event.clear()
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
@@ -384,8 +340,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
return
|
||||
|
||||
try:
|
||||
self._last_stt_time = time.monotonic()
|
||||
await self.send_with_retry(audio, self._report_error)
|
||||
await self._websocket.send(audio)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
@@ -508,8 +463,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
transcription processing.
|
||||
"""
|
||||
logger.info("Connected to Flux - ready to stream audio")
|
||||
# Notify connection is established
|
||||
self._connection_established_event.set()
|
||||
|
||||
async def _handle_fatal_error(self, data: Dict[str, Any]):
|
||||
"""Handle fatal error messages from Deepgram Flux.
|
||||
@@ -577,7 +530,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
transcript: maybe the first few words of the turn.
|
||||
"""
|
||||
logger.debug("User started speaking")
|
||||
self._user_is_speaking = True
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.start_metrics()
|
||||
@@ -598,22 +550,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.trace(f"Received event TurnResumed: {event}")
|
||||
await self._call_event_handler("on_turn_resumed")
|
||||
|
||||
def _calculate_average_confidence(self, transcript_data) -> Optional[float]:
|
||||
"""Calculate the average confidence from transcript data.
|
||||
|
||||
Return None if the data is missing or invalid.
|
||||
"""
|
||||
# Example: Assume transcript_data has a list of words with confidence
|
||||
words = transcript_data.get("words")
|
||||
if not words or not isinstance(words, list):
|
||||
return None
|
||||
confidences = [
|
||||
w.get("confidence") for w in words if isinstance(w.get("confidence"), (float, int))
|
||||
]
|
||||
if not confidences:
|
||||
return None
|
||||
return sum(confidences) / len(confidences)
|
||||
|
||||
async def _handle_end_of_turn(self, transcript: str, data: Dict[str, Any]):
|
||||
"""Handle EndOfTurn events from Deepgram Flux.
|
||||
|
||||
@@ -633,26 +569,16 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
data: The TurnInfo message data containing event type, transcript and some extra metadata.
|
||||
"""
|
||||
logger.debug("User stopped speaking")
|
||||
self._user_is_speaking = False
|
||||
|
||||
# Compute the average confidence
|
||||
average_confidence = self._calculate_average_confidence(data)
|
||||
|
||||
if not self._params.min_confidence or average_confidence > self._params.min_confidence:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
)
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcription confidence below min_confidence threshold: {average_confidence}"
|
||||
)
|
||||
|
||||
)
|
||||
await self._handle_transcription(transcript, True, self._language)
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(UserStoppedSpeakingFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
@@ -1,447 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat STT service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time transcription with support for interim results, multiple
|
||||
languages, and various Deepgram features.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from deepgram import LiveOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramSageMakerSTTService(STTService):
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech recognition using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
transcription with support for interim results, speaker diarization, and
|
||||
multiple languages.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- Deepgram SDK for LiveOptions configuration
|
||||
|
||||
Example::
|
||||
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
live_options=LiveOptions(
|
||||
model="nova-3",
|
||||
language="en",
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram model
|
||||
deployed (e.g., "my-deepgram-nova-3-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses value from
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions for detailed configuration. If None,
|
||||
uses sensible defaults (nova-3 model, English, interim results enabled).
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
# Create default options similar to DeepgramSTTService
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language=Language.EN,
|
||||
model="nova-3",
|
||||
channels=1,
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
)
|
||||
|
||||
# Merge with provided options
|
||||
merged_options = default_options.to_dict()
|
||||
if live_options:
|
||||
default_model = default_options.model
|
||||
merged_options.update(live_options.to_dict())
|
||||
# Handle the "None" string bug from deepgram-sdk
|
||||
if "model" in merged_options and merged_options["model"] == "None":
|
||||
merged_options["model"] = default_model
|
||||
|
||||
# Convert Language enum to string if needed
|
||||
if "language" in merged_options and isinstance(merged_options["language"], Language):
|
||||
merged_options["language"] = merged_options["language"].value
|
||||
|
||||
self.set_model_name(merged_options["model"])
|
||||
self._settings = merged_options
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the Deepgram model and reconnect.
|
||||
|
||||
Disconnects from the current session, updates the model setting, and
|
||||
establishes a new connection with the updated model.
|
||||
|
||||
Args:
|
||||
model: The Deepgram model name to use (e.g., "nova-3").
|
||||
"""
|
||||
await super().set_model(model)
|
||||
logger.info(f"Switching STT model to: [{model}]")
|
||||
self._settings["model"] = model
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the recognition language and reconnect.
|
||||
|
||||
Disconnects from the current session, updates the language setting, and
|
||||
establishes a new connection with the updated language.
|
||||
|
||||
Args:
|
||||
language: The language to use for speech recognition (e.g., Language.EN,
|
||||
Language.ES).
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._settings["language"] = language
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["sample_rate"] = self.sample_rate
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Deepgram for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via BiDi stream callbacks).
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio to SageMaker: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"SageMaker STT error: {e}"))
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram query string from settings, creates the BiDi client,
|
||||
starts the streaming session, and launches background tasks for processing
|
||||
responses and sending KeepAlive messages.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram on SageMaker...")
|
||||
|
||||
# Update sample rate in settings
|
||||
self._settings["sample_rate"] = self.sample_rate
|
||||
|
||||
# Build query string from settings, converting booleans to strings
|
||||
query_params = {}
|
||||
for key, value in self._settings.items():
|
||||
if value is not None:
|
||||
# Convert boolean values to lowercase strings for Deepgram API
|
||||
if isinstance(value, bool):
|
||||
query_params[key] = str(value).lower()
|
||||
else:
|
||||
query_params[key] = str(value)
|
||||
|
||||
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
|
||||
|
||||
# Create BiDi client
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start the session
|
||||
await self._client.start_session()
|
||||
|
||||
# Start processing responses in the background
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
# Start keepalive task to maintain connection
|
||||
self._keepalive_task = self.create_task(self._send_keepalive())
|
||||
|
||||
logger.debug("Connected to Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to SageMaker: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"SageMaker connection error: {e}"))
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a CloseStream message to Deepgram, cancels background tasks
|
||||
(KeepAlive and response processing), and closes the BiDi session.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram on SageMaker...")
|
||||
|
||||
# Send CloseStream message to Deepgram
|
||||
try:
|
||||
await self._client.send_json({"type": "CloseStream"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send CloseStream message: {e}")
|
||||
|
||||
# Cancel keepalive task
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Cancel response processing task
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
# Close the BiDi session
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_keepalive(self):
|
||||
"""Send periodic KeepAlive messages to maintain the connection.
|
||||
|
||||
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
|
||||
connection is active. This prevents the connection from timing out during
|
||||
periods of silence.
|
||||
"""
|
||||
while self._client and self._client.is_active:
|
||||
await asyncio.sleep(5)
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "KeepAlive"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send KeepAlive: {e}")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream, decodes the payload,
|
||||
parses JSON responses from Deepgram, and processes transcription results.
|
||||
Runs as a background task until the connection is closed or cancelled.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
# Check if this is a PayloadPart with bytes
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
|
||||
try:
|
||||
# Parse JSON response from Deepgram
|
||||
parsed = json.loads(response_data)
|
||||
|
||||
# Extract and process transcript if available
|
||||
if "channel" in parsed:
|
||||
await self._handle_transcript_response(parsed)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON response: {response_data}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing responses: {e}", exc_info=True)
|
||||
await self.push_error(ErrorFrame(error=f"SageMaker response error: {e}"))
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
async def _handle_transcript_response(self, parsed: dict):
|
||||
"""Handle a transcript response from Deepgram.
|
||||
|
||||
Extracts the transcript text, determines if it's final or interim, extracts
|
||||
language information, and pushes the appropriate frame (TranscriptionFrame
|
||||
or InterimTranscriptionFrame) downstream.
|
||||
|
||||
Args:
|
||||
parsed: The parsed JSON response from Deepgram containing channel,
|
||||
alternatives, transcript, and metadata.
|
||||
"""
|
||||
alternatives = parsed.get("channel", {}).get("alternatives", [])
|
||||
if not alternatives or not alternatives[0].get("transcript"):
|
||||
return
|
||||
|
||||
transcript = alternatives[0]["transcript"]
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
# Stop TTFB metrics on first transcript
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
speech_final = parsed.get("speech_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
if alternatives[0].get("languages"):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final and speech_final:
|
||||
# Final transcription
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# Interim transcription
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability and tracing
|
||||
integration. The actual transcription processing is handled by the parent
|
||||
class and observers.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: The detected language of the transcription, if available.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start TTFB and processing metrics collection."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Deepgram SageMaker-specific handling.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Send finalize message to Deepgram when user stops speaking
|
||||
# This tells Deepgram to flush any remaining audio and return final results
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
@@ -79,6 +79,15 @@ class DeepgramTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
@@ -168,6 +177,15 @@ class DeepgramHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
|
||||
@@ -416,8 +416,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Only used when commit_strategy is VAD. None uses ElevenLabs default.
|
||||
min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms).
|
||||
Only used when commit_strategy is VAD. None uses ElevenLabs default.
|
||||
include_timestamps: Whether to include word-level timestamps in transcripts.
|
||||
enable_logging: Whether to enable logging on ElevenLabs' side.
|
||||
"""
|
||||
|
||||
language_code: Optional[str] = None
|
||||
@@ -426,8 +424,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
vad_threshold: Optional[float] = None
|
||||
min_speech_duration_ms: Optional[int] = None
|
||||
min_silence_duration_ms: Optional[int] = None
|
||||
include_timestamps: bool = False
|
||||
enable_logging: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -463,8 +459,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
self._audio_format = "" # initialized in start()
|
||||
self._receive_task = None
|
||||
|
||||
self._settings = {"language": params.language_code}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate processing metrics.
|
||||
|
||||
@@ -483,13 +477,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Changing language requires reconnecting to the WebSocket.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
new_language = (
|
||||
language_to_elevenlabs_language(language)
|
||||
if isinstance(language, Language)
|
||||
else language
|
||||
)
|
||||
self._params.language_code = new_language
|
||||
self._settings["language"] = new_language
|
||||
self._params.language_code = language.value if isinstance(language, Language) else language
|
||||
# Reconnect with new settings
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -632,16 +620,10 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if self._params.language_code:
|
||||
params.append(f"language_code={self._params.language_code}")
|
||||
|
||||
params.append(f"audio_format={self._audio_format}")
|
||||
params.append(f"encoding={self._audio_format}")
|
||||
params.append(f"sample_rate={self.sample_rate}")
|
||||
params.append(f"commit_strategy={self._params.commit_strategy.value}")
|
||||
|
||||
# Add optional parameters
|
||||
if self._params.include_timestamps:
|
||||
params.append(f"include_timestamps={str(self._params.include_timestamps).lower()}")
|
||||
|
||||
if self._params.enable_logging:
|
||||
params.append(f"enable_logging={str(self._params.enable_logging).lower()}")
|
||||
|
||||
# Add VAD parameters if using VAD commit strategy and values are specified
|
||||
if self._params.commit_strategy == CommitStrategy.VAD:
|
||||
if self._params.vad_silence_threshold_secs is not None:
|
||||
@@ -730,20 +712,15 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
elif message_type == "committed_transcript_with_timestamps":
|
||||
await self._on_committed_transcript_with_timestamps(data)
|
||||
|
||||
elif message_type == "error":
|
||||
error_msg = data.get("error", "Unknown error")
|
||||
logger.error(f"ElevenLabs error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Error: {error_msg}"))
|
||||
elif message_type == "input_error":
|
||||
error_msg = data.get("error", "Unknown input error")
|
||||
logger.error(f"ElevenLabs input error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Input error: {error_msg}"))
|
||||
|
||||
elif message_type == "auth_error":
|
||||
error_msg = data.get("error", "Authentication error")
|
||||
logger.error(f"ElevenLabs auth error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Auth error: {error_msg}"))
|
||||
|
||||
elif message_type == "quota_exceeded_error":
|
||||
error_msg = data.get("error", "Quota exceeded")
|
||||
logger.error(f"ElevenLabs quota exceeded: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Quota exceeded: {error_msg}"))
|
||||
elif message_type in ["auth_error", "quota_exceeded", "transcriber_error", "error"]:
|
||||
error_msg = data.get("error", data.get("message", "Unknown error"))
|
||||
logger.error(f"ElevenLabs error ({message_type}): {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"{message_type}: {error_msg}"))
|
||||
|
||||
else:
|
||||
logger.debug(f"Unknown message type: {message_type}")
|
||||
@@ -788,11 +765,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Args:
|
||||
data: Committed transcript data.
|
||||
"""
|
||||
# If timestamps are enabled, skip this message and wait for the
|
||||
# committed_transcript_with_timestamps message which contains all the data
|
||||
if self._params.include_timestamps:
|
||||
return
|
||||
|
||||
text = data.get("text", "").strip()
|
||||
if not text:
|
||||
return
|
||||
@@ -820,18 +792,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
async def _on_committed_transcript_with_timestamps(self, data: dict):
|
||||
"""Handle committed transcript with word-level timestamps.
|
||||
|
||||
This message is sent when include_timestamps=true. The result data includes:
|
||||
- text: The transcribed text
|
||||
- language_code: Detected language (if available)
|
||||
- words: Array of word objects with timing information:
|
||||
- text: The word text
|
||||
- start: Start time in seconds
|
||||
- end: End time in seconds
|
||||
- type: "word" or "spacing"
|
||||
- speaker_id: Speaker identifier (if available)
|
||||
- logprob: Log probability score (if available)
|
||||
- characters: Array of character strings (if available)
|
||||
|
||||
Args:
|
||||
data: Committed transcript data with timestamps.
|
||||
"""
|
||||
@@ -839,24 +799,9 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Get language if provided
|
||||
language = data.get("language_code")
|
||||
|
||||
logger.debug(f"Committed transcript with timestamps: [{text}]")
|
||||
logger.trace(f"Timestamps: {data.get('words', [])}")
|
||||
|
||||
await self._handle_transcription(text, True, language)
|
||||
|
||||
# This message is sent after committed_transcript when include_timestamps=true.
|
||||
# It contains the full transcript data including text and word-level timestamps.
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
# This is sent after the committed_transcript, so we don't need to
|
||||
# push another TranscriptionFrame, but we could use the timestamps
|
||||
# for additional processing if needed in the future
|
||||
|
||||
@@ -159,6 +159,15 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Fish Audio TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Fish Audio's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model and reconnect.
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from pydantic import BaseModel, Field
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -1453,6 +1452,8 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._bot_text_buffer += text
|
||||
self._search_result_buffer += text # Also accumulate for grounding
|
||||
frame = LLMTextFrame(text=text)
|
||||
# Gemini Live text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
@@ -1645,7 +1646,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
|
||||
frame = TTSTextFrame(text=text)
|
||||
# Gemini Live text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
@@ -920,7 +920,9 @@ class GoogleLLMService(LLMService):
|
||||
for part in candidate.content.parts:
|
||||
if not part.thought and part.text:
|
||||
search_result += part.text
|
||||
await self.push_frame(LLMTextFrame(part.text))
|
||||
frame = LLMTextFrame(part.text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
elif part.function_call:
|
||||
function_call = part.function_call
|
||||
id = function_call.id or str(uuid.uuid4())
|
||||
|
||||
@@ -596,6 +596,15 @@ class GoogleHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Google TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Google's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language format.
|
||||
|
||||
@@ -794,6 +803,15 @@ class GoogleBaseTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Google and Gemini TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Google's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language format.
|
||||
|
||||
|
||||
@@ -111,6 +111,15 @@ class GroqTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Groq TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Groq's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Groq's TTS API.
|
||||
|
||||
@@ -14,14 +14,12 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import WordTTSService
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -31,7 +29,6 @@ try:
|
||||
PostedUtterance,
|
||||
PostedUtteranceVoiceWithId,
|
||||
)
|
||||
from hume.tts.types import TimestampMessage
|
||||
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Hume, you need to `pip install pipecat-ai[hume]`.")
|
||||
@@ -41,7 +38,7 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
|
||||
|
||||
|
||||
class HumeTTSService(WordTTSService):
|
||||
class HumeTTSService(TTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
|
||||
Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
|
||||
@@ -51,7 +48,6 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
- Generates speech from text using Hume TTS.
|
||||
- Streams PCM audio.
|
||||
- Supports word-level timestamps for precise audio-text synchronization.
|
||||
- Supports dynamic updates of voice and synthesis parameters at runtime.
|
||||
- Provides metrics for Time To First Byte (TTFB) and TTS usage.
|
||||
"""
|
||||
@@ -96,13 +92,7 @@ class HumeTTSService(WordTTSService):
|
||||
f"Hume TTS streams at {HUME_SAMPLE_RATE} Hz; configured sample_rate={sample_rate}"
|
||||
)
|
||||
|
||||
# WordTTSService sets push_text_frames=False by default, which we want
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key)
|
||||
self._params = params or HumeTTSService.InputParams()
|
||||
@@ -112,10 +102,6 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Track cumulative time for word timestamps across utterances
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Can generate metrics.
|
||||
|
||||
@@ -124,6 +110,15 @@ class HumeTTSService(WordTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Hume TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Hume's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame) -> None:
|
||||
"""Start the service.
|
||||
|
||||
@@ -131,27 +126,6 @@ class HumeTTSService(WordTTSService):
|
||||
frame: The start frame.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._reset_state()
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset internal state variables."""
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
|
||||
# Reset timing on interruption or stop
|
||||
self._reset_state()
|
||||
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def update_setting(self, key: str, value: Any) -> None:
|
||||
"""Runtime updates via `TTSUpdateSettingsFrame`.
|
||||
@@ -168,7 +142,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
if key_l == "voice_id":
|
||||
self.set_voice(str(value))
|
||||
logger.debug(f"HumeTTSService voice_id set to: {self.voice}")
|
||||
logger.info(f"HumeTTSService voice_id set to: {self.voice}")
|
||||
elif key_l == "description":
|
||||
self._params.description = None if value is None else str(value)
|
||||
elif key_l == "speed":
|
||||
@@ -181,7 +155,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Hume TTS with word timestamps.
|
||||
"""Generate speech from text using Hume TTS.
|
||||
|
||||
Args:
|
||||
text: The text to be synthesized.
|
||||
@@ -212,12 +186,7 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Start TTS sequence if not already started
|
||||
if not self._started:
|
||||
self.start_word_timestamps()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
yield TTSStartedFrame()
|
||||
|
||||
try:
|
||||
# Instant mode is always enabled here (not user-configurable)
|
||||
@@ -228,50 +197,23 @@ class HumeTTSService(WordTTSService):
|
||||
# Use version "2" by default if no description is provided
|
||||
# Version "1" is needed when description is used
|
||||
version = "1" if self._params.description is not None else "2"
|
||||
|
||||
# Track the duration of this utterance based on the last timestamp
|
||||
utterance_duration = 0.0
|
||||
|
||||
async for chunk in self._client.tts.synthesize_json_streaming(
|
||||
utterances=[utterance],
|
||||
format=pcm_fmt,
|
||||
instant_mode=True,
|
||||
version=version,
|
||||
include_timestamp_types=["word"], # Request word-level timestamps
|
||||
):
|
||||
# Process audio chunks
|
||||
audio_b64 = getattr(chunk, "audio", None)
|
||||
if audio_b64:
|
||||
await self.stop_ttfb_metrics()
|
||||
pcm_bytes = base64.b64decode(audio_b64)
|
||||
self._audio_bytes += pcm_bytes
|
||||
if not audio_b64:
|
||||
continue
|
||||
|
||||
# Buffer audio until we have enough to avoid glitches
|
||||
if len(self._audio_bytes) >= self.chunk_size:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=self._audio_bytes,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
self._audio_bytes = b""
|
||||
pcm_bytes = base64.b64decode(audio_b64)
|
||||
self._audio_bytes += pcm_bytes
|
||||
|
||||
# Process timestamp messages
|
||||
if isinstance(chunk, TimestampMessage):
|
||||
timestamp = chunk.timestamp
|
||||
if timestamp.type == "word":
|
||||
# Convert milliseconds to seconds and add cumulative offset
|
||||
word_start_time = self._cumulative_time + (timestamp.time.begin / 1000.0)
|
||||
word_end_time = self._cumulative_time + (timestamp.time.end / 1000.0)
|
||||
# Buffer audio until we have enough to avoid glitches
|
||||
if len(self._audio_bytes) < self.chunk_size:
|
||||
continue
|
||||
|
||||
# Track the maximum end time for this utterance
|
||||
utterance_duration = max(utterance_duration, word_end_time)
|
||||
|
||||
# Add word timestamp
|
||||
await self.add_word_timestamps([(timestamp.text, word_start_time)])
|
||||
|
||||
# Flush any remaining audio bytes
|
||||
if self._audio_bytes:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=self._audio_bytes,
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -282,14 +224,10 @@ class HumeTTSService(WordTTSService):
|
||||
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Update cumulative time for next utterance
|
||||
if utterance_duration > 0:
|
||||
self._cumulative_time = utterance_duration
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Ensure TTFB timer is stopped even on early failures
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame via push_stop_frames
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -146,8 +146,6 @@ class InworldTTSService(TTSService):
|
||||
Parameters:
|
||||
temperature: Voice temperature control for synthesis variability (e.g., 1.1).
|
||||
Valid range: [0, 2]. Higher values increase variability.
|
||||
speaking_rate: Speaking speed control (range: [0.5, 1.5]). Defaults to 1.0 when
|
||||
unset.
|
||||
|
||||
Note:
|
||||
Language is automatically inferred from the input text by Inworld's TTS models,
|
||||
@@ -155,7 +153,6 @@ class InworldTTSService(TTSService):
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None # optional temperature control (range: [0, 2])
|
||||
speaking_rate: Optional[float] = None # optional speaking rate control (range: [0.5, 1.5])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -201,7 +198,6 @@ class InworldTTSService(TTSService):
|
||||
- Other formats as supported by Inworld API
|
||||
params: Optional input parameters for additional configuration. Use this to specify:
|
||||
- temperature: Voice temperature control for variability (range: [0, 2], e.g., 1.1, optional)
|
||||
- speaking_rate: Set desired speaking speed (range: [0.5, 1.5], optional)
|
||||
Language is automatically inferred from input text.
|
||||
**kwargs: Additional arguments passed to the parent TTSService class.
|
||||
|
||||
@@ -232,18 +228,15 @@ class InworldTTSService(TTSService):
|
||||
self._settings = {
|
||||
"voiceId": voice_id, # Voice selection from direct parameter
|
||||
"modelId": model, # TTS model selection from direct parameter
|
||||
"audioConfig": { # Audio format configuration
|
||||
"audioEncoding": encoding, # Format: LINEAR16, MP3, etc.
|
||||
"sampleRateHertz": 0, # Will be set in start() from parent service
|
||||
"audio_config": { # Audio format configuration
|
||||
"audio_encoding": encoding, # Format: LINEAR16, MP3, etc.
|
||||
"sample_rate_hertz": 0, # Will be set in start() from parent service
|
||||
},
|
||||
}
|
||||
|
||||
# Add optional temperature parameter if provided (valid range: [0, 2])
|
||||
if params and params.temperature is not None:
|
||||
self._settings["temperature"] = params.temperature
|
||||
# Add optional speaking rate if provided (valid range: [0.5, 1.5])
|
||||
if params and params.speaking_rate is not None:
|
||||
self._settings["audioConfig"]["speakingRate"] = params.speaking_rate
|
||||
|
||||
# Register voice and model with parent service for metrics and tracking
|
||||
self.set_voice(voice_id) # Used for logging and metrics
|
||||
@@ -257,6 +250,15 @@ class InworldTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Inworld TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Inworld's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Inworld TTS service.
|
||||
|
||||
@@ -264,7 +266,7 @@ class InworldTTSService(TTSService):
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["audioConfig"]["sampleRateHertz"] = self.sample_rate
|
||||
self._settings["audio_config"]["sample_rate_hertz"] = self.sample_rate
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Inworld TTS service.
|
||||
@@ -330,7 +332,9 @@ class InworldTTSService(TTSService):
|
||||
"text": text, # Text to synthesize
|
||||
"voiceId": self._settings["voiceId"], # Voice selection (Ashley, Hades, etc.)
|
||||
"modelId": self._settings["modelId"], # TTS model (inworld-tts-1)
|
||||
"audioConfig": self._settings["audioConfig"], # Audio format settings (LINEAR16, 48kHz)
|
||||
"audio_config": self._settings[
|
||||
"audio_config"
|
||||
], # Audio format settings (LINEAR16, 48kHz)
|
||||
}
|
||||
|
||||
# Add optional temperature parameter if configured (valid range: [0, 2])
|
||||
|
||||
@@ -124,6 +124,15 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that LMNT TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that LMNT's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to LMNT service language format.
|
||||
|
||||
|
||||
@@ -40,40 +40,24 @@ def language_to_minimax_language(language: Language) -> Optional[str]:
|
||||
The corresponding MiniMax language name, or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
Language.AF: "Afrikaans",
|
||||
Language.AR: "Arabic",
|
||||
Language.BG: "Bulgarian",
|
||||
Language.CA: "Catalan",
|
||||
Language.CS: "Czech",
|
||||
Language.DA: "Danish",
|
||||
Language.DE: "German",
|
||||
Language.EL: "Greek",
|
||||
Language.EN: "English",
|
||||
Language.ES: "Spanish",
|
||||
Language.FA: "Persian", # ⚠️ Only supported by speech-2.6-* models
|
||||
Language.FI: "Finnish",
|
||||
Language.FIL: "Filipino", # ⚠️ Only supported by speech-2.6-* models
|
||||
Language.FR: "French",
|
||||
Language.HE: "Hebrew",
|
||||
Language.HI: "Hindi",
|
||||
Language.HR: "Croatian",
|
||||
Language.HU: "Hungarian",
|
||||
Language.ID: "Indonesian",
|
||||
Language.IT: "Italian",
|
||||
Language.JA: "Japanese",
|
||||
Language.KO: "Korean",
|
||||
Language.MS: "Malay",
|
||||
Language.NB: "Norwegian",
|
||||
Language.NN: "Nynorsk",
|
||||
Language.NL: "Dutch",
|
||||
Language.PL: "Polish",
|
||||
Language.PT: "Portuguese",
|
||||
Language.RO: "Romanian",
|
||||
Language.RU: "Russian",
|
||||
Language.SK: "Slovak",
|
||||
Language.SL: "Slovenian",
|
||||
Language.SV: "Swedish",
|
||||
Language.TA: "Tamil", # ⚠️ Only supported by speech-2.6-* models
|
||||
Language.TH: "Thai",
|
||||
Language.TR: "Turkish",
|
||||
Language.UK: "Ukrainian",
|
||||
@@ -100,22 +84,13 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""Configuration parameters for MiniMax TTS.
|
||||
|
||||
Parameters:
|
||||
language: Language for TTS generation. Supports 40 languages.
|
||||
Note: Filipino, Tamil, and Persian require speech-2.6-* models.
|
||||
language: Language for TTS generation.
|
||||
speed: Speech speed (range: 0.5 to 2.0).
|
||||
volume: Speech volume (range: 0 to 10).
|
||||
pitch: Pitch adjustment (range: -12 to 12).
|
||||
emotion: Emotional tone (options: "happy", "sad", "angry", "fearful",
|
||||
"disgusted", "surprised", "calm", "fluent").
|
||||
english_normalization: Deprecated; use `text_normalization` instead
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
The `english_normalization` parameter is deprecated and will be removed in a future version.
|
||||
Use the `text_normalization` parameter instead.
|
||||
|
||||
text_normalization: Enable text normalization (Chinese/English).
|
||||
latex_read: Enable LaTeX formula reading.
|
||||
exclude_aggregated_audio: Whether to exclude aggregated audio in final chunk.
|
||||
"disgusted", "surprised", "neutral").
|
||||
english_normalization: Whether to apply English text normalization.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
@@ -123,10 +98,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[int] = 0
|
||||
emotion: Optional[str] = None
|
||||
english_normalization: Optional[bool] = None # Deprecated
|
||||
text_normalization: Optional[bool] = None
|
||||
latex_read: Optional[bool] = None
|
||||
exclude_aggregated_audio: Optional[bool] = None
|
||||
english_normalization: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -148,12 +120,9 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
base_url: API base URL, defaults to MiniMax's T2A endpoint.
|
||||
Global: https://api.minimax.io/v1/t2a_v2
|
||||
Mainland China: https://api.minimaxi.chat/v1/t2a_v2
|
||||
Western United States: https://api-uw.minimax.io/v1/t2a_v2
|
||||
group_id: MiniMax Group ID to identify project.
|
||||
model: TTS model name. Defaults to "speech-02-turbo". Options include:
|
||||
"speech-2.6-hd", "speech-2.6-turbo" (latest, supports Filipino/Tamil/Persian),
|
||||
"speech-02-hd", "speech-02-turbo",
|
||||
"speech-01-hd", "speech-01-turbo".
|
||||
model: TTS model name. Defaults to "speech-02-turbo". Options include
|
||||
"speech-02-hd", "speech-02-turbo", "speech-01-hd", "speech-01-turbo".
|
||||
voice_id: Voice identifier. Defaults to "Calm_Woman".
|
||||
aiohttp_session: aiohttp.ClientSession for API communication.
|
||||
sample_rate: Output audio sample rate in Hz. If None, uses pipeline default.
|
||||
@@ -207,34 +176,15 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
"fluent",
|
||||
]
|
||||
if params.emotion in supported_emotions:
|
||||
self._settings["voice_setting"]["emotion"] = params.emotion
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unsupported emotion: {params.emotion}. Supported emotions: {supported_emotions}"
|
||||
)
|
||||
logger.warning(f"Unsupported emotion: {params.emotion}. Using default.")
|
||||
|
||||
# If `english_normalization`, add `text_normalization` and print warning
|
||||
# Add english_normalization if provided
|
||||
if params.english_normalization is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter `english_normalization` is deprecated and will be removed in a future version. Use `text_normalization` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self._settings["voice_setting"]["text_normalization"] = params.english_normalization
|
||||
|
||||
# Add text_normalization if provided (corrected parameter name)
|
||||
if params.text_normalization is not None:
|
||||
self._settings["voice_setting"]["text_normalization"] = params.text_normalization
|
||||
|
||||
# Add latex_read if provided
|
||||
if params.latex_read is not None:
|
||||
self._settings["voice_setting"]["latex_read"] = params.latex_read
|
||||
self._settings["english_normalization"] = params.english_normalization
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -244,6 +194,15 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that MiniMax TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that MiniMax's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to MiniMax service language format.
|
||||
|
||||
@@ -281,7 +240,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["audio_setting"]["sample_rate"] = self.sample_rate
|
||||
logger.debug(f"MiniMax TTS initialized with sample_rate: {self.sample_rate}")
|
||||
logger.debug(f"MiniMax TTS initialized with sample rate: {self.sample_rate}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -380,15 +339,11 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
num_channels=1,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error converting hex to binary: {e}",
|
||||
)
|
||||
logger.error(f"Error converting hex to binary: {e}")
|
||||
continue
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Error decoding JSON: {e}, data: {data_block[:100]}",
|
||||
)
|
||||
logger.error(f"Error decoding JSON: {e}, data: {data_block[:100]}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -151,6 +151,15 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Neuphonic service language format.
|
||||
|
||||
@@ -440,6 +449,15 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Neuphonic service language format.
|
||||
|
||||
|
||||
@@ -390,7 +390,9 @@ class BaseOpenAILLMService(LLMService):
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
frame = LLMTextFrame(chunk.choices[0].delta.content)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
|
||||
# we need to get LLMTextFrame for the transcript
|
||||
|
||||
@@ -19,7 +19,6 @@ from pipecat.adapters.services.open_ai_realtime_adapter import (
|
||||
OpenAIRealtimeLLMAdapter,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -679,13 +678,15 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# the output modality is "text"
|
||||
if evt.delta:
|
||||
frame = LLMTextFrame(evt.delta)
|
||||
# OpenAI Realtime text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
# We receive audio transcript deltas (as opposed to text deltas) when
|
||||
# the output modality is "audio" (the default)
|
||||
if evt.delta:
|
||||
frame = TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE)
|
||||
frame = TTSTextFrame(evt.delta)
|
||||
# OpenAI Realtime text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -131,6 +131,15 @@ class OpenAITTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that OpenAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that OpenAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model to use.
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -653,7 +652,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
|
||||
await self.push_frame(TTSTextFrame(evt.delta))
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
|
||||
@@ -66,6 +66,15 @@ class PiperTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Piper TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Piper's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Piper's HTTP API.
|
||||
|
||||
@@ -113,10 +113,6 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
params: Additional configuration parameters.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
@@ -127,17 +123,10 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
|
||||
|
||||
params = params or RimeTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
@@ -163,7 +152,6 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._context_id = None # Tracks current turn
|
||||
self._receive_task = None
|
||||
self._cumulative_time = 0 # Accumulates time across messages
|
||||
self._extra_msg_fields = {} # Extra fields for next message
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -193,31 +181,6 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
# A set of Rime-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Rime spell function."""
|
||||
return f"spell({text})"
|
||||
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f"<{seconds * 1000}>"
|
||||
|
||||
def PRONOUNCE(self, text: str, word: str, phoneme: str) -> str:
|
||||
"""Convenience method to support Rime's custom pronunciations feature.
|
||||
|
||||
https://docs.rime.ai/api-reference/custom-pronunciation
|
||||
"""
|
||||
self._extra_msg_fields["phonemizeBetweenBrackets"] = True
|
||||
return text.replace(word, f"{phoneme}")
|
||||
|
||||
def INLINE_SPEED(self, text: str, speed: float) -> str:
|
||||
"""Convenience method to support inline speeds."""
|
||||
if not self._extra_msg_fields:
|
||||
self._extra_msg_fields = {}
|
||||
speed_vals = self._extra_msg_fields.get("inlineSpeedAlpha", "").split(",")
|
||||
self._extra_msg_fields["inlineSpeedAlpha"] = ",".join(speed_vals + [str(speed)])
|
||||
return f"[{text}]"
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
@@ -230,11 +193,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Rime API."""
|
||||
msg = {"text": text, "contextId": self._context_id}
|
||||
if self._extra_msg_fields:
|
||||
msg |= self._extra_msg_fields
|
||||
self._extra_msg_fields = {}
|
||||
return msg
|
||||
return {"text": text, "contextId": self._context_id}
|
||||
|
||||
def _build_clear_msg(self) -> dict:
|
||||
"""Build clear operation message."""
|
||||
@@ -542,6 +501,15 @@ class RimeHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Rime TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Rime's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert pipecat language to Rime language code.
|
||||
|
||||
|
||||
@@ -113,6 +113,15 @@ class RivaTTSService(TTSService):
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Riva TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Riva's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
@@ -157,6 +166,7 @@ class RivaTTSService(TTSService):
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -181,7 +191,6 @@ class RivaTTSService(TTSService):
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -176,7 +176,9 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
frame = LLMTextFrame(chunk.choices[0].delta.content)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
|
||||
# we need to get LLMTextFrame for the transcript
|
||||
|
||||
@@ -195,6 +195,15 @@ class SarvamHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Sarvam TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Sarvam's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Sarvam AI language format.
|
||||
|
||||
@@ -458,6 +467,15 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Sarvam TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Sarvam's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Sarvam AI language format.
|
||||
|
||||
|
||||
@@ -84,10 +84,6 @@ class SimliVideoService(FrameProcessor):
|
||||
Please use 'api_key' and 'face_id' parameters instead.
|
||||
|
||||
use_turn_server: Whether to use TURN server for connection. Defaults to False.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
The 'use_turn_server' parameter is deprecated and will be removed in a future version.
|
||||
|
||||
latency_interval: Latency interval setting for sending health checks to check
|
||||
the latency to Simli Servers. Defaults to 0.
|
||||
simli_url: URL of the simli servers. Can be changed for custom deployments
|
||||
@@ -139,20 +135,14 @@ class SimliVideoService(FrameProcessor):
|
||||
|
||||
config = SimliConfig(**config_kwargs)
|
||||
|
||||
if use_turn_server:
|
||||
warnings.warn(
|
||||
"The 'use_turn_server' parameter is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._initialized = False
|
||||
# Add buffer time to session limits
|
||||
config.maxIdleTime += 5
|
||||
config.maxSessionLength += 5
|
||||
self._simli_client = SimliClient(
|
||||
config=config,
|
||||
latencyInterval=latency_interval,
|
||||
config,
|
||||
use_turn_server,
|
||||
latency_interval,
|
||||
simliURL=simli_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -105,6 +105,15 @@ class SpeechmaticsTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Speechmatics TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Speechmatics's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Speechmatics' HTTP API.
|
||||
|
||||
@@ -38,7 +38,7 @@ class STTService(AIService):
|
||||
|
||||
Event handlers:
|
||||
on_connected: Called when connected to the STT service.
|
||||
on_disconnected: Called when disconnected from the STT service.
|
||||
on_connected: Called when disconnected from the STT service.
|
||||
on_connection_error: Called when a connection to the STT service error occurs.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -12,8 +12,6 @@ from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
@@ -25,8 +23,6 @@ from typing import (
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -105,16 +101,6 @@ class TTSService(AIService):
|
||||
sample_rate: Optional[int] = None,
|
||||
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
# Types of text aggregations that should not be spoken.
|
||||
skip_aggregator_types: Optional[List[str]] = [],
|
||||
# A list of callables to transform text before just before sending it to TTS.
|
||||
# Each callable takes the aggregated text and its type, and returns the transformed text.
|
||||
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
|
||||
text_transforms: Optional[
|
||||
List[
|
||||
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
|
||||
]
|
||||
] = None,
|
||||
# Text filter executed after text has been aggregated.
|
||||
text_filters: Optional[Sequence[BaseTextFilter]] = None,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
@@ -134,16 +120,6 @@ class TTSService(AIService):
|
||||
pause_frame_processing: Whether to pause frame processing during audio generation.
|
||||
sample_rate: Output sample rate for generated audio.
|
||||
text_aggregator: Custom text aggregator for processing incoming text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
skip_aggregator_types: List of aggregation types that should not be spoken.
|
||||
text_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
|
||||
text_filters: Sequence of text filters to apply after aggregation.
|
||||
text_filter: Single text filter (deprecated, use text_filters).
|
||||
|
||||
@@ -166,21 +142,6 @@ class TTSService(AIService):
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
if text_aggregator:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
|
||||
self._text_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = text_transforms or []
|
||||
# TODO: Deprecate _text_filters when added to LLMTextProcessor
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
@@ -231,6 +192,23 @@ class TTSService(AIService):
|
||||
CHUNK_SECONDS = 0.5
|
||||
return int(self.sample_rate * CHUNK_SECONDS * 2) # 2 bytes/sample
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates whether TTSTextFrames include necesary inter-frame spaces.
|
||||
|
||||
When True, the TTSTextFrame objects pushed by this service already
|
||||
include all necessary spaces between subsequent frames. When False,
|
||||
downstream processors (like the assistant context aggregator) may need
|
||||
to add spacing.
|
||||
|
||||
Subclasses should override this property to return True if their text
|
||||
generation process already includes necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
False by default. Subclasses can override to return True.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model to use.
|
||||
|
||||
@@ -320,39 +298,6 @@ class TTSService(AIService):
|
||||
await self.cancel_task(self._stop_frame_task)
|
||||
self._stop_frame_task = None
|
||||
|
||||
def add_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
|
||||
the function should handle all text before sending to TTS.
|
||||
"""
|
||||
self._text_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._text_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._text_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
@@ -408,8 +353,6 @@ class TTSService(AIService):
|
||||
and frame.skip_tts
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AggregatedTextFrame):
|
||||
await self._push_tts_frames(frame)
|
||||
elif (
|
||||
isinstance(frame, TextFrame)
|
||||
and not isinstance(frame, InterimTranscriptionFrame)
|
||||
@@ -425,16 +368,10 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
pending_aggregation = self._text_aggregator.text
|
||||
|
||||
# Reset aggregator state
|
||||
sentence = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
|
||||
if pending_aggregation.text:
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(pending_aggregation.text, pending_aggregation.type)
|
||||
)
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -443,8 +380,7 @@ class TTSService(AIService):
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
|
||||
await self._push_tts_frames(frame.text)
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
@@ -534,35 +470,15 @@ class TTSService(AIService):
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: Optional[str] = None
|
||||
includes_inter_frame_spaces: bool = False
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
aggregated_by = "token"
|
||||
else:
|
||||
aggregate = await self._text_aggregator.aggregate(frame.text)
|
||||
if aggregate:
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
text = await self._text_aggregator.aggregate(frame.text)
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False
|
||||
):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
|
||||
# Skip sending to TTS if the aggregation type is in the skip list. Simply
|
||||
# push the original frame downstream.
|
||||
if type in self._skip_aggregator_types:
|
||||
await self.push_frame(src_frame)
|
||||
return
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
# Remove leading newlines only
|
||||
text = text.lstrip("\n")
|
||||
|
||||
@@ -578,45 +494,21 @@ class TTSService(AIService):
|
||||
|
||||
await self.start_processing_metrics()
|
||||
|
||||
# Process all filters.
|
||||
# Process all filter.
|
||||
for filter in self._text_filters:
|
||||
await filter.reset_interruption()
|
||||
text = await filter.filter(text)
|
||||
|
||||
if not text.strip():
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# To support use cases that may want to know the text before it's spoken, we
|
||||
# push the AggregatedTextFrame version before transforming and sending to TTS.
|
||||
# However, we do not want to add this text to the assistant context until it
|
||||
# is spoken, so we set append_to_context to False.
|
||||
src_frame.append_to_context = False
|
||||
await self.push_frame(src_frame)
|
||||
|
||||
# Note: Text transformations are meant to only affect the text sent to the TTS for
|
||||
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
|
||||
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
|
||||
# services that support word-level timestamps, this CAN affect the resulting context
|
||||
# since the TTSTextFrames are generated from the TTS output stream
|
||||
transformed_text = text
|
||||
for aggregation_type, transform in self._text_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
transformed_text = await transform(transformed_text, type)
|
||||
await self.process_generator(self.run_tts(transformed_text))
|
||||
if text:
|
||||
await self.process_generator(self.run_tts(text))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if self._push_text_frames:
|
||||
# In TTS services that support word timestamps, the TTSTextFrames
|
||||
# are pushed as words are spoken. However, in the case where the TTS service
|
||||
# does not support word timestamps (i.e. _push_text_frames is True), we send
|
||||
# the original (non-transformed) text after the TTS generation has completed.
|
||||
# This way, if we are interrupted, the text is not added to the assistant
|
||||
# context and the context that IS added does not include TTS-specific tags
|
||||
# or transformations.
|
||||
frame = TTSTextFrame(text, aggregated_by=type)
|
||||
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
frame = TTSTextFrame(text)
|
||||
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
@@ -743,9 +635,7 @@ class WordTTSService(TTSService):
|
||||
frame = TTSStoppedFrame()
|
||||
frame.pts = last_pts
|
||||
else:
|
||||
# Assumption: word-by-word text frames don't include spaces, so
|
||||
# we can rely on the default includes_inter_frame_spaces=False
|
||||
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
|
||||
frame = TTSTextFrame(word)
|
||||
frame.pts = self._initial_word_timestamp + timestamp
|
||||
if frame:
|
||||
last_pts = frame.pts
|
||||
|
||||
@@ -36,7 +36,6 @@ class WebsocketService(ABC):
|
||||
"""
|
||||
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._reconnect_on_error = reconnect_on_error
|
||||
self._reconnect_in_progress: bool = False # Add this flag
|
||||
|
||||
async def _verify_connection(self) -> bool:
|
||||
"""Verify the websocket connection is active and responsive.
|
||||
@@ -67,59 +66,6 @@ class WebsocketService(ABC):
|
||||
await self._connect_websocket()
|
||||
return await self._verify_connection()
|
||||
|
||||
async def _try_reconnect(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
report_error: Optional[Callable[[ErrorFrame], Awaitable[None]]] = None,
|
||||
) -> bool:
|
||||
# Prevent concurrent reconnection attempts
|
||||
if self._reconnect_in_progress:
|
||||
logger.warning(f"{self} reconnect attempt aborted: already in progress")
|
||||
return False
|
||||
|
||||
self._reconnect_in_progress = True
|
||||
last_exception: Optional[Exception] = None
|
||||
try:
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.warning(f"{self} reconnecting, attempt {attempt}")
|
||||
if await self._reconnect_websocket(attempt):
|
||||
logger.info(f"{self} reconnected successfully on attempt {attempt}")
|
||||
return True
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.error(f"{self} reconnection attempt {attempt} failed: {e}")
|
||||
if report_error:
|
||||
await report_error(
|
||||
ErrorFrame(f"{self} reconnection attempt {attempt} failed: {e}")
|
||||
)
|
||||
wait_time = exponential_backoff_time(attempt)
|
||||
await asyncio.sleep(wait_time)
|
||||
fatal_msg = f"{self} failed to reconnect after {max_retries} attempts"
|
||||
if last_exception:
|
||||
fatal_msg += f": {last_exception}"
|
||||
logger.error(fatal_msg)
|
||||
if report_error:
|
||||
await report_error(ErrorFrame(fatal_msg, fatal=True))
|
||||
return False
|
||||
finally:
|
||||
self._reconnect_in_progress = False
|
||||
|
||||
async def send_with_retry(self, message, report_error: Callable[[ErrorFrame], Awaitable[None]]):
|
||||
"""Attempt to send a message, retrying after reconnect if necessary."""
|
||||
try:
|
||||
await self._websocket.send(message)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} send failed: {e}, will try to reconnect")
|
||||
# Try to reconnect before retrying
|
||||
success = await self._try_reconnect(report_error=report_error)
|
||||
if success:
|
||||
logger.info(f"{self} reconnected successfully, will retry send the message")
|
||||
# trying to send the message one more time
|
||||
await self._websocket.send(message)
|
||||
else:
|
||||
logger.error(f"{self} send failed; unable to reconnect")
|
||||
|
||||
async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]):
|
||||
"""Handle websocket message receiving with automatic retry logic.
|
||||
|
||||
@@ -130,9 +76,13 @@ class WebsocketService(ABC):
|
||||
Args:
|
||||
report_error: Callback function to report connection errors.
|
||||
"""
|
||||
retry_count = 0
|
||||
MAX_RETRIES = 3
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self._receive_messages()
|
||||
retry_count = 0 # Reset counter on successful message receive
|
||||
except ConnectionClosedOK as e:
|
||||
# Normal closure, don't retry
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
@@ -142,9 +92,21 @@ class WebsocketService(ABC):
|
||||
logger.error(message)
|
||||
|
||||
if self._reconnect_on_error:
|
||||
success = await self._try_reconnect(report_error=report_error)
|
||||
if not success:
|
||||
retry_count += 1
|
||||
if retry_count >= MAX_RETRIES:
|
||||
await report_error(ErrorFrame(message))
|
||||
break
|
||||
|
||||
logger.warning(f"{self} connection error, will retry: {e}")
|
||||
await report_error(ErrorFrame(message))
|
||||
|
||||
try:
|
||||
if await self._reconnect_websocket(retry_count):
|
||||
retry_count = 0 # Reset counter on successful reconnection
|
||||
wait_time = exponential_backoff_time(retry_count)
|
||||
await asyncio.sleep(wait_time)
|
||||
except Exception as reconnect_error:
|
||||
logger.error(f"{self} reconnection failed: {reconnect_error}")
|
||||
else:
|
||||
await report_error(ErrorFrame(message))
|
||||
break
|
||||
|
||||
@@ -203,16 +203,8 @@ async def run_test(
|
||||
if not isinstance(frame, EndFrame) or not send_end_frame:
|
||||
received_down_frames.append(frame)
|
||||
|
||||
down_frames_printed = "["
|
||||
for frame in received_down_frames:
|
||||
down_frames_printed += f"{frame.__class__.__name__}, "
|
||||
down_frames_printed += "]"
|
||||
expected_frames_printed = "["
|
||||
for frame in expected_down_frames:
|
||||
expected_frames_printed += f"{frame.__name__}, "
|
||||
expected_frames_printed += "]"
|
||||
print("received DOWN frames =", down_frames_printed)
|
||||
print("expected DOWN frames =", expected_frames_printed)
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
print("expected DOWN frames =", expected_down_frames)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LiveKit REST Helpers.
|
||||
|
||||
Methods that wrap the LiveKit API for room management.
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class LiveKitRESTHelper:
|
||||
"""Helper class for interacting with LiveKit's REST API.
|
||||
|
||||
Provides methods for managing LiveKit rooms.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
api_url: str = "https://your-livekit-host.com",
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
):
|
||||
"""Initialize the LiveKit REST helper.
|
||||
|
||||
Args:
|
||||
api_key: Your LiveKit API key.
|
||||
api_secret: Your LiveKit API secret.
|
||||
api_url: LiveKit server URL (e.g. "https://your-livekit-host.com").
|
||||
aiohttp_session: Async HTTP session for making requests.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.api_url = api_url.rstrip("/")
|
||||
self.aiohttp_session = aiohttp_session
|
||||
|
||||
def _create_access_token(self, room_create: bool = True) -> str:
|
||||
"""Create a signed access token for LiveKit API authentication.
|
||||
|
||||
Args:
|
||||
room_create: Whether to grant roomCreate permission.
|
||||
|
||||
Returns:
|
||||
Signed JWT access token.
|
||||
"""
|
||||
import time
|
||||
|
||||
import jwt
|
||||
|
||||
claims = {
|
||||
"iss": self.api_key,
|
||||
"sub": self.api_key,
|
||||
"nbf": int(time.time()),
|
||||
"exp": int(time.time()) + 60, # Token valid for 60 seconds
|
||||
"video": {
|
||||
"roomCreate": room_create,
|
||||
},
|
||||
}
|
||||
|
||||
return jwt.encode(claims, self.api_secret, algorithm="HS256")
|
||||
|
||||
async def delete_room_by_name(self, room_name: str) -> bool:
|
||||
"""Delete a LiveKit room by name.
|
||||
|
||||
This will forcibly disconnect all participants currently in the room.
|
||||
|
||||
Args:
|
||||
room_name: Name of the room to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful.
|
||||
|
||||
Raises:
|
||||
Exception: If deletion fails.
|
||||
"""
|
||||
token = self._create_access_token(room_create=True)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with self.aiohttp_session.post(
|
||||
f"{self.api_url}/twirp/livekit.RoomService/DeleteRoom",
|
||||
headers=headers,
|
||||
json={"room": room_name},
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
raise Exception(f"Failed to delete room [{room_name}] (status: {r.status}): {text}")
|
||||
|
||||
return True
|
||||
@@ -18,7 +18,6 @@ Dependencies:
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import FrozenSet, List, Optional, Sequence, Tuple
|
||||
|
||||
import nltk
|
||||
@@ -199,24 +198,7 @@ def parse_start_end_tags(
|
||||
return (None, current_tag_index)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextPartForConcatenation:
|
||||
"""Class representing a part of text for concatenation with concatenate_aggregated_text.
|
||||
|
||||
Attributes:
|
||||
text: The text content.
|
||||
includes_inter_part_spaces: Whether any necessary inter-frame
|
||||
(leading/trailing) spaces are already included in the text.
|
||||
"""
|
||||
|
||||
text: str
|
||||
includes_inter_part_spaces: bool
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(text: [{self.text}], includes_inter_part_spaces: {self.includes_inter_part_spaces})"
|
||||
|
||||
|
||||
def concatenate_aggregated_text(text_parts: List[TextPartForConcatenation]) -> str:
|
||||
def concatenate_aggregated_text(text_parts: List[str], add_spaces: bool) -> str:
|
||||
"""Concatenate a list of text parts into a single string.
|
||||
|
||||
This function joins the provided list of text parts into a single string,
|
||||
@@ -226,55 +208,15 @@ def concatenate_aggregated_text(text_parts: List[TextPartForConcatenation]) -> s
|
||||
transcription services.
|
||||
|
||||
Args:
|
||||
text_parts: A list of text parts to concatenate.
|
||||
text_parts: A list of strings representing parts of text to concatenate.
|
||||
add_spaces: Whether to add spaces between text parts during concatenation.
|
||||
|
||||
Returns:
|
||||
A single concatenated string.
|
||||
"""
|
||||
result = ""
|
||||
last_includes_inter_part_spaces = False
|
||||
|
||||
if not text_parts:
|
||||
return result
|
||||
|
||||
def append_part(part: TextPartForConcatenation):
|
||||
nonlocal result
|
||||
nonlocal last_includes_inter_part_spaces
|
||||
result += part.text
|
||||
last_includes_inter_part_spaces = part.includes_inter_part_spaces
|
||||
|
||||
for part in text_parts:
|
||||
# Part is empty.
|
||||
# Skip.
|
||||
if not part.text:
|
||||
continue
|
||||
|
||||
# Result is as yet empty.
|
||||
# Just append.
|
||||
if not result:
|
||||
append_part(part)
|
||||
continue
|
||||
|
||||
if part.includes_inter_part_spaces and last_includes_inter_part_spaces:
|
||||
# This part is part of an ongoing run that has spaces already included.
|
||||
# Just append.
|
||||
append_part(part)
|
||||
elif not part.includes_inter_part_spaces and not last_includes_inter_part_spaces:
|
||||
# This part is part of an ongoing run that has no spaces included.
|
||||
# Add a space before appending.
|
||||
result += " "
|
||||
append_part(part)
|
||||
else:
|
||||
# This part represents a transition to a new run (spaces -> no spaces, or vice versa).
|
||||
# Add a space if needed, before appending.
|
||||
if not result[-1].isspace() and not part.text[0].isspace():
|
||||
result += " "
|
||||
append_part(part)
|
||||
|
||||
# NOTE: the above logic assumes that runs of text parts with
|
||||
# includes_inter_part_spaces=True are well-formed, i.e. they're not
|
||||
# actually multiple separate runs with a space-less boundary, like
|
||||
# "hello ", "world.", "goodnight ", "moon."
|
||||
# Concatenate text parts with or without spaces based on the flag
|
||||
separator = " " if add_spaces else ""
|
||||
result = separator.join(text_parts)
|
||||
|
||||
# Clean up any excessive whitespace
|
||||
result = result.strip()
|
||||
|
||||
@@ -12,47 +12,9 @@ aggregated text should be sent for speech synthesis.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aggregation:
|
||||
"""Data class representing aggregated text and its type.
|
||||
|
||||
An Aggregation object is created whenever a stream of text is aggregated by
|
||||
a text aggregator. It contains the aggregated text and a type indicating
|
||||
the nature of the aggregation.
|
||||
|
||||
Parameters:
|
||||
text: The aggregated text content.
|
||||
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token',
|
||||
'my_custom_aggregation').
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the aggregation.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the type and text of the aggregation.
|
||||
"""
|
||||
return f"Aggregation by {self.type}: {self.text}"
|
||||
|
||||
|
||||
class BaseTextAggregator(ABC):
|
||||
"""Base class for text aggregators in the Pipecat framework.
|
||||
|
||||
@@ -68,7 +30,7 @@ class BaseTextAggregator(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def text(self) -> Aggregation:
|
||||
def text(self) -> str:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Subclasses must implement this property to return the text that has
|
||||
@@ -80,33 +42,25 @@ class BaseTextAggregator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
"""Aggregate the specified text with the currently accumulated text.
|
||||
|
||||
This method should be implemented to define how the new text contributes
|
||||
to the aggregation process. It returns the aggregated text and a string
|
||||
describing how it was aggregated if it's ready to be processed,
|
||||
or None otherwise.
|
||||
to the aggregation process. It returns the updated aggregated text if
|
||||
it's ready to be processed, or None otherwise.
|
||||
|
||||
Subclasses should implement their specific logic for:
|
||||
|
||||
- How to combine new text with existing accumulated text
|
||||
- When to consider the aggregated text ready for processing
|
||||
- What criteria determine text completion (e.g., sentence boundaries)
|
||||
- When a completion occurs, the method should return an Aggregation object
|
||||
containing the aggregated text and its type. The text should be stripped
|
||||
of leading/trailing whitespace so that consumers can rely on a consistent
|
||||
format.
|
||||
|
||||
Args:
|
||||
text: The text to be aggregated.
|
||||
|
||||
Returns:
|
||||
An Aggregation object if ready for processing, or None if more
|
||||
text is needed before the aggregated content is ready. If an Aggregation
|
||||
object is returned, it should consist of the updated aggregated text,
|
||||
stripped of leading/trailing whitespace, and a string indicating the
|
||||
type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
|
||||
The updated aggregated text if ready for processing, or None if more
|
||||
text is needed before the aggregated content is ready.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class BaseTextFilter(ABC):
|
||||
behavior, settings management, and interruption handling logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update the filter's configuration settings.
|
||||
|
||||
@@ -52,6 +53,7 @@ class BaseTextFilter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruption events in the processing pipeline.
|
||||
|
||||
@@ -60,6 +62,7 @@ class BaseTextFilter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset_interruption(self):
|
||||
"""Reset the filter state after an interruption has been handled.
|
||||
|
||||
|
||||
@@ -8,41 +8,19 @@
|
||||
|
||||
This module provides an aggregator that identifies and processes content between
|
||||
pattern pairs (like XML tags or custom delimiters) in streaming text, with
|
||||
support for custom handlers and configurable actions for when a pattern is found.
|
||||
support for custom handlers and configurable pattern removal.
|
||||
"""
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import Awaitable, Callable, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
|
||||
|
||||
class MatchAction(Enum):
|
||||
"""Actions to take when a pattern pair is matched.
|
||||
|
||||
Parameters:
|
||||
REMOVE: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
KEEP: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
AGGREGATE: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP = "keep"
|
||||
AGGREGATE = "aggregate"
|
||||
|
||||
|
||||
class PatternMatch(Aggregation):
|
||||
class PatternMatch:
|
||||
"""Represents a matched pattern pair with its content.
|
||||
|
||||
A PatternMatch object is created when a complete pattern pair is found
|
||||
@@ -51,25 +29,25 @@ class PatternMatch(Aggregation):
|
||||
content between the patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, content: str, type: str, full_match: str):
|
||||
def __init__(self, pattern_id: str, full_match: str, content: str):
|
||||
"""Initialize a pattern match.
|
||||
|
||||
Args:
|
||||
type: The type of the matched pattern pair. It should be representative
|
||||
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
|
||||
pattern_id: The identifier of the matched pattern pair.
|
||||
full_match: The complete text including start and end patterns.
|
||||
content: The text content between the start and end patterns.
|
||||
"""
|
||||
super().__init__(text=content, type=type)
|
||||
self.pattern_id = pattern_id
|
||||
self.full_match = full_match
|
||||
self.content = content
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the pattern match.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the pattern type and content.
|
||||
A descriptive string showing the pattern ID and content.
|
||||
"""
|
||||
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
|
||||
return f"PatternMatch(id={self.pattern_id}, content={self.content})"
|
||||
|
||||
|
||||
class PatternPairAggregator(BaseTextAggregator):
|
||||
@@ -77,21 +55,16 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
This aggregator buffers text until it can identify complete pattern pairs
|
||||
(defined by start and end patterns), processes the content between these
|
||||
patterns using registered handlers. By default, its aggregation method
|
||||
returns text at sentence boundaries, and remove the content found between
|
||||
any matched patterns. However, matched patterns can also be configured to
|
||||
returned as a separate aggregation object containing the content between
|
||||
their start and end patterns or left in, so that only the delimiters are
|
||||
removed and a callback can be triggered.
|
||||
|
||||
This aggregator is particularly useful for processing structured content in
|
||||
streaming text, such as XML tags, markdown formatting, or custom delimiters.
|
||||
patterns using registered handlers, and returns text at sentence boundaries.
|
||||
It's particularly useful for processing structured content in streaming text,
|
||||
such as XML tags, markdown formatting, or custom delimiters.
|
||||
|
||||
The aggregator ensures that patterns spanning multiple text chunks are
|
||||
correctly identified.
|
||||
correctly identified and handles cases where patterns contain sentence
|
||||
boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self):
|
||||
"""Initialize the pattern pair aggregator.
|
||||
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
@@ -102,27 +75,16 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._handlers = {}
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
def text(self) -> str:
|
||||
"""Get the currently buffered text.
|
||||
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
stripped_text = self._text.strip()
|
||||
type = (
|
||||
pattern_start[1].get("type", AggregationType.SENTENCE)
|
||||
if pattern_start
|
||||
else AggregationType.SENTENCE
|
||||
)
|
||||
return Aggregation(text=stripped_text, type=type)
|
||||
return self._text
|
||||
|
||||
def add_pattern(
|
||||
self,
|
||||
type: str,
|
||||
start_pattern: str,
|
||||
end_pattern: str,
|
||||
action: MatchAction = MatchAction.REMOVE,
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
) -> "PatternPairAggregator":
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
@@ -131,94 +93,41 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
the end pattern, and treat the content between them as a match.
|
||||
|
||||
Args:
|
||||
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
pattern_id: Unique identifier for this pattern pair.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
action: What to do when a complete pattern is matched:
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
remove_match: Whether to remove the matched content from the text.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
|
||||
raise ValueError(
|
||||
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
|
||||
)
|
||||
self._patterns[type] = {
|
||||
self._patterns[pattern_id] = {
|
||||
"start": start_pattern,
|
||||
"end": end_pattern,
|
||||
"type": type,
|
||||
"action": action,
|
||||
"remove_match": remove_match,
|
||||
}
|
||||
return self
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
):
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use `add_pattern` with a type and MatchAction instead.
|
||||
|
||||
This method calls `add_pattern` setting type with the provided pattern_id and action
|
||||
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
|
||||
|
||||
Args:
|
||||
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
|
||||
as those arereserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
|
||||
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once")
|
||||
warnings.warn(
|
||||
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
|
||||
" removed in a future version. Use add_pattern with a type and MatchAction instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
|
||||
return self.add_pattern(
|
||||
type=pattern_id,
|
||||
start_pattern=start_pattern,
|
||||
end_pattern=end_pattern,
|
||||
action=action,
|
||||
)
|
||||
|
||||
def on_pattern_match(
|
||||
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
) -> "PatternPairAggregator":
|
||||
"""Register a handler for when a pattern pair is matched.
|
||||
|
||||
The handler will be called whenever a complete match for the
|
||||
specified type is found in the text.
|
||||
specified pattern ID is found in the text.
|
||||
|
||||
Args:
|
||||
type: The type of the pattern pair to trigger the handler.
|
||||
pattern_id: ID of the pattern pair to match.
|
||||
handler: Async function to call when pattern is matched.
|
||||
The function should accept a PatternMatch object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._handlers[type] = handler
|
||||
self._handlers[pattern_id] = handler
|
||||
return self
|
||||
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]:
|
||||
"""Process all complete pattern pairs in the text.
|
||||
|
||||
Searches for all complete pattern pairs in the text, calls the
|
||||
@@ -228,19 +137,19 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to process for pattern matches.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_matches, processed_text) where:
|
||||
Tuple of (processed_text, was_modified) where:
|
||||
|
||||
- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
|
||||
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
|
||||
- processed_text is the text after processing patterns
|
||||
- was_modified indicates whether any changes were made
|
||||
"""
|
||||
all_matches = []
|
||||
processed_text = text
|
||||
modified = False
|
||||
|
||||
for type, pattern_info in self._patterns.items():
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
# Escape special regex characters in the patterns
|
||||
start = re.escape(pattern_info["start"])
|
||||
end = re.escape(pattern_info["end"])
|
||||
action = pattern_info["action"]
|
||||
remove_match = pattern_info["remove_match"]
|
||||
|
||||
# Create regex to match from start pattern to end pattern
|
||||
# The .*? is non-greedy to handle nested patterns
|
||||
@@ -256,25 +165,24 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
# Create pattern match object
|
||||
pattern_match = PatternMatch(
|
||||
content=content.strip(), type=type, full_match=full_match
|
||||
pattern_id=pattern_id, full_match=full_match, content=content
|
||||
)
|
||||
|
||||
# Call the appropriate handler if registered
|
||||
if type in self._handlers:
|
||||
if pattern_id in self._handlers:
|
||||
try:
|
||||
await self._handlers[type](pattern_match)
|
||||
await self._handlers[pattern_id](pattern_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pattern handler for {type}: {e}")
|
||||
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
if action == MatchAction.REMOVE:
|
||||
if remove_match:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
else:
|
||||
all_matches.append(pattern_match)
|
||||
modified = True
|
||||
|
||||
return all_matches, processed_text
|
||||
return processed_text, modified
|
||||
|
||||
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
|
||||
def _has_incomplete_patterns(self, text: str) -> bool:
|
||||
"""Check if text contains incomplete pattern pairs.
|
||||
|
||||
Determines whether the text contains any start patterns without
|
||||
@@ -284,10 +192,9 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to check for incomplete patterns.
|
||||
|
||||
Returns:
|
||||
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
|
||||
or None if no patterns are found or all patterns are complete.
|
||||
True if there are incomplete patterns, False otherwise.
|
||||
"""
|
||||
for type, pattern_info in self._patterns.items():
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
start = pattern_info["start"]
|
||||
end = pattern_info["end"]
|
||||
|
||||
@@ -296,16 +203,12 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
end_count = text.count(end)
|
||||
|
||||
# If there are more starts than ends, we have incomplete patterns
|
||||
# Again, this is written generically but there only ever should
|
||||
# be one pattern active at a time, so the counts should be 0 or 1.
|
||||
# Which is why we base the return on the first found.
|
||||
if start_count > end_count:
|
||||
start_index = text.find(start)
|
||||
return [start_index, pattern_info]
|
||||
return True
|
||||
|
||||
return None
|
||||
return False
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[PatternMatch]:
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete pattern
|
||||
@@ -324,36 +227,16 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._text += text
|
||||
|
||||
# Process any complete patterns in the buffer
|
||||
patterns, processed_text = await self._process_complete_patterns(self._text)
|
||||
processed_text, modified = await self._process_complete_patterns(self._text)
|
||||
|
||||
self._text = processed_text
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
# Only update the buffer if modifications were made
|
||||
if modified:
|
||||
self._text = processed_text
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
return None
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
if self._has_incomplete_patterns(self._text):
|
||||
# Still waiting for complete patterns
|
||||
return None
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
@@ -361,9 +244,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
return result
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@ text processing scenarios.
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
|
||||
|
||||
class SimpleTextAggregator(BaseTextAggregator):
|
||||
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
self._text = ""
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
def text(self) -> str:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
return self._text
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
"""Aggregate text and return completed sentences.
|
||||
|
||||
Adds the new text to the buffer and checks for end-of-sentence markers.
|
||||
@@ -64,9 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
result = self._text[:eos_end_marker]
|
||||
self._text = self._text[eos_end_marker:]
|
||||
|
||||
if result:
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
return None
|
||||
return result
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the text buffer.
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
|
||||
|
||||
class SkipTagsAggregator(BaseTextAggregator):
|
||||
@@ -43,15 +43,15 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
self._current_tag_index: int = 0
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
def text(self) -> str:
|
||||
"""Get the currently buffered text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
return self._text
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
"""Aggregate text while respecting tag boundaries.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete
|
||||
@@ -63,9 +63,8 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
text: New text to add to the buffer.
|
||||
|
||||
Returns:
|
||||
An Aggregation object containing text up to a sentence boundary and
|
||||
marked as SENTENCE type or None if more text is needed to complete a
|
||||
sentence or close tags.
|
||||
Processed text up to a sentence boundary (when not within tags),
|
||||
or None if more text is needed to complete a sentence or close tags.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
@@ -81,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
return result
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -1005,53 +1005,3 @@ class TestLLMAssistantAggregator(
|
||||
) -> Optional[LLMAssistantAggregatorParams]:
|
||||
kwargs.pop("expect_stripped_words", None)
|
||||
return LLMAssistantAggregatorParams(**kwargs) if kwargs else None
|
||||
|
||||
async def test_multiple_text_mixed(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
|
||||
)
|
||||
|
||||
# The newer LLMAssistantAggregator expects TextFrames to declare
|
||||
# when they include inter-frame spaces.
|
||||
def make_text_frame(text: str, includes_spaces: bool) -> TextFrame:
|
||||
frame = TextFrame(text=text)
|
||||
frame.includes_inter_frame_spaces = includes_spaces
|
||||
return frame
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame("Hello ", includes_spaces=True),
|
||||
make_text_frame("Pipecat. ", includes_spaces=True),
|
||||
make_text_frame("Here's some", includes_spaces=True),
|
||||
make_text_frame(
|
||||
" code:", includes_spaces=True
|
||||
), # Validates ending includes_inter_frame_spaces run with no space
|
||||
make_text_frame("```python\nprint('Hello, World!')\n```", includes_spaces=False),
|
||||
make_text_frame(
|
||||
"```javascript\nconsole.log('Hello, World!');\n```", includes_spaces=False
|
||||
),
|
||||
make_text_frame(
|
||||
" And some more: ", includes_spaces=True
|
||||
), # Validates starting includes_inter_frame_spaces run with a space and ending it with no space
|
||||
make_text_frame("```html\n<div>Hello, World!</div>\n```", includes_spaces=False),
|
||||
make_text_frame(
|
||||
"Hope that ", includes_spaces=True
|
||||
), # Validates starting includes_inter_frame_spaces run with no space
|
||||
make_text_frame("helps!", includes_spaces=True),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(
|
||||
context,
|
||||
0,
|
||||
"Hello Pipecat. Here's some code: ```python\nprint('Hello, World!')\n``` ```javascript\nconsole.log('Hello, World!');\n``` And some more: ```html\n<div>Hello, World!</div>\n``` Hope that helps!",
|
||||
)
|
||||
|
||||
@@ -7,42 +7,30 @@
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
|
||||
|
||||
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.aggregator = PatternPairAggregator()
|
||||
self.test_handler = AsyncMock()
|
||||
self.code_handler = AsyncMock()
|
||||
|
||||
# Add a test pattern
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="test_pattern",
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
)
|
||||
self.aggregator.add_pattern(
|
||||
type="code_pattern",
|
||||
start_pattern="<code>",
|
||||
end_pattern="</code>",
|
||||
action=MatchAction.AGGREGATE,
|
||||
remove_match=True,
|
||||
)
|
||||
|
||||
# Register the mock handler
|
||||
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
|
||||
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
||||
|
||||
async def test_pattern_match_and_removal(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
@@ -51,50 +39,20 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.type, "test_pattern")
|
||||
self.assertEqual(call_args.pattern_id, "test_pattern")
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(call_args.content, "pattern content")
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
self.assertEqual(result.text, "Hello !")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Next sentence should be processed separately. Spaces around the sentence
|
||||
# should be stripped in the returned Aggregation.
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_pattern_match_and_aggregate(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
self.assertEqual(result.text, "Here is code")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.code_handler.assert_called_once()
|
||||
call_args = self.code_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.type, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(result.text, "pattern content")
|
||||
self.assertEqual(result.type, "code_pattern")
|
||||
self.assertEqual(result, "Hello !")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(result, " This is another sentence.")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
@@ -107,30 +65,26 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_not_called()
|
||||
|
||||
# Buffer should contain the incomplete text
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
|
||||
|
||||
# Reset and confirm buffer is cleared
|
||||
await self.aggregator.reset()
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
async def test_multiple_patterns(self):
|
||||
# Set up multiple patterns and handlers
|
||||
voice_handler = AsyncMock()
|
||||
emphasis_handler = AsyncMock()
|
||||
|
||||
self.aggregator.add_pattern(
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
action=MatchAction.REMOVE,
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
|
||||
)
|
||||
|
||||
self.aggregator.add_pattern(
|
||||
type="emphasis",
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="emphasis",
|
||||
start_pattern="<em>",
|
||||
end_pattern="</em>",
|
||||
action=MatchAction.KEEP, # Keep emphasis tags
|
||||
remove_match=False, # Keep emphasis tags
|
||||
)
|
||||
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
@@ -143,19 +97,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Both handlers should be called with correct data
|
||||
voice_handler.assert_called_once()
|
||||
voice_match = voice_handler.call_args[0][0]
|
||||
self.assertEqual(voice_match.type, "voice")
|
||||
self.assertEqual(voice_match.text, "female")
|
||||
self.assertEqual(voice_match.pattern_id, "voice")
|
||||
self.assertEqual(voice_match.content, "female")
|
||||
|
||||
emphasis_handler.assert_called_once()
|
||||
emphasis_match = emphasis_handler.call_args[0][0]
|
||||
self.assertEqual(emphasis_match.type, "emphasis")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
self.assertEqual(emphasis_match.pattern_id, "emphasis")
|
||||
self.assertEqual(emphasis_match.content, "very")
|
||||
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
async def test_handle_interruption(self):
|
||||
# Start with incomplete pattern
|
||||
@@ -166,7 +120,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.handle_interruption()
|
||||
|
||||
# Buffer should be cleared
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
# Handler should not have been called
|
||||
self.test_handler.assert_not_called()
|
||||
@@ -184,10 +138,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Handler should be called with entire content
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
||||
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
|
||||
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
self.assertEqual(result, "Hello Final sentence.")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
@@ -13,7 +13,6 @@ import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
@@ -75,7 +74,6 @@ async def test_run_piper_tts_success(aiohttp_client):
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
AggregatedTextFrame,
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -123,7 +121,7 @@ async def test_run_piper_tts_error(aiohttp_client):
|
||||
TTSSpeakFrame(text="Error case."),
|
||||
]
|
||||
|
||||
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
|
||||
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
|
||||
@@ -15,21 +15,15 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert self.aggregator.text.text == "Hello"
|
||||
assert self.aggregator.text == "Hello "
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text.text == ""
|
||||
assert self.aggregator.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
|
||||
assert self.aggregator.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
# Aggregators should strip leading/trailing spaces when returning text
|
||||
assert self.aggregator.text.text == "How are"
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == "How are you?"
|
||||
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
|
||||
assert await self.aggregator.aggregate("you?") == " How are you?"
|
||||
|
||||
@@ -18,18 +18,16 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# No tags involved, aggregate at end of sentence.
|
||||
result = await self.aggregator.aggregate("Hello Pipecat!")
|
||||
self.assertEqual(result.text, "Hello Pipecat!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(result, "Hello Pipecat!")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
async def test_basic_tags(self):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, avoid aggregation during tags.
|
||||
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
async def test_streaming_tags(self):
|
||||
await self.aggregator.reset()
|
||||
@@ -37,22 +35,20 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Tags involved, stream small chunk of texts.
|
||||
result = await self.aggregator.aggregate("My email is <sp")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <sp")
|
||||
self.assertEqual(self.aggregator.text, "My email is <sp")
|
||||
|
||||
result = await self.aggregator.aggregate("ell>foo.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
|
||||
|
||||
result = await self.aggregator.aggregate("bar@pipecat.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
|
||||
|
||||
result = await self.aggregator.aggregate("ai</spe")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
|
||||
result = await self.aggregator.aggregate("ll>.")
|
||||
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
|
||||
@@ -11,7 +11,6 @@ from datetime import datetime, timezone
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -131,11 +130,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(), # Wait for StartedSpeaking to process
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
TTSTextFrame(text="How"),
|
||||
TTSTextFrame(text="are"),
|
||||
TTSTextFrame(text="you?"),
|
||||
SleepFrame(), # Wait for text frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -196,9 +195,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
|
||||
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
|
||||
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
|
||||
TTSTextFrame(text=""), # Empty text
|
||||
TTSTextFrame(text=" "), # Just whitespace
|
||||
TTSTextFrame(text="\n"), # Just newline
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
@@ -236,14 +235,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
SleepFrame(),
|
||||
InterruptionFrame(), # User interrupts here
|
||||
SleepFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="New"),
|
||||
TTSTextFrame(text="response"),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -300,8 +299,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
|
||||
@@ -339,8 +338,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
SleepFrame(), # Ensure messages are processed
|
||||
CancelFrame(),
|
||||
]
|
||||
@@ -402,8 +401,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="Assistant"),
|
||||
TTSTextFrame(text="message"),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
@@ -440,7 +439,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Test the specific pattern shared
|
||||
def make_tts_text_frame(text: str) -> TTSTextFrame:
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
|
||||
frame = TTSTextFrame(text=text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.transcript_processor import TurnAwareTranscriptProcessor
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestTurnAwareTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for TurnAwareTranscriptProcessor."""
|
||||
|
||||
async def test_basic_turn_flow(self):
|
||||
"""Test basic turn start/end with user and assistant speech."""
|
||||
processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Track events
|
||||
turn_started_calls = []
|
||||
turn_ended_calls = []
|
||||
|
||||
@processor.event_handler("on_turn_started")
|
||||
async def on_turn_started(proc, turn_number):
|
||||
turn_started_calls.append(turn_number)
|
||||
|
||||
@processor.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
|
||||
turn_ended_calls.append(
|
||||
{
|
||||
"turn_number": turn_number,
|
||||
"user_text": user_text,
|
||||
"assistant_text": assistant_text,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
# Turn 1: User speaks, bot responds
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="Hi", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" there", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1),
|
||||
]
|
||||
|
||||
await run_test(processor, frames_to_send=frames_to_send)
|
||||
|
||||
# Verify events
|
||||
self.assertEqual(
|
||||
len(turn_started_calls), 1, f"Expected 1 turn started, got {len(turn_started_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_started_calls[0], 1)
|
||||
|
||||
self.assertEqual(
|
||||
len(turn_ended_calls), 1, f"Expected 1 turn ended, got {len(turn_ended_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
|
||||
self.assertEqual(turn_ended_calls[0]["user_text"], "Hello")
|
||||
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hi there")
|
||||
self.assertFalse(turn_ended_calls[0]["interrupted"])
|
||||
|
||||
async def test_interruption(self):
|
||||
"""Test turn ending on interruption."""
|
||||
processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Track events
|
||||
turn_ended_calls = []
|
||||
|
||||
@processor.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
|
||||
turn_ended_calls.append(
|
||||
{
|
||||
"turn_number": turn_number,
|
||||
"user_text": user_text,
|
||||
"assistant_text": assistant_text,
|
||||
"interrupted": interrupted,
|
||||
}
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
# User speaks
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Tell me", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
# Bot starts responding
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="Sure", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" I", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" can", aggregated_by=AggregationType.WORD),
|
||||
# User interrupts
|
||||
InterruptionFrame(),
|
||||
# New turn starts
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Wait", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.1),
|
||||
]
|
||||
|
||||
await run_test(processor, frames_to_send=frames_to_send)
|
||||
|
||||
# Verify first turn was interrupted
|
||||
self.assertGreaterEqual(
|
||||
len(turn_ended_calls), 1, f"Expected at least 1 turn ended, got {len(turn_ended_calls)}"
|
||||
)
|
||||
first_turn = turn_ended_calls[0]
|
||||
self.assertEqual(first_turn["user_text"], "Tell me")
|
||||
# Note: In this test flow, InterruptionFrame arrives before TTSTextFrames are processed,
|
||||
# so assistant text may be empty. In real scenarios, word timestamps ensure proper capture.
|
||||
self.assertIn(first_turn["assistant_text"], ["", "Sure I can", "Sure I can"])
|
||||
self.assertTrue(first_turn["interrupted"])
|
||||
|
||||
async def test_multiple_turns(self):
|
||||
"""Test multiple back-and-forth turns."""
|
||||
processor = TurnAwareTranscriptProcessor()
|
||||
|
||||
# Track events
|
||||
turn_started_calls = []
|
||||
turn_ended_calls = []
|
||||
|
||||
@processor.event_handler("on_turn_started")
|
||||
async def on_turn_started(proc, turn_number):
|
||||
turn_started_calls.append(turn_number)
|
||||
|
||||
@processor.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(proc, turn_number, user_text, assistant_text, interrupted):
|
||||
turn_ended_calls.append(
|
||||
{
|
||||
"turn_number": turn_number,
|
||||
"user_text": user_text,
|
||||
"assistant_text": assistant_text,
|
||||
}
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
# Turn 1
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.05),
|
||||
# Turn 2
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="How are you", user_id="user1", timestamp=""),
|
||||
SleepFrame(sleep=0.01), # Allow transcription to process
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="I'm", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text=" good", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1),
|
||||
]
|
||||
|
||||
await run_test(processor, frames_to_send=frames_to_send)
|
||||
|
||||
# Verify multiple turns
|
||||
self.assertEqual(
|
||||
len(turn_started_calls), 2, f"Expected 2 turns started, got {len(turn_started_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_started_calls, [1, 2])
|
||||
|
||||
self.assertEqual(
|
||||
len(turn_ended_calls), 2, f"Expected 2 turns ended, got {len(turn_ended_calls)}"
|
||||
)
|
||||
self.assertEqual(turn_ended_calls[0]["turn_number"], 1)
|
||||
self.assertEqual(turn_ended_calls[0]["user_text"], "Hi")
|
||||
self.assertEqual(turn_ended_calls[0]["assistant_text"], "Hello")
|
||||
|
||||
self.assertEqual(turn_ended_calls[1]["turn_number"], 2)
|
||||
self.assertEqual(turn_ended_calls[1]["user_text"], "How are you")
|
||||
self.assertEqual(turn_ended_calls[1]["assistant_text"], "I'm good")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
113
uv.lock
generated
113
uv.lock
generated
@@ -45,20 +45,20 @@ sdist = { url = "https://files.pythonhosted.org/packages/99/83/bf38b95d98c67b8eb
|
||||
|
||||
[[package]]
|
||||
name = "aioboto3"
|
||||
version = "15.5.0"
|
||||
version = "15.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiobotocore", extra = ["boto3"] },
|
||||
{ name = "aiofiles" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/01/92e9ab00f36e2899315f49eefcd5b4685fbb19016c7f19a9edf06da80bb0/aioboto3-15.5.0.tar.gz", hash = "sha256:ea8d8787d315594842fbfcf2c4dce3bac2ad61be275bc8584b2ce9a3402a6979", size = 255069, upload-time = "2025-10-30T13:37:16.122Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/80/d0/ed107e16551ba1b93ddcca9a6bf79580450945268a8bc396530687b3189f/aioboto3-15.0.0.tar.gz", hash = "sha256:dce40b701d1f8e0886dc874d27cd9799b8bf6b32d63743f57e7bef7e4a562756", size = 225278, upload-time = "2025-06-26T16:30:48.967Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/3e/e8f5b665bca646d43b916763c901e00a07e40f7746c9128bdc912a089424/aioboto3-15.5.0-py3-none-any.whl", hash = "sha256:cc880c4d6a8481dd7e05da89f41c384dbd841454fc1998ae25ca9c39201437a6", size = 35913, upload-time = "2025-10-30T13:37:14.549Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/95/d69c744f408e5e4592fe53ed98fc244dd13b83d84cf1f83b2499d98bfcc9/aioboto3-15.0.0-py3-none-any.whl", hash = "sha256:9cf54b3627c8b34bb82eaf43ab327e7027e37f92b1e10dd5cfe343cd512568d0", size = 35785, upload-time = "2025-06-26T16:30:47.444Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiobotocore"
|
||||
version = "2.25.1"
|
||||
version = "2.23.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@@ -69,9 +69,9 @@ dependencies = [
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/62/94/2e4ec48cf1abb89971cb2612d86f979a6240520f0a659b53a43116d344dc/aiobotocore-2.25.1.tar.gz", hash = "sha256:ea9be739bfd7ece8864f072ec99bb9ed5c7e78ebb2b0b15f29781fbe02daedbc", size = 120560, upload-time = "2025-10-28T22:33:21.787Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/25/4b06ea1214ddf020a28df27dc7136ac9dfaf87929d51e6f6044dd350ed67/aiobotocore-2.23.0.tar.gz", hash = "sha256:0333931365a6c7053aee292fe6ef50c74690c4ae06bb019afdf706cb6f2f5e32", size = 115825, upload-time = "2025-06-12T23:46:38.055Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/95/2a/d275ec4ce5cd0096665043995a7d76f5d0524853c76a3d04656de49f8808/aiobotocore-2.25.1-py3-none-any.whl", hash = "sha256:eb6daebe3cbef5b39a0bb2a97cffbe9c7cb46b2fcc399ad141f369f3c2134b1f", size = 86039, upload-time = "2025-10-28T22:33:19.949Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/43/ccf9b29669cdb09fd4bfc0a8effeb2973b22a0f3c3be4142d0b485975d11/aiobotocore-2.23.0-py3-none-any.whl", hash = "sha256:8202cebbf147804a083a02bc282fbfda873bfdd0065fd34b64784acb7757b66e", size = 84161, upload-time = "2025-06-12T23:46:36.305Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -419,30 +419,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-bedrock-runtime"
|
||||
version = "0.2.0"
|
||||
version = "0.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "smithy-aws-core", extra = ["eventstream", "json"], marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-core", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-http", extra = ["awscrt"], marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/db/94/f2451bb09c106e5690bbb88fc366637cdcec942b352ed9bb788804c877e0/aws_sdk_bedrock_runtime-0.2.0.tar.gz", hash = "sha256:8de52dd4492e74c73244d4b41a52304e1db368814a10e49dbbf8f4e8e412cd0e", size = 88156, upload-time = "2025-11-22T00:35:44.978Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1d/78/48574454b3cac869df67665e4a403ebfc3abfcfba2c2ff01ccfd67d55f8f/aws_sdk_bedrock_runtime-0.1.1.tar.gz", hash = "sha256:c896f99e675c3a1ab600633a07b785f3dc9fe8ab94f640b1f992b63da2dfc784", size = 82446, upload-time = "2025-10-21T20:25:25.845Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/6b/07fbddd31dd6e38c967fe088b5e91a7cc3a2bc0f645f18b4e5d45bc03f1f/aws_sdk_bedrock_runtime-0.2.0-py3-none-any.whl", hash = "sha256:19594de50a52d199d73efca153c0a2328bd781827715a6e012d50b11085236cc", size = 79875, upload-time = "2025-11-22T00:35:44.092Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-sagemaker-runtime-http2"
|
||||
version = "0.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "smithy-aws-core", extra = ["eventstream", "json"], marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-core", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-http", extra = ["awscrt"], marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6e/ca/00f9c55887fc0f3fa345995dd871d40ff81473ab1591e56b4b4483d99d00/aws_sdk_sagemaker_runtime_http2-0.1.0.tar.gz", hash = "sha256:5077ec0c4440495b15004bbf04e27bc0bc137f1f8950d32195c6b45d7788d837", size = 20863, upload-time = "2025-11-22T00:20:56.358Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/24/2e2f727c51c20f4625cd19364d9421dbd7c893fe2b53a46eb0caaf6263a2/aws_sdk_sagemaker_runtime_http2-0.1.0-py3-none-any.whl", hash = "sha256:1aebb728ba6c6d14e58e29ecf89b51f7abbe8786d34144f8a7d59a419e80bd2f", size = 21911, upload-time = "2025-11-22T00:20:55.054Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/07/62c0b70223d178c138f29124ac2f7973a6ba803abc7735b6a01a85217f3d/aws_sdk_bedrock_runtime-0.1.1-py3-none-any.whl", hash = "sha256:c0336b377b2112cf88197d3d44302fbeb3efb1101989fa49ae55e78f49cfe345", size = 74954, upload-time = "2025-10-21T20:25:24.973Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -620,30 +606,30 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.40.61"
|
||||
version = "1.38.27"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
{ name = "jmespath" },
|
||||
{ name = "s3transfer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ed/f9/6ef8feb52c3cce5ec3967a535a6114b57ac7949fd166b0f3090c2b06e4e5/boto3-1.40.61.tar.gz", hash = "sha256:d6c56277251adf6c2bdd25249feae625abe4966831676689ff23b4694dea5b12", size = 111535, upload-time = "2025-10-28T19:26:57.247Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/96/fc74d8521d2369dd8c412438401ff12e1350a1cd3eab5c758ed3dd5e5f82/boto3-1.38.27.tar.gz", hash = "sha256:94bd7fdd92d5701b362d4df100d21e28f8307a67ff56b6a8b0398119cf22f859", size = 111875, upload-time = "2025-05-30T19:32:41.352Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/24/3bf865b07d15fea85b63504856e137029b6acbc73762496064219cdb265d/boto3-1.40.61-py3-none-any.whl", hash = "sha256:6b9c57b2a922b5d8c17766e29ed792586a818098efe84def27c8f582b33f898c", size = 139321, upload-time = "2025-10-28T19:26:55.007Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/8b/b2361188bd1e293eede1bc165e2461d390394f71ec0c8c21211c8dabf62c/boto3-1.38.27-py3-none-any.whl", hash = "sha256:95f5fe688795303a8a15e8b7e7f255cadab35eae459d00cc281a4fd77252ea80", size = 139938, upload-time = "2025-05-30T19:32:38.006Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.40.61"
|
||||
version = "1.38.27"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jmespath" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/28/a3/81d3a47c2dbfd76f185d3b894f2ad01a75096c006a2dd91f237dca182188/botocore-1.40.61.tar.gz", hash = "sha256:a2487ad69b090f9cccd64cf07c7021cd80ee9c0655ad974f87045b02f3ef52cd", size = 14393956, upload-time = "2025-10-28T19:26:46.108Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/36/5e/67899214ad57f7f26af5bd776ac5eb583dc4ecf5c1e52e2cbfdc200e487a/botocore-1.38.27.tar.gz", hash = "sha256:9788f7efe974328a38cbade64cc0b1e67d27944b899f88cb786ae362973133b6", size = 13919963, upload-time = "2025-05-30T19:32:29.657Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/38/c5/f6ce561004db45f0b847c2cd9b19c67c6bf348a82018a48cb718be6b58b0/botocore-1.40.61-py3-none-any.whl", hash = "sha256:17ebae412692fd4824f99cde0f08d50126dc97954008e5ba2b522eb049238aa7", size = 14055973, upload-time = "2025-10-28T19:26:42.15Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/83/a753562020b69fa90cebc39e8af2c753b24dcdc74bee8355ee3f6cefdf34/botocore-1.38.27-py3-none-any.whl", hash = "sha256:a785d5e9a5eda88ad6ab9ed8b87d1f2ac409d0226bba6ff801c55359e94d91a8", size = 13580545, upload-time = "2025-05-30T19:32:26.712Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1305,13 +1291,13 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "daily-python"
|
||||
version = "0.22.0"
|
||||
version = "0.21.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/db/d6c311ba760123a9987a9ae291171c9a6af11ee4dbefdb661d65e2ac13a2/daily_python-0.22.0-cp37-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2ef7591a7929c5d9e5ea78329ea049d2f313bd3d2d289f5f4ecce4bb3799c3d0", size = 13526264, upload-time = "2025-11-20T05:52:04.134Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/47/f1f6d893e7aab4b2e3d3b20d0dd8fbf31c7a71597d274aae1d288e36fac3/daily_python-0.22.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:0fbc8467c471ce536dc6214a24cc28cbc38ff61113b1714e09d0eafc2741fc5a", size = 12041400, upload-time = "2025-11-20T05:52:06.419Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/c9/8f26944cd55ece2ab9c076fae5c1fcf4fdc8639ea6f2b861566d26ad9e00/daily_python-0.22.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:b8aa9531b0fa2852b41935697d184be86318eaee3b35f49e0b5714e53cdb524a", size = 14147194, upload-time = "2025-11-20T05:52:08.4Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/02/d86f4cee39bcdb112d83e2cf12345d7a974cbad5dafb350788148644f16b/daily_python-0.22.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:8f49fceee92830aaf53a65053f332504b2cc62af79c38edb9bce8707c9fa4b0c", size = 14654605, upload-time = "2025-11-20T05:52:10.639Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/11/99590f8b7aad077f3f9b5b59d39b010aee0bd01b14dece8ae1e93d8080e7/daily_python-0.21.0-cp37-abi3-macosx_10_15_x86_64.whl", hash = "sha256:bdec96417825181559769bb2258ae688d1215949a1878336194e36fb452274a8", size = 13277066, upload-time = "2025-10-29T00:20:49.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/db/8c57f1a1b713ba3393584ac2be32d8074d3022a2c2c17c28eb4cd2aa3629/daily_python-0.21.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:18677fa1415a0dc48b891cdf2fb8fe9dabc70e1b019d5aaa3d0699ccc8d187c9", size = 11644908, upload-time = "2025-10-29T00:20:52.106Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/b6/b03f2f58a367d6ef4bb728715471542fdfa68afa8a177670139c3a2aadb7/daily_python-0.21.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:97eb97352fe74227061b678e330b8befcfa4c694feb6eb2b09fe6eacec00ad6d", size = 13652356, upload-time = "2025-10-29T00:20:54.813Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/76/bde65f6f8d4c1679dc6c185fa37dae9223f6ddb4b7ced728ef46504956f7/daily_python-0.21.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:68c3e36f609fc2fce79e4d17ecf1021eadd836506db6c5125f95c682bcf3612a", size = 14304643, upload-time = "2025-10-29T00:20:57.194Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4522,7 +4508,6 @@ langchain = [
|
||||
livekit = [
|
||||
{ name = "livekit" },
|
||||
{ name = "livekit-api" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "tenacity" },
|
||||
]
|
||||
lmnt = [
|
||||
@@ -4584,9 +4569,6 @@ runner = [
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
sagemaker = [
|
||||
{ name = "aws-sdk-sagemaker-runtime-http2", marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sarvam = [
|
||||
{ name = "sarvamai" },
|
||||
{ name = "websockets" },
|
||||
@@ -4666,18 +4648,17 @@ docs = [
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'moondream'", specifier = "~=1.10.0" },
|
||||
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.1.0" },
|
||||
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.5.0" },
|
||||
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.0.0" },
|
||||
{ name = "aiofiles", specifier = ">=24.1.0,<25" },
|
||||
{ name = "aiohttp", specifier = ">=3.11.12,<4" },
|
||||
{ name = "aiortc", marker = "extra == 'webrtc'", specifier = ">=1.13.0,<2" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.49.0" },
|
||||
{ name = "audioop-lts", marker = "python_full_version >= '3.13'", specifier = "~=0.2.1" },
|
||||
{ name = "aws-sdk-bedrock-runtime", marker = "python_full_version >= '3.12' and extra == 'aws-nova-sonic'", specifier = "~=0.2.0" },
|
||||
{ name = "aws-sdk-sagemaker-runtime-http2", marker = "python_full_version >= '3.12' and extra == 'sagemaker'" },
|
||||
{ name = "aws-sdk-bedrock-runtime", marker = "python_full_version >= '3.12' and extra == 'aws-nova-sonic'", specifier = "~=0.1.1" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'azure'", specifier = "~=1.42.0" },
|
||||
{ name = "cartesia", marker = "extra == 'cartesia'", specifier = "~=2.0.3" },
|
||||
{ name = "coremltools", marker = "extra == 'local-smart-turn'", specifier = ">=8.0" },
|
||||
{ name = "daily-python", marker = "extra == 'daily'", specifier = "~=0.22.0" },
|
||||
{ name = "daily-python", marker = "extra == 'daily'", specifier = "~=0.21.0" },
|
||||
{ name = "deepgram-sdk", marker = "extra == 'deepgram'", specifier = "~=4.7.0" },
|
||||
{ name = "docstring-parser", specifier = "~=0.16" },
|
||||
{ name = "einops", marker = "extra == 'moondream'", specifier = "~=0.8.0" },
|
||||
@@ -4740,14 +4721,13 @@ requires-dist = [
|
||||
{ name = "pyaudio", marker = "extra == 'local'", specifier = "~=0.2.14" },
|
||||
{ name = "pydantic", specifier = ">=2.10.6,<3" },
|
||||
{ name = "pygobject", marker = "extra == 'gstreamer'", specifier = "~=3.50.0" },
|
||||
{ name = "pyjwt", marker = "extra == 'livekit'", specifier = ">=2.10.1" },
|
||||
{ name = "pyloudnorm", specifier = "~=0.1.1" },
|
||||
{ name = "python-dotenv", marker = "extra == 'runner'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "pyvips", extras = ["binary"], marker = "extra == 'moondream'", specifier = "~=3.0.0" },
|
||||
{ name = "resampy", specifier = "~=0.4.3" },
|
||||
{ name = "sarvamai", marker = "extra == 'sarvam'", specifier = "==0.1.21" },
|
||||
{ name = "sentry-sdk", marker = "extra == 'sentry'", specifier = ">=2.28.0,<3" },
|
||||
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=1.0.3" },
|
||||
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=0.1.25" },
|
||||
{ name = "soundfile", marker = "extra == 'soundfile'", specifier = "~=0.13.1" },
|
||||
{ name = "soxr", specifier = "~=0.5.0" },
|
||||
{ name = "speechmatics-rt", marker = "extra == 'speechmatics'", specifier = ">=0.5.0" },
|
||||
@@ -4765,7 +4745,7 @@ requires-dist = [
|
||||
{ name = "wait-for2", marker = "python_full_version < '3.12'", specifier = ">=0.4.1" },
|
||||
{ name = "websockets", marker = "extra == 'websockets-base'", specifier = ">=13.1,<16.0" },
|
||||
]
|
||||
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "nim", "noisereduce", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
|
||||
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "deepseek", "daily", "deepgram", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "krisp", "koala", "langchain", "livekit", "lmnt", "local", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "nim", "neuphonic", "noisereduce", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "rime", "riva", "runner", "sambanova", "sarvam", "sentry", "local-smart-turn", "local-smart-turn-v3", "remote-smart-turn", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -6222,14 +6202,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.14.0"
|
||||
version = "0.13.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/62/74/8d69dcb7a9efe8baa2046891735e5dfe433ad558ae23d9e3c14c633d1d58/s3transfer-0.14.0.tar.gz", hash = "sha256:eff12264e7c8b4985074ccce27a3b38a485bb7f7422cc8046fee9be4983e4125", size = 151547, upload-time = "2025-09-09T19:23:31.089Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/05/d52bf1e65044b4e5e27d4e63e8d1579dbdec54fce685908ae09bc3720030/s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf", size = 150589, upload-time = "2025-07-18T19:22:42.31Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/48/f0/ae7ca09223a81a1d890b2557186ea015f6e0502e9b8cb8e1813f1d8cfa4e/s3transfer-0.14.0-py3-none-any.whl", hash = "sha256:ea3b790c7077558ed1f02a3072fb3cb992bbbd253392f4b6e9e8976941c7d456", size = 85712, upload-time = "2025-09-09T19:23:30.041Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/4f/d073e09df851cfa251ef7840007d04db3293a0482ce607d2b993926089be/s3transfer-0.13.1-py3-none-any.whl", hash = "sha256:a981aa7429be23fe6dfc13e80e4020057cbab622b08c0315288758d67cabc724", size = 85308, upload-time = "2025-07-18T19:22:40.947Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6516,19 +6496,18 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "simli-ai"
|
||||
version = "1.0.3"
|
||||
version = "0.1.25"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiortc" },
|
||||
{ name = "av" },
|
||||
{ name = "httpx" },
|
||||
{ name = "livekit" },
|
||||
{ name = "numpy" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/81/03/b0b3e12c68fd3f9c57f6afeee67841349e4866b88760f413357af3043ae4/simli_ai-1.0.3.tar.gz", hash = "sha256:e96b0621a1dbd9582b2ae3d51eefd4995983b49c1f1061eb9239707b15a1ee27", size = 13350, upload-time = "2025-11-13T12:22:32.514Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/6a/b28f90baf76f6a60865985f6233ff44abc72d45b66b76658bff3961e20a7/simli_ai-0.1.25.tar.gz", hash = "sha256:7a00b3426dc26a6a421641072c3e49014b7950c621cf4544152f35c58d13fcff", size = 13182, upload-time = "2025-11-06T16:27:08.862Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/d1/dc382ba529de0d2d51f35e9bfd20b41d8f5c96404a3aa24bae97a5a5e51f/simli_ai-1.0.3-py3-none-any.whl", hash = "sha256:ffafa7540aa28833e207be8f3b199367c7f500dac1a8ba0108395bfb7d8362bc", size = 13863, upload-time = "2025-11-13T12:22:31.218Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/57/ae1032fd88214ea4ee6d3028c817c12a999eb90a67766bbab31e9819385a/simli_ai-0.1.25-py3-none-any.whl", hash = "sha256:7d01f65321dc9052f25e15d0463af6a20a86c6d37d9a7b3a2c4b01cbec0a54ed", size = 13651, upload-time = "2025-11-06T16:27:07.765Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6542,16 +6521,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "smithy-aws-core"
|
||||
version = "0.2.0"
|
||||
version = "0.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aws-sdk-signers", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-core", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-http", marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c1/c8/5970c869527972b23a1733de3993d54283c84a2340e84acdd48a11aa0ff4/smithy_aws_core-0.2.0.tar.gz", hash = "sha256:dfa1ecd311d6f0a16f48c86d793085e2a0a33a46de897d129dd1f142a43bf7f6", size = 11344, upload-time = "2025-11-21T18:33:01.928Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/56/d3/f847e0fd36b95aa36ce3a4c9ce1a08e16b2aa9a56b71714045c9c924e282/smithy_aws_core-0.1.1.tar.gz", hash = "sha256:78dfd7040fc2bc72b6af293096642fc9a7bfd2db28ddbdf7c4110535eab9d662", size = 11196, upload-time = "2025-10-21T20:21:18.648Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/25/739c0005a6cb4effbc2d37fe23590660b508fe314200f4acf94410a8f315/smithy_aws_core-0.2.0-py3-none-any.whl", hash = "sha256:d112082ef77758e1977f8694cf690ac35c76570c12a6590fccd5da085a3ce507", size = 18966, upload-time = "2025-11-21T18:33:00.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/04/87cb06f0f6d664b5cffdef6d4042dd52c11c138436084d30ffdaa3543031/smithy_aws_core-0.1.1-py3-none-any.whl", hash = "sha256:0d1634f276c2999dc2a04fafef63b9d28309de50d939d1d49df952773a7063c4", size = 18963, upload-time = "2025-10-21T20:21:17.692Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -6564,35 +6543,35 @@ json = [
|
||||
|
||||
[[package]]
|
||||
name = "smithy-aws-event-stream"
|
||||
version = "0.2.0"
|
||||
version = "0.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "smithy-core", marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/31/90/78283c21484f8cf9862982e53bc2769b784910735fb5fb2400a17bfb5fdd/smithy_aws_event_stream-0.2.0.tar.gz", hash = "sha256:99700a11346e7ab1435ff2e53e6f6d60a1e857f2b2ee1941d40b54270adf3323", size = 12278, upload-time = "2025-11-21T18:33:03.79Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/49/26/8ff24194efed60b2df18f610ea05fa2a4c6546858b80a0a51335a4943b9b/smithy_aws_event_stream-0.1.0.tar.gz", hash = "sha256:6634691a3bf5d4801a2c29f0761db2dc4771f3ae43cdee50c10d4b4bb2f86475", size = 12216, upload-time = "2025-09-29T19:37:14.659Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/f5/08b997eee81b55150496ce565f0e03c72d0c80e5b218170bdeae7c46a5a4/smithy_aws_event_stream-0.2.0-py3-none-any.whl", hash = "sha256:679a0c7d944e67d3a55d287541b3ca1e61f9d6a62e13401367dcc034e75aa55d", size = 15567, upload-time = "2025-11-21T18:33:02.711Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/c4/2b63d31af58fc359577e5515bf730348a235f2f2fa10e17af8640495c81c/smithy_aws_event_stream-0.1.0-py3-none-any.whl", hash = "sha256:17a7300a85cb90df4c6c23f895ca6343361fa419203c3cf80019edd7d3b5f036", size = 15581, upload-time = "2025-09-29T19:37:13.589Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smithy-core"
|
||||
version = "0.2.0"
|
||||
version = "0.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c7/f6/140f0be9331dd7cd8fa012b3ca4735df39a1a81d03eea89728f997249116/smithy_core-0.2.0.tar.gz", hash = "sha256:05c3e3309df5dcb9cf53e241bd57a96510e4575186443ea157db9dbb59b6c85e", size = 50334, upload-time = "2025-11-21T18:33:05.697Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b9/8d/16028d03456071d21de7591f1e1e6a1cc81b2389e53ef8663dbf59caf9cd/smithy_core-0.1.0.tar.gz", hash = "sha256:b159b8905264e1e4c613eab9f74cec0b2f5b8119c42fbadddb4da0a8ed8050e9", size = 48415, upload-time = "2025-09-29T19:37:16.873Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/16/e3/d0defa2acf50b91625fe15e3ddb0c8e41ff64363a1f4cd9b8f19ae2ec0c6/smithy_core-0.2.0-py3-none-any.whl", hash = "sha256:db4620da3497abb60f79ac1d8a738d3eac46d7e820bfb50c777c36e932915239", size = 64777, upload-time = "2025-11-21T18:33:04.591Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/5b/563cb2beadcfa40597b0c3ff3f2d42e21f065b14782c4ba9cb41a44b745f/smithy_core-0.1.0-py3-none-any.whl", hash = "sha256:cb44e9355fb89e89f2c6ba6a1d59c5db4f2f7282c72d31d9307b6202d66cd0fa", size = 62895, upload-time = "2025-09-29T19:37:15.917Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smithy-http"
|
||||
version = "0.3.0"
|
||||
version = "0.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "smithy-core", marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1c/c7/4d8be56e897f99f3b6ffcdf52ba00a468febc939fca85b90f1c122450830/smithy_http-0.3.0.tar.gz", hash = "sha256:55dcc3af315eee6863d2f3f58ada1d9cb4bcc3a57faac10a1b21d4a93722f520", size = 28674, upload-time = "2025-11-21T18:33:07.387Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3c/1c/44e99a7dfb8c39bf0c3d998accdf4573a7a3488863b90f21af260cec2d45/smithy_http-0.2.0.tar.gz", hash = "sha256:2382562fa9af326be455f14b18615f16ffe9db756e51b2a4ca0d23e1b881cff8", size = 26729, upload-time = "2025-10-21T20:21:06.146Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/e5/59ae79ecdc9a935ad10512c581b3054ebb1afd90498ecc8afaf141dbc22b/smithy_http-0.3.0-py3-none-any.whl", hash = "sha256:972924304febd77c7134a7cffab83ce3b48423ff966dcc1f257e2c0d58fa9b18", size = 40520, upload-time = "2025-11-21T18:33:06.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/e2/d475fad81ac74ec0e145cb6d72afe5ecde4e2358bd632c2fd5d3f4bc87dc/smithy_http-0.2.0-py3-none-any.whl", hash = "sha256:49ee2402d7737798d70f99f491fbfb2a5767283ae562e21b6f86e3fd14f3e3e0", size = 37328, upload-time = "2025-10-21T20:21:05.362Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -6602,15 +6581,15 @@ awscrt = [
|
||||
|
||||
[[package]]
|
||||
name = "smithy-json"
|
||||
version = "0.2.0"
|
||||
version = "0.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "ijson", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "smithy-core", marker = "python_full_version >= '3.12'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/89/cf/e319a2a299b27bc0addf46ee3d4b9c25ec0817e3a0507b2b7a33eddc19f1/smithy_json-0.2.0.tar.gz", hash = "sha256:0946066fdda15d6a579dfdd4b61a547ab915eb057bd176fc2bc17d01dc789499", size = 7157, upload-time = "2025-11-21T18:33:08.968Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e2/5b/0ecb10007475e1b8faca3bbff1be2fc6edb3ea12ffc5e939e2249be95325/smithy_json-0.1.0.tar.gz", hash = "sha256:84fb48e445b87d850c240d837702c16b259ea53bad76b655ac1bbd8094d48912", size = 7086, upload-time = "2025-09-29T19:37:20.432Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2e/b1/33012ac5b2e5940a00b6e1ccc313330e6f8692152a151f72a398cd6be0e0/smithy_json-0.2.0-py3-none-any.whl", hash = "sha256:5018a4e61731afa3094a02d737d4f956dbf270c271410c089045a17d86fc3b3b", size = 9911, upload-time = "2025-11-21T18:33:08.267Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/95/e11c04e56aae12b62e38c49000004a1dc598a64dc207018c08448efde322/smithy_json-0.1.0-py3-none-any.whl", hash = "sha256:80ff64734dccdabf1ba6a2908555b97e60f62c07c3a27df48e421ee058413cb9", size = 9914, upload-time = "2025-09-29T19:37:19.459Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user