Compare commits
29 Commits
mb/fix-pip
...
bot-output
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8640d84ae | ||
|
|
23e4e29999 | ||
|
|
713b488bb6 | ||
|
|
71b87fd420 | ||
|
|
3f269f9834 | ||
|
|
4c698777f3 | ||
|
|
5ca04ad741 | ||
|
|
9a3902a82c | ||
|
|
8ab0c92681 | ||
|
|
124f147a37 | ||
|
|
ed808a9246 | ||
|
|
e9de9daf8c | ||
|
|
82b9c4f0b6 | ||
|
|
5dfe20be91 | ||
|
|
0d2c5286fa | ||
|
|
29417ba44d | ||
|
|
bc6a9cac26 | ||
|
|
8a90decbc0 | ||
|
|
ccca6e8d81 | ||
|
|
e6dc1a510d | ||
|
|
69945c5e0d | ||
|
|
5c8635570d | ||
|
|
fe9aa3383e | ||
|
|
d1116d149e | ||
|
|
74a0e8c88d | ||
|
|
fbbad27d37 | ||
|
|
2fab3e2286 | ||
|
|
a7b2052b38 | ||
|
|
3c76917c1e |
116
CHANGELOG.md
116
CHANGELOG.md
@@ -16,8 +16,87 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
services that subclass `TTSService` can indicate whether the text in the
|
||||
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
|
||||
|
||||
- Introduced new `AggregatedTextFrame` type to support representing a best effort of
|
||||
the perceived llm output whether or not it is processed by the TTS. This new frame
|
||||
type includes the field `aggregated_by` to represent the conceptual format by which
|
||||
the given text is aggregated. `TTSTextFrame`s now inherit from `AggregatedTextFrame`.
|
||||
With this inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate
|
||||
the perceived output and determine whether or not the text was spoken based on if that
|
||||
frame is also a `TTSTextFrame`. (See bullet below on new `bot-output` which takes
|
||||
advantage of this)
|
||||
|
||||
- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how
|
||||
LLMTextFrames should be aggregated and considered. It's purpose is to turn
|
||||
`LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still
|
||||
aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you
|
||||
wish to override how the llm text is aggregated, you should no longer override the
|
||||
TTS's internal aggregator, but instead, insert this processor between your LLM and
|
||||
TTS in the pipeline.
|
||||
|
||||
- New `bot-output` RTVI message to represent what the bot actually "says".
|
||||
- The `RTVIObserver` now emits `bot-output` messages based off the new `AggregatedTextFrame`s
|
||||
(`bot-tts-text` and `bot-llm-text` are still supported and generated, but `bot-transcript` is
|
||||
now deprecated in lieu of this new, more thorough, message).
|
||||
- The new `RTVIBotOutputMessage` includes the fields:
|
||||
- `spoken`: A boolean indicating whether the text was spoken by TTS
|
||||
- `aggregated_by`: A string representing how the text was aggregated ("sentence", "word",
|
||||
"my custom aggregation")
|
||||
- Introduced new fields to `RTVIObserver` to support the new `bot-output` messaging:
|
||||
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output messages.
|
||||
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that match
|
||||
aggregation types that should not be included in bot-output messages. (Ex. `credit_card`)
|
||||
- Introduced new methods, `add_text_transformer()` and `remove_text_transformer()`, to `RTVIObserver` to support providing (and subsequently removing)
|
||||
callbacks for various types of aggregations (or all aggregations with `*`) that can modify the
|
||||
text before being sent as a `bot-output` or `tts-text` message. (Think obscuring the credit card
|
||||
or inserting extra detail the client might want that the context doesn't need.)
|
||||
|
||||
- Updated the base aggregator type:
|
||||
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
|
||||
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
|
||||
aggregation")
|
||||
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
|
||||
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
|
||||
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
|
||||
(instead of `Optional[str]`). To update:
|
||||
```
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
if (aggregation):
|
||||
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
|
||||
```
|
||||
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
|
||||
produce/consume `Aggregation` objects.
|
||||
|
||||
- Augmented the `PatternPairAggregator`:
|
||||
- Introduced a new, preferred version of `add_pattern` to support a new option for treating a
|
||||
match as a separate aggregation returned from `aggregate()`. This replaces the now
|
||||
deprecated `add_pattern_pair` method and you provide a `MatchAction` in lieu of the `remove_match` field.
|
||||
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization for how
|
||||
a match should be handled.
|
||||
- `REMOVE`: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
- `KEEP`: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
- `AGGREGATE`: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
- `PatternMatch` now extends `Aggregation` and provides richer info to handlers.
|
||||
- **BREAKING**: The `PatternMatch` type returned to handlers registered via `on_pattern_match`
|
||||
has been updated to subclass from the new `Aggregation` type, which means that `content`
|
||||
has been replaced with `text` and `pattern_id` has been replaced with `type`:
|
||||
```
|
||||
async dev on_match_tag(match: PatternMatch):
|
||||
pattern = match.type # instead of match.pattern_id
|
||||
text = match.text # instead of match.content
|
||||
```
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated all STT and TTS services to use consistent error handling pattern with
|
||||
`push_error()` method for better pipeline error event integration.
|
||||
|
||||
- Added Hindi support for Rime TTS services.
|
||||
|
||||
- Updated `GeminiTTSService` to use Google Cloud Text-to-Speech streaming API
|
||||
@@ -30,11 +109,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Updated language mappings for the Google and Gemini TTS services to match
|
||||
official documentation.
|
||||
|
||||
- `TextFrame` new field `append_to_context` used to indicate if the encompassing
|
||||
text should be added to the LLM context (by the LLM assistant aggregator). It
|
||||
defaults to `True`.
|
||||
|
||||
- TTS flow respects aggregation metadata
|
||||
- `TTSService` accepts a new `skip_aggregator_types` to avoid speaking certain aggregation types
|
||||
(now determined/returned by the aggregator)
|
||||
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when either an
|
||||
aggregation occurs that should not be spoken or when the TTS service supports word-by-word
|
||||
timestamping. In the latter case, the `TTSService` preliminarily generates an
|
||||
`AggregatedTextFrame`, aggregated by sentence to generate the full sentence content as early
|
||||
as possible.
|
||||
- Introduced a new methods, `add_text_transformer()` and `remove_text_transformer()`:
|
||||
These functions introduce the ability to provide (and subsequently remove) callbacks to the TTS to transform text based on
|
||||
its aggregated type prior to sending the text to the underlying TTS service. This makes it
|
||||
possible to do things like introduce TTS-specific tags for spelling or emotion or change the
|
||||
pronunciation of something on the fly.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
|
||||
`credentials` or `credentials_path` instead for Google Cloud authentication.
|
||||
|
||||
- The RTVI `bot-transcription` event is deprecated in favor of the new `bot-output`
|
||||
message which is the canonical representation of bot output (spoken or not). The code
|
||||
still emits a transcription message for backwards compatibility while transition occurs.
|
||||
|
||||
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
|
||||
`LLMTextProcessor`. TTSServices still have an internal aggregator for support of default
|
||||
behavior, but if you want to override the aggregation behavior, you should use the new
|
||||
processor.
|
||||
|
||||
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a `pattern_id`
|
||||
and `remove_match` field in favor of the new `add_pattern` method which takes a `type` and an
|
||||
`action`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed subtle issue of assistant context messages ending up with double spaces
|
||||
@@ -51,6 +161,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Prevented `HeyGenVideoService` from automatically disconnecting after 5 minutes.
|
||||
|
||||
### Added
|
||||
|
||||
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
|
||||
example wiring; leverages the enhancement model for robust detection with no
|
||||
ONNX dependency or added processing complexity.
|
||||
|
||||
## [0.0.94] - 2025-11-10
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -15,7 +15,6 @@ from loguru import logger
|
||||
from pipecat.audio.filters.aic_filter import AICFilter
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -48,7 +47,7 @@ def _create_aic_filter() -> AICFilter:
|
||||
|
||||
return AICFilter(
|
||||
license_key=license_key,
|
||||
enhancement_level=1.0,
|
||||
enhancement_level=0.5,
|
||||
)
|
||||
|
||||
|
||||
@@ -56,27 +55,33 @@ def _create_aic_filter() -> AICFilter:
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=_create_aic_filter(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=_create_aic_filter(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=_create_aic_filter(),
|
||||
),
|
||||
"daily": lambda: (
|
||||
lambda aic: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
"twilio": lambda: (
|
||||
lambda aic: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
"webrtc": lambda: (
|
||||
lambda aic: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -106,16 +110,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
pattern_aggregator = PatternPairAggregator()
|
||||
|
||||
# Add pattern for voice switching
|
||||
pattern_aggregator.add_pattern_pair(
|
||||
pattern_id="voice_tag",
|
||||
pattern_aggregator.add_pattern(
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
remove_match=True,
|
||||
action=MatchAction.REMOVE, # Remove tags from final text
|
||||
)
|
||||
|
||||
# Register handler for voice switching
|
||||
async def on_voice_tag(match: PatternMatch):
|
||||
voice_name = match.content.strip().lower()
|
||||
voice_name = match.text.strip().lower()
|
||||
if voice_name in VOICE_IDS:
|
||||
# First flush any existing audio to finish the current context
|
||||
await tts.flush_audio()
|
||||
@@ -125,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
else:
|
||||
logger.warning(f"Unknown voice: {voice_name}")
|
||||
|
||||
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
|
||||
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ Source = "https://github.com/pipecat-ai/pipecat"
|
||||
Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.0.1" ]
|
||||
aic = [ "aic-sdk~=1.1.0" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
|
||||
@@ -68,6 +68,58 @@ class AICFilter(BaseAudioFilter):
|
||||
# Model will be created in start() since the API now requires sample_rate
|
||||
self._aic = None
|
||||
|
||||
def get_vad_factory(self):
|
||||
"""Return a zero-arg factory that will create the VAD once the model exists.
|
||||
|
||||
Returns:
|
||||
A zero-argument callable that, when invoked, returns an initialized
|
||||
VoiceActivityDetector bound to the underlying AIC model. Raises a
|
||||
RuntimeError if the model has not been initialized (i.e. start()
|
||||
has not been called successfully).
|
||||
"""
|
||||
|
||||
def _factory():
|
||||
if self._aic is None:
|
||||
raise RuntimeError("AIC model not initialized yet. Call start(sample_rate) first.")
|
||||
return self._aic.create_vad()
|
||||
|
||||
return _factory
|
||||
|
||||
def create_vad_analyzer(
|
||||
self,
|
||||
*,
|
||||
lookback_buffer_size: Optional[float] = None,
|
||||
sensitivity: Optional[float] = None,
|
||||
):
|
||||
"""Return an analyzer that will lazily instantiate the AIC VAD when ready.
|
||||
|
||||
AIC VAD parameters:
|
||||
- lookback_buffer_size:
|
||||
Number of window-length audio buffers used as a lookback buffer.
|
||||
Higher values increase prediction stability but add latency.
|
||||
Range: 1.0 .. 20.0, Default (SDK): 6.0
|
||||
- sensitivity:
|
||||
Energy threshold sensitivity. Energy threshold = 10 ** (-sensitivity).
|
||||
Range: 1.0 .. 15.0, Default (SDK): 6.0
|
||||
|
||||
Args:
|
||||
lookback_buffer_size: Optional lookback buffer size to configure on the VAD.
|
||||
Range: 1.0 .. 20.0. If None, SDK default is used.
|
||||
sensitivity: Optional sensitivity (energy threshold) to configure on the VAD.
|
||||
Range: 1.0 .. 15.0. If None, SDK default is used.
|
||||
|
||||
Returns:
|
||||
A lazily-initialized AICVADAnalyzer that will bind to the VAD backend
|
||||
once the filter's model has been created (after start(sample_rate)).
|
||||
"""
|
||||
from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
|
||||
return AICVADAnalyzer(
|
||||
vad_factory=self.get_vad_factory(),
|
||||
lookback_buffer_size=lookback_buffer_size,
|
||||
sensitivity=sensitivity,
|
||||
)
|
||||
|
||||
async def start(self, sample_rate: int):
|
||||
"""Initialize the filter with the transport's sample rate.
|
||||
|
||||
@@ -185,7 +237,7 @@ class AICFilter(BaseAudioFilter):
|
||||
)
|
||||
|
||||
# Process planar in-place; returns ndarray (same shape)
|
||||
out_f32 = self._aic.process(block_f32)
|
||||
out_f32 = await self._aic.process_async(block_f32)
|
||||
|
||||
# Convert back to int16 bytes, planar layout
|
||||
out_i16 = np.clip(out_f32 * 32768.0, -32768, 32767).astype(np.int16)
|
||||
|
||||
158
src/pipecat/audio/vad/aic_vad.py
Normal file
158
src/pipecat/audio/vad/aic_vad.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""AIC-integrated VAD analyzer that lazily binds to the AIC SDK backend.
|
||||
|
||||
This analyzer queries the backend's is_speech_detected() and maps it to a float
|
||||
confidence (1.0/0.0). It uses 10 ms windows based on the sample rate and applies
|
||||
optional AIC VAD parameters (lookback_buffer_size, sensitivity) when available.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
|
||||
try:
|
||||
from aic import AICVadParameter
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the AIC filter, you need to `pip install pipecat-ai[aic]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AICVADAnalyzer(VADAnalyzer):
|
||||
"""VAD analyzer that lazily instantiates the AIC VoiceActivityDetector via a factory.
|
||||
|
||||
The analyzer can be constructed before the AIC Model exists. Once the filter has
|
||||
started and the Model is available, the provided factory will succeed and the
|
||||
backend VAD will be created. We then switch to single-sample updates where
|
||||
num_frames_required() returns 1 and confidence is derived from the backend's
|
||||
boolean is_speech_detected() state.
|
||||
|
||||
AIC VAD runtime parameters:
|
||||
- lookback_buffer_size:
|
||||
Controls the lookback buffer size used by the VAD, i.e. the number of
|
||||
window-length audio buffers used as a lookback buffer. Larger values improve
|
||||
stability but increase latency.
|
||||
Range: 1.0 .. 20.0
|
||||
Default (SDK): 6.0
|
||||
- sensitivity:
|
||||
Controls the energy threshold sensitivity. Higher values make the detector
|
||||
less sensitive (require more energy to count as speech).
|
||||
Range: 1.0 .. 15.0
|
||||
Formula: Energy threshold = 10 ** (-sensitivity)
|
||||
Default (SDK): 6.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vad_factory: Optional[Callable[[], Any]] = None,
|
||||
lookback_buffer_size: Optional[float] = None,
|
||||
sensitivity: Optional[float] = None,
|
||||
):
|
||||
"""Create an AIC VAD analyzer.
|
||||
|
||||
Args:
|
||||
vad_factory:
|
||||
Zero-arg callable that returns an initialized AIC VoiceActivityDetector.
|
||||
This may raise until the filter's Model has been created; the analyzer
|
||||
will retry on set_sample_rate/first use.
|
||||
lookback_buffer_size:
|
||||
Optional override for AIC VAD lookback buffer size.
|
||||
Range: 1.0 .. 20.0. Larger values increase stability at the cost of latency.
|
||||
If None, the SDK default (6.0) is used.
|
||||
sensitivity:
|
||||
Optional override for AIC VAD sensitivity (energy threshold).
|
||||
Range: 1.0 .. 15.0. Energy threshold = 10 ** (-sensitivity).
|
||||
If None, the SDK default (6.0) is used.
|
||||
"""
|
||||
# Use fixed VAD parameters for AIC: no user override
|
||||
fixed_params = VADParams(confidence=0.5, start_secs=0.0, stop_secs=0.0, min_volume=0.0)
|
||||
super().__init__(sample_rate=None, params=fixed_params)
|
||||
self._vad_factory = vad_factory
|
||||
self._backend_vad: Optional[Any] = None
|
||||
self._pending_lookback: Optional[float] = lookback_buffer_size
|
||||
self._pending_sensitivity: Optional[float] = sensitivity
|
||||
|
||||
def bind_vad_factory(self, vad_factory: Callable[[], Any]):
|
||||
"""Attach or replace the factory post-construction."""
|
||||
self._vad_factory = vad_factory
|
||||
self._ensure_backend_initialized()
|
||||
|
||||
def _apply_backend_params(self):
|
||||
"""Apply optional AIC VAD parameters if available."""
|
||||
if self._backend_vad is None or AICVadParameter is None:
|
||||
return
|
||||
try:
|
||||
if self._pending_lookback is not None:
|
||||
self._backend_vad.set_parameter(
|
||||
AICVadParameter.LOOKBACK_BUFFER_SIZE, float(self._pending_lookback)
|
||||
)
|
||||
if self._pending_sensitivity is not None:
|
||||
self._backend_vad.set_parameter(
|
||||
AICVadParameter.SENSITIVITY, float(self._pending_sensitivity)
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"AIC VAD parameter application deferred/failed: {e}")
|
||||
|
||||
def _ensure_backend_initialized(self):
|
||||
if self._backend_vad is not None:
|
||||
return
|
||||
if not self._vad_factory:
|
||||
return
|
||||
try:
|
||||
self._backend_vad = self._vad_factory()
|
||||
self._apply_backend_params()
|
||||
# With backend ready, recompute internal frame sizing
|
||||
super().set_params(self._params)
|
||||
logger.debug("AIC VAD backend initialized in analyzer.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Filter may not be started yet; try again later
|
||||
logger.debug(f"Deferring AIC VAD backend initialization: {e}")
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
"""Set the sample rate for audio processing.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
"""
|
||||
# Set rate and attempt backend initialization once we know SR
|
||||
self._sample_rate = self._init_sample_rate or sample_rate
|
||||
self._ensure_backend_initialized()
|
||||
# Ensure params are initialized even if backend not ready yet
|
||||
try:
|
||||
super().set_params(self._params)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def num_frames_required(self) -> int:
|
||||
"""Get the number of audio frames required for analysis.
|
||||
|
||||
Returns:
|
||||
Number of frames needed for VAD processing.
|
||||
"""
|
||||
# Use 10 ms windows based on sample rate
|
||||
return int(self.sample_rate * 0.01) if self.sample_rate > 0 else 160
|
||||
|
||||
def voice_confidence(self, buffer: bytes) -> float:
|
||||
"""Calculate voice activity confidence for the given audio buffer.
|
||||
|
||||
Args:
|
||||
buffer: Audio buffer to analyze.
|
||||
|
||||
Returns:
|
||||
Voice confidence score is 0.0 or 1.0.
|
||||
"""
|
||||
# Ensure backend exists (filter might have started since last call)
|
||||
self._ensure_backend_initialized()
|
||||
if self._backend_vad is None:
|
||||
return 0.0
|
||||
|
||||
# We do not need to analyze 'buffer' here since the model's VAD is updated
|
||||
# as part of the enhancement pipeline. Simply query the boolean and map it.
|
||||
try:
|
||||
is_speech = self._backend_vad.is_speech_detected()
|
||||
return 1.0 if is_speech else 0.0
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"AIC VAD inference error: {e}")
|
||||
return 0.0
|
||||
@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
|
||||
class IVRStatus(Enum):
|
||||
@@ -114,15 +118,15 @@ class IVRProcessor(FrameProcessor):
|
||||
def _setup_xml_patterns(self):
|
||||
"""Set up XML pattern detection and handlers."""
|
||||
# Register DTMF pattern
|
||||
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", remove_match=True)
|
||||
self._aggregator.add_pattern("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
|
||||
|
||||
# Register mode pattern
|
||||
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", remove_match=True)
|
||||
self._aggregator.add_pattern("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
|
||||
|
||||
# Register IVR pattern
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", remove_match=True)
|
||||
self._aggregator.add_pattern("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -159,7 +163,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing DTMF content.
|
||||
"""
|
||||
value = match.content
|
||||
value = match.text
|
||||
logger.debug(f"DTMF detected: {value}")
|
||||
|
||||
try:
|
||||
@@ -180,7 +184,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing IVR status content.
|
||||
"""
|
||||
status = match.content
|
||||
status = match.text
|
||||
logger.trace(f"IVR status detected: {status}")
|
||||
|
||||
# Convert string to enum, with validation
|
||||
@@ -211,7 +215,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing mode content.
|
||||
"""
|
||||
mode = match.content
|
||||
mode = match.text
|
||||
logger.debug(f"Mode detected: {mode}")
|
||||
if mode == "conversation":
|
||||
await self._handle_conversation()
|
||||
|
||||
@@ -12,6 +12,7 @@ and LLM processing.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -337,11 +338,14 @@ class TextFrame(DataFrame):
|
||||
# mandatory fields of theirs to have defaults to preserve
|
||||
# non-default-before-default argument order)
|
||||
includes_inter_frame_spaces: bool = field(init=False)
|
||||
# Whether this text frame should be appended to the LLM context.
|
||||
append_to_context: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
@@ -355,8 +359,32 @@ class LLMTextFrame(TextFrame):
|
||||
pass
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSTextFrame(TextFrame):
|
||||
class AggregatedTextFrame(TextFrame):
|
||||
"""Text frame representing an aggregation of TextFrames.
|
||||
|
||||
This frame contains multiple TextFrames aggregated together for processing
|
||||
or output along with a field to indicate how they are aggregated.
|
||||
|
||||
Parameters:
|
||||
aggregated_by: Method used to aggregate the text frames.
|
||||
"""
|
||||
|
||||
aggregated_by: AggregationType | str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSTextFrame(AggregatedTextFrame):
|
||||
"""Text frame generated by Text-to-Speech services."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1001,7 +1001,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
if not self._started or not frame.append_to_context:
|
||||
return
|
||||
|
||||
if self._params.expect_stripped_words:
|
||||
|
||||
@@ -814,7 +814,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
if not self._started or not frame.append_to_context:
|
||||
return
|
||||
|
||||
# Make sure we really have text (spaces count, too!)
|
||||
|
||||
106
src/pipecat/processors/aggregators/llm_text_processor.py
Normal file
106
src/pipecat/processors/aggregators/llm_text_processor.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LLM text processor module for processing and aggregating raw LLM output text.
|
||||
|
||||
This processor will convert LLMTextFrames into AggregatedTextFrames based on the
|
||||
configured text aggregator. Using the customizable aggregator, it provides
|
||||
functionality to handle or manipulate LLM text frames before they are sent to other
|
||||
components such as TTS services or context aggregators. It can be used to pre-aggregate
|
||||
and categorize, modify, or filter direct output tokens from the LLM.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class LLMTextProcessor(FrameProcessor):
|
||||
"""A processor for handling or manipulating LLM text frames before they are processed further.
|
||||
|
||||
This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured
|
||||
text aggregator. Using the customizable aggregator, it provides functionality to handle or
|
||||
manipulate LLM text frames before they are sent to other components such as TTS services or
|
||||
context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct
|
||||
output tokens from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs):
|
||||
"""Initialize the LLM text processor.
|
||||
|
||||
Args:
|
||||
text_aggregator: An optional text aggregator to use for processing LLM text frames. By
|
||||
default, a SimpleTextAggregator aggregating by sentence will be used.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
|
||||
TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame.skip_tts)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._handle_llm_end()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_interruption(self, _):
|
||||
"""Handle interruptions by resetting the text aggregator."""
|
||||
await self._text_aggregator.handle_interruption()
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the internal state of the text processor and its aggregator."""
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
@@ -24,6 +24,7 @@ from typing import (
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
@@ -32,6 +33,8 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.audio.utils import calculate_audio_volume
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -704,6 +707,29 @@ class RTVITextMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotOutputMessageData(RTVITextMessageData):
|
||||
"""Data for bot output RTVI messages.
|
||||
|
||||
Extends RTVITextMessageData to include metadata about the output.
|
||||
"""
|
||||
|
||||
spoken: bool = False # Indicates if the text has been spoken by TTS
|
||||
aggregated_by: AggregationType | str
|
||||
# Indicates what form the text is in (e.g., by word, sentence, etc.)
|
||||
|
||||
|
||||
class RTVIBotOutputMessage(BaseModel):
|
||||
"""Message containing bot output text.
|
||||
|
||||
An event meant to holistically represent what the bot is outputting,
|
||||
along with metadata about the output and if it has been spoken.
|
||||
"""
|
||||
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-output"] = "bot-output"
|
||||
data: RTVIBotOutputMessageData
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
"""Message containing bot transcription text.
|
||||
|
||||
@@ -896,6 +922,7 @@ class RTVIObserverParams:
|
||||
Parameter `errors_enabled` is deprecated. Error messages are always enabled.
|
||||
|
||||
Parameters:
|
||||
bot_output_enabled: Indicates if bot output messages should be sent.
|
||||
bot_llm_enabled: Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled: Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled: Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
@@ -907,9 +934,17 @@ class RTVIObserverParams:
|
||||
metrics_enabled: Indicates if metrics messages should be sent.
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
bot_output_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
audio_level_period_secs: How often audio levels should be sent if enabled.
|
||||
"""
|
||||
|
||||
bot_output_enabled: bool = True
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
@@ -921,6 +956,15 @@ class RTVIObserverParams:
|
||||
metrics_enabled: bool = True
|
||||
system_logs_enabled: bool = False
|
||||
errors_enabled: Optional[bool] = None
|
||||
skip_aggregator_types: Optional[List[AggregationType | str]] = None
|
||||
bot_output_transforms: Optional[
|
||||
List[
|
||||
Tuple[
|
||||
AggregationType | str,
|
||||
Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
]
|
||||
]
|
||||
] = None
|
||||
audio_level_period_secs: float = 0.15
|
||||
|
||||
|
||||
@@ -973,8 +1017,45 @@ class RTVIObserver(BaseObserver):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._aggregation_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = self._params.bot_output_transforms or []
|
||||
|
||||
def add_bot_output_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" to
|
||||
handle all text before sending to the client.
|
||||
"""
|
||||
self._aggregation_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_bot_output_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._aggregation_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._aggregation_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _logger_sink(self, message):
|
||||
"""Logger sink so we cna send system logs to RTVI clients."""
|
||||
"""Logger sink so we can send system logs to RTVI clients."""
|
||||
message = RTVISystemLogMessage(data=RTVITextMessageData(text=message))
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
@@ -1048,12 +1129,15 @@ class RTVIObserver(BaseObserver):
|
||||
await self.send_rtvi_message(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.send_rtvi_message(message)
|
||||
else:
|
||||
elif isinstance(frame, AggregatedTextFrame) and (
|
||||
self._params.bot_output_enabled or self._params.bot_tts_enabled
|
||||
):
|
||||
if isinstance(frame, TTSTextFrame) and not isinstance(src, BaseOutputTransport):
|
||||
# This check is to make sure we handle the frame when it has gone
|
||||
# through the transport and has correct timing.
|
||||
mark_as_seen = False
|
||||
else:
|
||||
await self._handle_aggregated_llm_text(frame)
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
@@ -1084,15 +1168,6 @@ class RTVIObserver(BaseObserver):
|
||||
if mark_as_seen:
|
||||
self._frames_seen.add(frame.id)
|
||||
|
||||
async def _push_bot_transcription(self):
|
||||
"""Push accumulated bot transcription as a message."""
|
||||
if len(self._bot_transcription) > 0:
|
||||
message = RTVIBotTranscriptionMessage(
|
||||
data=RTVITextMessageData(text=self._bot_transcription)
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
self._bot_transcription = ""
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
"""Handle user speaking interruption frames."""
|
||||
message = None
|
||||
@@ -1115,14 +1190,45 @@ class RTVIObserver(BaseObserver):
|
||||
if message:
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
|
||||
"""Handle aggregated LLM text output frames."""
|
||||
# Skip certain aggregator types if configured to do so.
|
||||
if (
|
||||
self._params.skip_aggregator_types
|
||||
and frame.aggregated_by in self._params.skip_aggregator_types
|
||||
):
|
||||
return
|
||||
|
||||
text = frame.text
|
||||
type = frame.aggregated_by
|
||||
for aggregation_type, transform in self._aggregation_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
text = await transform(text, type)
|
||||
|
||||
isTTS = isinstance(frame, TTSTextFrame)
|
||||
if self._params.bot_output_enabled:
|
||||
message = RTVIBotOutputMessage(
|
||||
data=RTVIBotOutputMessageData(text=text, spoken=isTTS, aggregated_by=type)
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
if isTTS and self._params.bot_tts_enabled:
|
||||
tts_message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.send_rtvi_message(tts_message)
|
||||
|
||||
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
|
||||
"""Handle LLM text output frames."""
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
# TODO (mrkb): Remove all this logic when we fully deprecate bot-transcription messages.
|
||||
self._bot_transcription += frame.text
|
||||
if match_endofsentence(self._bot_transcription):
|
||||
await self._push_bot_transcription()
|
||||
|
||||
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
|
||||
await self.send_rtvi_message(
|
||||
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
|
||||
)
|
||||
self._bot_transcription = ""
|
||||
|
||||
async def _handle_user_transcriptions(self, frame: Frame):
|
||||
"""Handle user transcription frames."""
|
||||
@@ -1248,7 +1354,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
# Default to 0.3.0 which is the last version before actually having a
|
||||
# "client-version".
|
||||
self._client_version = [0, 3, 0]
|
||||
self._skip_tts: bool = False # Keep in sync with llm_service.py
|
||||
self._llm_skip_tts: bool = False # Keep in sync with llm_service.py's configuration.
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -1441,7 +1547,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
elif isinstance(frame, RTVIActionFrame):
|
||||
await self._action_queue.put(frame)
|
||||
elif isinstance(frame, LLMConfigureOutputFrame):
|
||||
self._skip_tts = frame.skip_tts
|
||||
self._llm_skip_tts = frame.skip_tts
|
||||
await self.push_frame(frame, direction)
|
||||
# Other frames
|
||||
else:
|
||||
@@ -1697,9 +1803,9 @@ class RTVIProcessor(FrameProcessor):
|
||||
opts = data.options if data.options is not None else RTVISendTextOptions()
|
||||
if opts.run_immediately:
|
||||
await self.interrupt_bot()
|
||||
cur_skip_tts = self._skip_tts
|
||||
cur_llm_skip_tts = self._llm_skip_tts
|
||||
should_skip_tts = not opts.audio_response
|
||||
toggle_skip_tts = cur_skip_tts != should_skip_tts
|
||||
toggle_skip_tts = cur_llm_skip_tts != should_skip_tts
|
||||
if toggle_skip_tts:
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=should_skip_tts)
|
||||
await self.push_frame(output_frame)
|
||||
@@ -1709,7 +1815,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
)
|
||||
await self.push_frame(text_frame)
|
||||
if toggle_skip_tts:
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=cur_skip_tts)
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=cur_llm_skip_tts)
|
||||
await self.push_frame(output_frame)
|
||||
|
||||
async def _handle_update_context(self, data: RTVIAppendToContextData):
|
||||
|
||||
@@ -21,6 +21,7 @@ from pipecat import __version__ as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -205,8 +206,9 @@ class AssemblyAISTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to AssemblyAI: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
self._connected = False
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -231,7 +233,8 @@ class AssemblyAISTTService(STTService):
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during termination handshake: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
@@ -239,7 +242,8 @@ class AssemblyAISTTService(STTService):
|
||||
await self._websocket.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
finally:
|
||||
self._websocket = None
|
||||
@@ -258,11 +262,13 @@ class AssemblyAISTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in receive handler: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
|
||||
"""Parse a raw message into the appropriate message type."""
|
||||
@@ -291,7 +297,8 @@ class AssemblyAISTTService(STTService):
|
||||
elif isinstance(parsed_message, TerminationMessage):
|
||||
await self._handle_termination(parsed_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _handle_termination(self, message: TerminationMessage):
|
||||
"""Handle termination message."""
|
||||
|
||||
@@ -237,7 +237,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -249,7 +250,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Async")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._started = False
|
||||
@@ -297,7 +299,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
|
||||
else:
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
|
||||
@@ -342,7 +344,8 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -350,6 +353,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AsyncAIHttpTTSService(TTSService):
|
||||
@@ -492,7 +496,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Async API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Async API error: {error_text}"))
|
||||
await self.push_error(ErrorFrame(error=f"Async API error: {error_text}"))
|
||||
raise Exception(f"Async API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -509,7 +513,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -1027,7 +1028,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
logger.debug(f"Assistant response text added: {text}")
|
||||
|
||||
# Report the text of the assistant response.
|
||||
frame = TTSTextFrame(text)
|
||||
frame = TTSTextFrame(text, aggregated_by=AggregationType.SENTENCE)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -1062,7 +1063,9 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# TTSTextFrame would be ignored otherwise (the interruption frame
|
||||
# would have cleared the assistant aggregator state).
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
frame = TTSTextFrame(self._assistant_text_buffer)
|
||||
frame = TTSTextFrame(
|
||||
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
|
||||
)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
self._may_need_repush_assistant_text = False
|
||||
|
||||
@@ -140,7 +140,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
@@ -181,8 +182,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect: {e}")
|
||||
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
@@ -199,13 +200,13 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_stt: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
@@ -288,7 +289,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
@@ -308,7 +310,8 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._ws_client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -527,9 +530,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(
|
||||
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
|
||||
)
|
||||
await self.push_frame(ErrorFrame(f"AWS Transcribe error: {error_msg}"))
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
@@ -537,5 +538,6 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
break
|
||||
|
||||
@@ -18,6 +18,7 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -111,13 +112,17 @@ class AzureSTTService(STTService):
|
||||
audio: Raw audio bytes to process.
|
||||
|
||||
Yields:
|
||||
None - actual transcription frames are pushed via callbacks.
|
||||
Frame: Either None for successful processing or ErrorFrame on failure.
|
||||
"""
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech recognition service.
|
||||
@@ -133,17 +138,21 @@ class AzureSTTService(STTService):
|
||||
if self._audio_stream:
|
||||
return
|
||||
|
||||
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
|
||||
self._audio_stream = PushAudioInputStream(stream_format)
|
||||
try:
|
||||
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
|
||||
self._audio_stream = PushAudioInputStream(stream_format)
|
||||
|
||||
audio_config = AudioConfig(stream=self._audio_stream)
|
||||
audio_config = AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception during initialization: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech recognition service.
|
||||
|
||||
@@ -337,7 +337,7 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -364,13 +364,15 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
@@ -448,3 +450,4 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == CancellationReason.Error:
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
yield ErrorFrame(error=f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
@@ -20,6 +20,7 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -275,7 +276,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to connect to Cartesia: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
@@ -284,6 +286,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -315,7 +318,9 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
await self._on_transcript(data)
|
||||
|
||||
elif data["type"] == "error":
|
||||
logger.error(f"Cartesia error: {data.get('message', 'Unknown error')}")
|
||||
error_msg = data.get("message", "Unknown error")
|
||||
logger.error(f"Cartesia error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(error=error_msg))
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
|
||||
@@ -10,7 +10,8 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -125,6 +126,72 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class CartesiaEmotion(str, Enum):
|
||||
"""Predefined Emotions supported by Cartesia."""
|
||||
|
||||
# Primary emotions supported by Cartesia
|
||||
NEUTRAL = "neutral"
|
||||
ANGRY = "angry"
|
||||
EXCITED = "excited"
|
||||
CONTENT = "content"
|
||||
SAD = "sad"
|
||||
SCARED = "scared"
|
||||
# Additional emotions supported by Cartesia
|
||||
HAPPY = "happy"
|
||||
ENTHUSIASTIC = "enthusiastic"
|
||||
ELATED = "elated"
|
||||
EUPHORIC = "euphoric"
|
||||
TRIUMPHANT = "triumphant"
|
||||
AMAZED = "amazed"
|
||||
SURPRISED = "surprised"
|
||||
FLIRTATIOUS = "flirtatious"
|
||||
JOKING_COMEDIC = "joking/comedic"
|
||||
CURIOUS = "curious"
|
||||
PEACEFUL = "peaceful"
|
||||
SERENE = "serene"
|
||||
CALM = "calm"
|
||||
GRATEFUL = "grateful"
|
||||
AFFECTIONATE = "affectionate"
|
||||
TRUST = "trust"
|
||||
SYMPATHETIC = "sympathetic"
|
||||
ANTICIPATION = "anticipation"
|
||||
MYSTERIOUS = "mysterious"
|
||||
MAD = "mad"
|
||||
OUTRAGED = "outraged"
|
||||
FRUSTRATED = "frustrated"
|
||||
AGITATED = "agitated"
|
||||
THREATENED = "threatened"
|
||||
DISGUSTED = "disgusted"
|
||||
CONTEMPT = "contempt"
|
||||
ENVIOUS = "envious"
|
||||
SARCASTIC = "sarcastic"
|
||||
IRONIC = "ironic"
|
||||
DEJECTED = "dejected"
|
||||
MELANCHOLIC = "melancholic"
|
||||
DISAPPOINTED = "disappointed"
|
||||
HURT = "hurt"
|
||||
GUILTY = "guilty"
|
||||
BORED = "bored"
|
||||
TIRED = "tired"
|
||||
REJECTED = "rejected"
|
||||
NOSTALGIC = "nostalgic"
|
||||
WISTFUL = "wistful"
|
||||
APOLOGETIC = "apologetic"
|
||||
HESITANT = "hesitant"
|
||||
INSECURE = "insecure"
|
||||
CONFUSED = "confused"
|
||||
RESIGNED = "resigned"
|
||||
ANXIOUS = "anxious"
|
||||
PANICKED = "panicked"
|
||||
ALARMED = "alarmed"
|
||||
PROUD = "proud"
|
||||
CONFIDENT = "confident"
|
||||
DISTANT = "distant"
|
||||
SKEPTICAL = "skeptical"
|
||||
CONTEMPLATIVE = "contemplative"
|
||||
DETERMINED = "determined"
|
||||
|
||||
|
||||
class CartesiaTTSService(AudioContextWordTTSService):
|
||||
"""Cartesia TTS service with WebSocket streaming and word timestamps.
|
||||
|
||||
@@ -182,6 +249,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
container: Audio container format.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
@@ -200,10 +271,18 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
text_aggregator=text_aggregator or SkipTagsAggregator([("<spell>", "</spell>")]),
|
||||
text_aggregator=text_aggregator,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
|
||||
|
||||
params = params or CartesiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
@@ -257,6 +336,27 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
# A set of Cartesia-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Cartesia spell tag."""
|
||||
return f"<spell>{text}</spell>"
|
||||
|
||||
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
|
||||
"""Convenience method to create an emotion tag."""
|
||||
return f'<emotion value="{emotion}" />'
|
||||
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f'<break time="{seconds}s" />'
|
||||
|
||||
def VOLUME_TAG(volume: float) -> str:
|
||||
"""Convenience method to create a volume tag."""
|
||||
return f'<volume ratio="{volume}" />'
|
||||
|
||||
def SPEED_TAG(speed: float) -> str:
|
||||
"""Convenience method to create a speed tag."""
|
||||
return f'<speed ratio="{speed}" />'
|
||||
|
||||
def _is_cjk_language(self, language: str) -> bool:
|
||||
"""Check if the given language is CJK (Chinese, Japanese, Korean).
|
||||
|
||||
@@ -397,7 +497,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -409,7 +510,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
@@ -465,7 +567,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
self._context_id = None
|
||||
else:
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
@@ -506,7 +608,8 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -514,6 +617,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class CartesiaHttpTTSService(TTSService):
|
||||
@@ -705,7 +809,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Cartesia API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Cartesia API error: {error_text}"))
|
||||
await self.push_error(ErrorFrame(error=f"Cartesia API error: {error_text}"))
|
||||
raise Exception(f"Cartesia API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -722,7 +826,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -191,7 +191,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Reset state only after everything is cleaned up
|
||||
self._websocket = None
|
||||
@@ -214,7 +215,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -233,6 +235,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -333,14 +336,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
if not self._websocket:
|
||||
logger.error("Not connected to Deepgram Flux.")
|
||||
yield ErrorFrame("Not connected to Deepgram Flux.", fatal=True)
|
||||
yield ErrorFrame("Not connected to Deepgram Flux.")
|
||||
return
|
||||
|
||||
try:
|
||||
await self._websocket.send(audio)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio to Flux: {e}")
|
||||
yield ErrorFrame(f"Failed to send audio to Flux: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
yield None
|
||||
@@ -417,7 +420,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
# Skip malformed messages
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
# Error will be handled inside WebsocketService->_receive_task_handler
|
||||
raise
|
||||
else:
|
||||
|
||||
@@ -256,7 +256,7 @@ class DeepgramSTTService(STTService):
|
||||
async def _on_error(self, *args, **kwargs):
|
||||
error: ErrorResponse = kwargs["error"]
|
||||
logger.warning(f"{self} connection error, will retry: {error}")
|
||||
await self.push_error(ErrorFrame(f"{error}"))
|
||||
await self.push_error(ErrorFrame(error=f"{error}"))
|
||||
await self.stop_all_metrics()
|
||||
# NOTE(aleix): we don't disconnect (i.e. call finish on the connection)
|
||||
# because this triggers more errors internally in the Deepgram SDK. So,
|
||||
|
||||
@@ -125,8 +125,8 @@ class DeepgramTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class DeepgramHttpTTSService(TTSService):
|
||||
|
||||
@@ -351,8 +351,8 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ElevenLabs STT error: {e}")
|
||||
yield ErrorFrame(f"ElevenLabs STT error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
def audio_format_from_sample_rate(sample_rate: int) -> str:
|
||||
|
||||
@@ -424,7 +424,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing context for voice settings update: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
@@ -535,8 +536,9 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
self._websocket = None
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
@@ -551,7 +553,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from ElevenLabs")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
@@ -581,7 +584,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
self._partial_word = ""
|
||||
@@ -736,13 +740,15 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
else:
|
||||
await self._send_text(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
self._started = False
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class ElevenLabsHttpTTSService(WordTTSService):
|
||||
@@ -1085,7 +1091,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
logger.warning(f"Failed to parse JSON from stream: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {e}", exc_info=True)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
continue
|
||||
|
||||
# After processing all chunks, emit any remaining partial word
|
||||
@@ -1109,8 +1116,8 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
self._previous_text = text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame
|
||||
|
||||
@@ -290,5 +290,5 @@ class FalSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fal Wizper error: {e}")
|
||||
yield ErrorFrame(f"Fal Wizper error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -237,7 +237,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Fish Audio initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -251,7 +252,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
await self._websocket.send(ormsgpack.packb(stop_message))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
@@ -293,7 +295,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -329,7 +332,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
flush_message = {"event": "flush"}
|
||||
await self._get_websocket().send(ormsgpack.packb(flush_message))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -337,5 +341,5 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(f"Error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat import __version__ as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -467,7 +468,8 @@ class GladiaSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in connection handler: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._connection_active = False
|
||||
|
||||
if not self._should_reconnect:
|
||||
@@ -557,7 +559,8 @@ class GladiaSTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.debug("Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gladia keepalive task: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
@@ -620,7 +623,8 @@ class GladiaSTTService(STTService):
|
||||
# Expected when closing the connection
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gladia WebSocket handler: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _maybe_reconnect(self) -> bool:
|
||||
"""Handle exponential backoff reconnection logic."""
|
||||
|
||||
@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -1174,7 +1175,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._connection_task = self.create_task(self._connection_task_handler(config=config))
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}"))
|
||||
|
||||
async def _connection_task_handler(self, config: LiveConnectConfig):
|
||||
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
|
||||
@@ -1255,9 +1256,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
f"Max consecutive failures ({MAX_CONSECUTIVE_FAILURES}) reached, "
|
||||
"treating as fatal error"
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(error=f"{self} Error in receive loop: {error}", fatal=True)
|
||||
)
|
||||
await self.push_error(ErrorFrame(error=f"{self} Error in receive loop: {error}"))
|
||||
return False
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1648,7 +1647,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
frame = TTSTextFrame(text=text)
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
|
||||
# Gemini Live text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
@@ -774,7 +774,8 @@ class GoogleSTTService(STTService):
|
||||
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in request generator: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _stream_audio(self):
|
||||
@@ -805,14 +806,15 @@ class GoogleSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Reconnecting: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming task: {e}")
|
||||
await self.push_frame(ErrorFrame(str(e)))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process an audio chunk for STT transcription.
|
||||
@@ -900,7 +902,8 @@ class GoogleSTTService(STTService):
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google STT responses: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a
|
||||
# timeout, propagate to _stream_audio to reconnect)
|
||||
raise
|
||||
|
||||
@@ -746,7 +746,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -1014,7 +1014,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -1266,6 +1266,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
error_message = f"Gemini TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -13,7 +13,13 @@ from typing import AsyncGenerator, Optional
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
@@ -150,5 +156,6 @@ class GroqTTSService(TTSService):
|
||||
yield TTSAudioRawFrame(bytes, frame_rate, channels)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -225,8 +225,8 @@ class HumeTTSService(TTSService):
|
||||
self._audio_bytes = b""
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Ensure TTFB timer is stopped even on early failures
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -374,7 +374,7 @@ class InworldTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Inworld API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Inworld API error: {error_text}"))
|
||||
yield ErrorFrame(error=f"Inworld API error: {error_text}")
|
||||
return
|
||||
|
||||
# ================================================================================
|
||||
@@ -402,7 +402,7 @@ class InworldTTSService(TTSService):
|
||||
# ================================================================================
|
||||
# Log any unexpected errors and notify the pipeline
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# ================================================================================
|
||||
# STEP 8: CLEANUP AND COMPLETION
|
||||
@@ -517,7 +517,7 @@ class InworldTTSService(TTSService):
|
||||
# Extract the base64-encoded audio content from response
|
||||
if "audioContent" not in response_data:
|
||||
logger.error("No audioContent in Inworld API response")
|
||||
await self.push_error(ErrorFrame("No audioContent in response"))
|
||||
await self.push_error(ErrorFrame(error="No audioContent in response"))
|
||||
return
|
||||
|
||||
# ================================================================================
|
||||
|
||||
@@ -223,7 +223,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -239,7 +240,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
# await self._websocket.send(json.dumps({"eof": True}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -276,7 +278,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
logger.error(f"{self} error: {msg['error']}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
@@ -309,7 +311,8 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(json.dumps({"flush": True}))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -317,3 +320,4 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -347,8 +347,8 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"MiniMax TTS error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -294,7 +294,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -307,7 +308,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Neuphonic")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -372,7 +374,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -380,6 +383,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class NeuphonicHttpTTSService(TTSService):
|
||||
@@ -582,7 +586,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
yield TTSAudioRawFrame(audio_bytes, self.sample_rate, 1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SSE message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
# Don't yield error frame for individual message failures
|
||||
continue
|
||||
|
||||
@@ -590,8 +595,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
logger.debug("TTS generation cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=f"Neuphonic TTS error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -19,6 +19,7 @@ from pipecat.adapters.services.open_ai_realtime_adapter import (
|
||||
OpenAIRealtimeLLMAdapter,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -478,7 +479,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
@@ -667,9 +668,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
self._current_assistant_response = None
|
||||
# error handling
|
||||
if evt.response.status == "failed":
|
||||
await self.push_error(
|
||||
ErrorFrame(error=evt.response.status_details["error"]["message"], fatal=True)
|
||||
)
|
||||
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
|
||||
return
|
||||
# response content
|
||||
for item in evt.response.output:
|
||||
@@ -688,7 +687,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# We receive audio transcript deltas (as opposed to text deltas) when
|
||||
# the output modality is "audio" (the default)
|
||||
if evt.delta:
|
||||
frame = TTSTextFrame(evt.delta)
|
||||
frame = TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE)
|
||||
# OpenAI Realtime text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
@@ -763,7 +762,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
|
||||
|
||||
#
|
||||
# state and client events for the current conversation
|
||||
|
||||
@@ -199,7 +199,7 @@ class OpenAITTSService(TTSService):
|
||||
f"{self} error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {r.status_code}, error: {error})"
|
||||
error=f"Error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -216,3 +216,4 @@ class OpenAITTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -17,6 +17,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -454,7 +455,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}"))
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
@@ -627,9 +628,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
self._current_assistant_response = None
|
||||
# error handling
|
||||
if evt.response.status == "failed":
|
||||
await self.push_error(
|
||||
ErrorFrame(error=evt.response.status_details["error"]["message"], fatal=True)
|
||||
)
|
||||
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
|
||||
return
|
||||
# response content
|
||||
for item in evt.response.output:
|
||||
@@ -654,7 +653,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
@@ -687,7 +686,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}"))
|
||||
|
||||
async def _handle_assistant_output(self, output):
|
||||
# We haven't seen intermixed audio and function_call items in the same response. But let's
|
||||
|
||||
@@ -101,7 +101,7 @@ class PiperTTSService(TTSService):
|
||||
f"{self} error getting audio (status: {response.status}, error: {error})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {response.status}, error: {error})"
|
||||
error=f"Error getting audio (status: {response.status}, error: {error})"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -117,8 +117,8 @@ class PiperTTSService(TTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -266,7 +266,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -279,7 +280,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from PlayHT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._websocket = None
|
||||
@@ -350,7 +352,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
self._request_id = None
|
||||
elif "error" in msg:
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['error']}"))
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
|
||||
@@ -392,7 +394,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(json.dumps(tts_command))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -402,8 +405,8 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
yield ErrorFrame(f"{self} error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class PlayHTHttpTTSService(TTSService):
|
||||
@@ -623,7 +626,8 @@ class PlayHTHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -113,6 +113,10 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
params: Additional configuration parameters.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
@@ -123,10 +127,17 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
|
||||
|
||||
params = params or RimeTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
@@ -152,6 +163,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._context_id = None # Tracks current turn
|
||||
self._receive_task = None
|
||||
self._cumulative_time = 0 # Accumulates time across messages
|
||||
self._extra_msg_fields = {} # Extra fields for next message
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -181,6 +193,31 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
# A set of Rime-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Rime spell function."""
|
||||
return f"spell({text})"
|
||||
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f"<{seconds * 1000}>"
|
||||
|
||||
def PRONOUNCE(self, text: str, word: str, phoneme: str) -> str:
|
||||
"""Convenience method to support Rime's custom pronunciations feature.
|
||||
|
||||
https://docs.rime.ai/api-reference/custom-pronunciation
|
||||
"""
|
||||
self._extra_msg_fields["phonemizeBetweenBrackets"] = True
|
||||
return text.replace(word, f"{phoneme}")
|
||||
|
||||
def INLINE_SPEED(self, text: str, speed: float) -> str:
|
||||
"""Convenience method to support inline speeds."""
|
||||
if not self._extra_msg_fields:
|
||||
self._extra_msg_fields = {}
|
||||
speed_vals = self._extra_msg_fields.get("inlineSpeedAlpha", "").split(",")
|
||||
self._extra_msg_fields["inlineSpeedAlpha"] = ",".join(speed_vals + [str(speed)])
|
||||
return f"[{text}]"
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
@@ -193,7 +230,11 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Rime API."""
|
||||
return {"text": text, "contextId": self._context_id}
|
||||
msg = {"text": text, "contextId": self._context_id}
|
||||
if self._extra_msg_fields:
|
||||
msg |= self._extra_msg_fields
|
||||
self._extra_msg_fields = {}
|
||||
return msg
|
||||
|
||||
def _build_clear_msg(self) -> dict:
|
||||
"""Build clear operation message."""
|
||||
@@ -259,7 +300,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -271,7 +313,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.send(json.dumps(self._build_eos_msg()))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
@@ -367,7 +410,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
|
||||
self._context_id = None
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -409,7 +452,8 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -417,6 +461,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class RimeHttpTTSService(TTSService):
|
||||
@@ -574,8 +619,8 @@ class RimeHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"Rime TTS error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -659,8 +659,8 @@ class RivaSegmentedSTTService(SegmentedSTTService):
|
||||
yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Riva Canary ASR error: {e}")
|
||||
yield ErrorFrame(f"Riva Canary ASR error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class ParakeetSTTService(RivaSTTService):
|
||||
|
||||
@@ -23,6 +23,7 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -165,6 +166,7 @@ class RivaTTSService(TTSService):
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -264,7 +264,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Sarvam API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam API error: {error_text}"))
|
||||
await self.push_error(ErrorFrame(error=f"Sarvam API error: {error_text}"))
|
||||
return
|
||||
|
||||
response_data = await response.json()
|
||||
@@ -274,7 +274,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
# Decode base64 audio data
|
||||
if "audios" not in response_data or not response_data["audios"]:
|
||||
logger.error("No audio data received from Sarvam API")
|
||||
await self.push_error(ErrorFrame("No audio data received"))
|
||||
await self.push_error(ErrorFrame(error="No audio data received"))
|
||||
return
|
||||
|
||||
# Get the first audio (there should be only one for single text input)
|
||||
@@ -296,7 +296,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
@@ -578,7 +578,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
# Reset state only after everything is cleaned up
|
||||
self._started = False
|
||||
@@ -602,7 +603,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -618,8 +620,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._websocket.send(json.dumps(config_message))
|
||||
logger.debug("Configuration sent successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send config: {str(e)}")
|
||||
await self.push_frame(ErrorFrame(f"Failed to send config: {str(e)}"))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
@@ -632,6 +634,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -661,7 +664,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
if "too long" in error_msg.lower() or "timeout" in error_msg.lower():
|
||||
logger.warning("Connection timeout detected, service may need restart")
|
||||
|
||||
await self.push_frame(ErrorFrame(f"TTS Error: {error_msg}"))
|
||||
await self.push_frame(ErrorFrame(error=f"TTS Error: {error_msg}"))
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Handle keepalive messages to maintain WebSocket connection."""
|
||||
@@ -717,7 +720,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -725,3 +729,4 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -327,8 +327,8 @@ class SonioxSTTService(STTService):
|
||||
# Expected when closing the connection
|
||||
logger.debug("WebSocket connection closed, keepalive task stopped.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error (_keepalive_task_handler): {e}")
|
||||
await self.push_error(ErrorFrame(f"{self} error (_keepalive_task_handler): {e}"))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
if not self._websocket:
|
||||
@@ -409,7 +409,7 @@ class SonioxSTTService(STTService):
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(
|
||||
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
error=f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -425,5 +425,5 @@ class SonioxSTTService(STTService):
|
||||
# Expected when closing the connection.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error: {e}")
|
||||
await self.push_error(ErrorFrame(f"{self} error: {e}"))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
@@ -467,8 +467,8 @@ class SpeechmaticsSTTService(STTService):
|
||||
await self._client.send_audio(audio)
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"Speechmatics error: {e}")
|
||||
yield ErrorFrame(f"Speechmatics error: {e}", fatal=False)
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
def update_params(
|
||||
@@ -514,6 +514,8 @@ class SpeechmaticsSTTService(STTService):
|
||||
self._client.send_message(payload), self.get_event_loop()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
raise RuntimeError(f"error sending message to STT: {e}")
|
||||
|
||||
async def _connect(self) -> None:
|
||||
@@ -579,7 +581,8 @@ class SpeechmaticsSTTService(STTService):
|
||||
logger.debug(f"{self} Connected to Speechmatics STT service")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error connecting to Speechmatics: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
self._client = None
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
@@ -593,7 +596,8 @@ class SpeechmaticsSTTService(STTService):
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self} Timeout while closing Speechmatics client connection")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error closing Speechmatics client: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
@@ -23,6 +25,8 @@ from typing import (
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -101,6 +105,16 @@ class TTSService(AIService):
|
||||
sample_rate: Optional[int] = None,
|
||||
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
# Types of text aggregations that should not be spoken.
|
||||
skip_aggregator_types: Optional[List[str]] = [],
|
||||
# A list of callables to transform text before just before sending it to TTS.
|
||||
# Each callable takes the aggregated text and its type, and returns the transformed text.
|
||||
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
|
||||
text_transforms: Optional[
|
||||
List[
|
||||
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
|
||||
]
|
||||
] = None,
|
||||
# Text filter executed after text has been aggregated.
|
||||
text_filters: Optional[Sequence[BaseTextFilter]] = None,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
@@ -120,6 +134,16 @@ class TTSService(AIService):
|
||||
pause_frame_processing: Whether to pause frame processing during audio generation.
|
||||
sample_rate: Output sample rate for generated audio.
|
||||
text_aggregator: Custom text aggregator for processing incoming text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
skip_aggregator_types: List of aggregation types that should not be spoken.
|
||||
text_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
|
||||
text_filters: Sequence of text filters to apply after aggregation.
|
||||
text_filter: Single text filter (deprecated, use text_filters).
|
||||
|
||||
@@ -142,6 +166,21 @@ class TTSService(AIService):
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
if text_aggregator:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
|
||||
self._text_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = text_transforms or []
|
||||
# TODO: Deprecate _text_filters when added to LLMTextProcessor
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
@@ -298,6 +337,39 @@ class TTSService(AIService):
|
||||
await self.cancel_task(self._stop_frame_task)
|
||||
self._stop_frame_task = None
|
||||
|
||||
def add_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
|
||||
the function should handle all text before sending to TTS.
|
||||
"""
|
||||
self._text_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._text_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._text_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
@@ -353,6 +425,8 @@ class TTSService(AIService):
|
||||
and frame.skip_tts
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AggregatedTextFrame):
|
||||
await self._push_tts_frames(frame)
|
||||
elif (
|
||||
isinstance(frame, TextFrame)
|
||||
and not isinstance(frame, InterimTranscriptionFrame)
|
||||
@@ -368,10 +442,10 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
sentence = self._text_aggregator.text
|
||||
aggregate = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
await self._push_tts_frames(sentence)
|
||||
await self._push_tts_frames(AggregatedTextFrame(aggregate.text, aggregate.type))
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -380,7 +454,7 @@ class TTSService(AIService):
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
await self._push_tts_frames(frame.text)
|
||||
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
@@ -472,13 +546,24 @@ class TTSService(AIService):
|
||||
text: Optional[str] = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
aggregated_by = "token"
|
||||
else:
|
||||
text = await self._text_aggregator.aggregate(frame.text)
|
||||
aggregate = await self._text_aggregator.aggregate(frame.text)
|
||||
if aggregate:
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(AggregatedTextFrame(text, aggregated_by))
|
||||
|
||||
async def _push_tts_frames(self, src_frame: AggregatedTextFrame):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
if type in self._skip_aggregator_types:
|
||||
await self.push_frame(src_frame)
|
||||
return
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
# Remove leading newlines only
|
||||
text = text.lstrip("\n")
|
||||
|
||||
@@ -499,15 +584,39 @@ class TTSService(AIService):
|
||||
await filter.reset_interruption()
|
||||
text = await filter.filter(text)
|
||||
|
||||
if text:
|
||||
await self.process_generator(self.run_tts(text))
|
||||
if not text.strip():
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# To support use cases that may want to know the text before it's spoken, we
|
||||
# push the AggregatedTextFrame version before transforming and sending to TTS.
|
||||
# However, we do not want to add this text to the assistant context until it
|
||||
# is spoken, so we set append_to_context to False.
|
||||
src_frame.append_to_context = False
|
||||
await self.push_frame(src_frame)
|
||||
|
||||
# Note: Text transformations are meant to only affect the text sent to the TTS for
|
||||
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
|
||||
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
|
||||
# services that support word-level timestamps, this CAN affect the resulting context
|
||||
# since the TTSTextFrames are generated from the TTS output stream
|
||||
transformed_text = text
|
||||
for aggregation_type, transform in self._text_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
transformed_text = await transform(transformed_text, type)
|
||||
await self.process_generator(self.run_tts(transformed_text))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
frame = TTSTextFrame(text)
|
||||
# In TTS services that support word timestamps, the TTSTextFrames
|
||||
# are pushed as words are spoken. However, in the case where the TTS service
|
||||
# does not support word timestamps (i.e. _push_text_frames is True), we send
|
||||
# the original (non-transformed) text after the TTS generation has completed.
|
||||
# This way, if we are interrupted, the text is not added to the assistant
|
||||
# context and the context that IS added does not include TTS-specific tags
|
||||
# or transformations.
|
||||
frame = TTSTextFrame(text, aggregated_by=type)
|
||||
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -635,7 +744,7 @@ class WordTTSService(TTSService):
|
||||
frame = TTSStoppedFrame()
|
||||
frame.pts = last_pts
|
||||
else:
|
||||
frame = TTSTextFrame(word)
|
||||
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
|
||||
frame.pts = self._initial_word_timestamp + timestamp
|
||||
if frame:
|
||||
last_pts = frame.pts
|
||||
|
||||
@@ -246,7 +246,8 @@ class UltravoxSTTService(AIService):
|
||||
|
||||
logger.info("Model warm-up completed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
|
||||
def _generate_silent_audio(self, sample_rate=16000, duration_sec=1.0):
|
||||
"""Generate silent audio as a numpy array.
|
||||
@@ -376,7 +377,7 @@ class UltravoxSTTService(AIService):
|
||||
if arr.size > 0: # Check if array is not empty
|
||||
audio_arrays.append(arr)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing bytes audio frame: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
# Handle numpy array data
|
||||
elif isinstance(f.audio, np.ndarray):
|
||||
if f.audio.size > 0: # Check if array is not empty
|
||||
@@ -436,14 +437,14 @@ class UltravoxSTTService(AIService):
|
||||
yield LLMFullResponseEndFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating text from model: {e}")
|
||||
yield ErrorFrame(f"Error generating text: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
else:
|
||||
logger.warning("No model available for text generation")
|
||||
logger.error("No model available for text generation")
|
||||
yield ErrorFrame("No model available for text generation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio buffer: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -94,7 +94,7 @@ class WebsocketService(ABC):
|
||||
if self._reconnect_on_error:
|
||||
retry_count += 1
|
||||
if retry_count >= MAX_RETRIES:
|
||||
await report_error(ErrorFrame(message, fatal=True))
|
||||
await report_error(ErrorFrame(message))
|
||||
break
|
||||
|
||||
logger.warning(f"{self} connection error, will retry: {e}")
|
||||
|
||||
@@ -226,8 +226,8 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
logger.warning("Received empty transcription from API")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during transcription: {e}")
|
||||
yield ErrorFrame(f"Error during transcription: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
"""Transcribe audio data to text.
|
||||
|
||||
@@ -428,5 +428,5 @@ class WhisperSTTServiceMLX(WhisperSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"MLX Whisper transcription error: {e}")
|
||||
yield ErrorFrame(f"MLX Whisper transcription error: {str(e)}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -146,7 +146,7 @@ class XTTSService(TTSService):
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(
|
||||
f"Error error getting studio speakers (status: {r.status}, error: {text})"
|
||||
error=f"Error getting studio speakers (status: {r.status}, error: {text})"
|
||||
)
|
||||
)
|
||||
return
|
||||
@@ -187,7 +187,7 @@ class XTTSService(TTSService):
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(error=f"Error getting audio (status: {r.status}, error: {text})")
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
@@ -203,8 +203,16 @@ async def run_test(
|
||||
if not isinstance(frame, EndFrame) or not send_end_frame:
|
||||
received_down_frames.append(frame)
|
||||
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
print("expected DOWN frames =", expected_down_frames)
|
||||
down_frames_printed = "["
|
||||
for frame in received_down_frames:
|
||||
down_frames_printed += f"{frame.__class__.__name__}, "
|
||||
down_frames_printed += "]"
|
||||
expected_frames_printed = "["
|
||||
for frame in expected_down_frames:
|
||||
expected_frames_printed += f"{frame.__name__}, "
|
||||
expected_frames_printed += "]"
|
||||
print("received DOWN frames =", down_frames_printed)
|
||||
print("expected DOWN frames =", expected_frames_printed)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
|
||||
@@ -12,9 +12,46 @@ aggregated text should be sent for speech synthesis.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aggregation:
|
||||
"""Data class representing aggregated text and its type.
|
||||
|
||||
An Aggregation object is created whenever a stream of text is aggregated by
|
||||
a text aggregator. It contains the aggregated text and a type indicating
|
||||
the nature of the aggregation.
|
||||
|
||||
Parameters:
|
||||
text: The aggregated text content.
|
||||
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the aggregation.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the type and text of the aggregation.
|
||||
"""
|
||||
return f"Aggregation by {self.type}: {self.text}"
|
||||
|
||||
|
||||
class BaseTextAggregator(ABC):
|
||||
"""Base class for text aggregators in the Pipecat framework.
|
||||
|
||||
@@ -30,7 +67,7 @@ class BaseTextAggregator(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Subclasses must implement this property to return the text that has
|
||||
@@ -42,12 +79,13 @@ class BaseTextAggregator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate the specified text with the currently accumulated text.
|
||||
|
||||
This method should be implemented to define how the new text contributes
|
||||
to the aggregation process. It returns the updated aggregated text if
|
||||
it's ready to be processed, or None otherwise.
|
||||
to the aggregation process. It returns the aggregated text and a string
|
||||
describing how it was aggregated if it's ready to be processed,
|
||||
or None otherwise.
|
||||
|
||||
Subclasses should implement their specific logic for:
|
||||
|
||||
|
||||
@@ -8,19 +8,41 @@
|
||||
|
||||
This module provides an aggregator that identifies and processes content between
|
||||
pattern pairs (like XML tags or custom delimiters) in streaming text, with
|
||||
support for custom handlers and configurable pattern removal.
|
||||
support for custom handlers and configurable actions for when a pattern is found.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Awaitable, Callable, Optional, Tuple
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class PatternMatch:
|
||||
class MatchAction(Enum):
|
||||
"""Actions to take when a pattern pair is matched.
|
||||
|
||||
Parameters:
|
||||
REMOVE: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
KEEP: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
AGGREGATE: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP = "keep"
|
||||
AGGREGATE = "aggregate"
|
||||
|
||||
|
||||
class PatternMatch(Aggregation):
|
||||
"""Represents a matched pattern pair with its content.
|
||||
|
||||
A PatternMatch object is created when a complete pattern pair is found
|
||||
@@ -29,25 +51,25 @@ class PatternMatch:
|
||||
content between the patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern_id: str, full_match: str, content: str):
|
||||
def __init__(self, content: str, type: str, full_match: str):
|
||||
"""Initialize a pattern match.
|
||||
|
||||
Args:
|
||||
pattern_id: The identifier of the matched pattern pair.
|
||||
type: The type of the matched pattern pair. It should be representative
|
||||
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
|
||||
full_match: The complete text including start and end patterns.
|
||||
content: The text content between the start and end patterns.
|
||||
"""
|
||||
self.pattern_id = pattern_id
|
||||
super().__init__(text=content, type=type)
|
||||
self.full_match = full_match
|
||||
self.content = content
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the pattern match.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the pattern ID and content.
|
||||
A descriptive string showing the pattern type and content.
|
||||
"""
|
||||
return f"PatternMatch(id={self.pattern_id}, content={self.content})"
|
||||
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
|
||||
|
||||
|
||||
class PatternPairAggregator(BaseTextAggregator):
|
||||
@@ -55,16 +77,21 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
This aggregator buffers text until it can identify complete pattern pairs
|
||||
(defined by start and end patterns), processes the content between these
|
||||
patterns using registered handlers, and returns text at sentence boundaries.
|
||||
It's particularly useful for processing structured content in streaming text,
|
||||
such as XML tags, markdown formatting, or custom delimiters.
|
||||
patterns using registered handlers. By default, its aggregation method
|
||||
returns text at sentence boundaries, and remove the content found between
|
||||
any matched patterns. However, matched patterns can also be configured to
|
||||
returned as a separate aggregation object containing the content between
|
||||
their start and end patterns or left in, so that only the delimiters are
|
||||
removed and a callback can be triggered.
|
||||
|
||||
This aggregator is particularly useful for processing structured content in
|
||||
streaming text, such as XML tags, markdown formatting, or custom delimiters.
|
||||
|
||||
The aggregator ensures that patterns spanning multiple text chunks are
|
||||
correctly identified and handles cases where patterns contain sentence
|
||||
boundaries.
|
||||
correctly identified.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the pattern pair aggregator.
|
||||
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
@@ -75,16 +102,23 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._handlers = {}
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Get the currently buffered text.
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return self._text
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start:
|
||||
return Aggregation(self._text, pattern_start[1].get("type", AggregationType.SENTENCE))
|
||||
return Aggregation(self._text, AggregationType.SENTENCE)
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
def add_pattern(
|
||||
self,
|
||||
type: str,
|
||||
start_pattern: str,
|
||||
end_pattern: str,
|
||||
action: MatchAction = MatchAction.REMOVE,
|
||||
) -> "PatternPairAggregator":
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
@@ -93,41 +127,94 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
the end pattern, and treat the content between them as a match.
|
||||
|
||||
Args:
|
||||
pattern_id: Unique identifier for this pattern pair.
|
||||
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
remove_match: Whether to remove the matched content from the text.
|
||||
action: What to do when a complete pattern is matched:
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._patterns[pattern_id] = {
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
|
||||
raise ValueError(
|
||||
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
|
||||
)
|
||||
self._patterns[type] = {
|
||||
"start": start_pattern,
|
||||
"end": end_pattern,
|
||||
"remove_match": remove_match,
|
||||
"type": type,
|
||||
"action": action,
|
||||
}
|
||||
return self
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
):
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use `add_pattern` with a type and MatchAction instead.
|
||||
|
||||
This method calls `add_pattern` setting type with the provided pattern_id and action
|
||||
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
|
||||
|
||||
Args:
|
||||
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
|
||||
as those arereserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
|
||||
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once")
|
||||
warnings.warn(
|
||||
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
|
||||
" removed in a future version. Use add_pattern with a type and MatchAction instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
|
||||
return self.add_pattern(
|
||||
type=pattern_id,
|
||||
start_pattern=start_pattern,
|
||||
end_pattern=end_pattern,
|
||||
action=action,
|
||||
)
|
||||
|
||||
def on_pattern_match(
|
||||
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
) -> "PatternPairAggregator":
|
||||
"""Register a handler for when a pattern pair is matched.
|
||||
|
||||
The handler will be called whenever a complete match for the
|
||||
specified pattern ID is found in the text.
|
||||
specified type is found in the text.
|
||||
|
||||
Args:
|
||||
pattern_id: ID of the pattern pair to match.
|
||||
type: The type of the pattern pair to trigger the handler.
|
||||
handler: Async function to call when pattern is matched.
|
||||
The function should accept a PatternMatch object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._handlers[pattern_id] = handler
|
||||
self._handlers[type] = handler
|
||||
return self
|
||||
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]:
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
|
||||
"""Process all complete pattern pairs in the text.
|
||||
|
||||
Searches for all complete pattern pairs in the text, calls the
|
||||
@@ -137,19 +224,19 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to process for pattern matches.
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_text, was_modified) where:
|
||||
Tuple of (all_matches, processed_text) where:
|
||||
|
||||
- processed_text is the text after processing patterns
|
||||
- was_modified indicates whether any changes were made
|
||||
- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
|
||||
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
|
||||
"""
|
||||
all_matches = []
|
||||
processed_text = text
|
||||
modified = False
|
||||
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
for type, pattern_info in self._patterns.items():
|
||||
# Escape special regex characters in the patterns
|
||||
start = re.escape(pattern_info["start"])
|
||||
end = re.escape(pattern_info["end"])
|
||||
remove_match = pattern_info["remove_match"]
|
||||
action = pattern_info["action"]
|
||||
|
||||
# Create regex to match from start pattern to end pattern
|
||||
# The .*? is non-greedy to handle nested patterns
|
||||
@@ -164,25 +251,24 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
full_match = match.group(0) # Full match including patterns
|
||||
|
||||
# Create pattern match object
|
||||
pattern_match = PatternMatch(
|
||||
pattern_id=pattern_id, full_match=full_match, content=content
|
||||
)
|
||||
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
|
||||
|
||||
# Call the appropriate handler if registered
|
||||
if pattern_id in self._handlers:
|
||||
if type in self._handlers:
|
||||
try:
|
||||
await self._handlers[pattern_id](pattern_match)
|
||||
await self._handlers[type](pattern_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
|
||||
logger.error(f"Error in pattern handler for {type}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
if remove_match:
|
||||
if action == MatchAction.REMOVE:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
modified = True
|
||||
else:
|
||||
all_matches.append(pattern_match)
|
||||
|
||||
return processed_text, modified
|
||||
return all_matches, processed_text
|
||||
|
||||
def _has_incomplete_patterns(self, text: str) -> bool:
|
||||
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
|
||||
"""Check if text contains incomplete pattern pairs.
|
||||
|
||||
Determines whether the text contains any start patterns without
|
||||
@@ -192,9 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to check for incomplete patterns.
|
||||
|
||||
Returns:
|
||||
True if there are incomplete patterns, False otherwise.
|
||||
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
|
||||
or None if no patterns are found or all patterns are complete.
|
||||
"""
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
for type, pattern_info in self._patterns.items():
|
||||
start = pattern_info["start"]
|
||||
end = pattern_info["end"]
|
||||
|
||||
@@ -203,12 +290,16 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
end_count = text.count(end)
|
||||
|
||||
# If there are more starts than ends, we have incomplete patterns
|
||||
# Again, this is written generically but there only ever should
|
||||
# be one pattern active at a time, so the counts should be 0 or 1.
|
||||
# Which is why we base the return on the first found.
|
||||
if start_count > end_count:
|
||||
return True
|
||||
start_index = text.find(start)
|
||||
return [start_index, pattern_info]
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[PatternMatch]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete pattern
|
||||
@@ -227,16 +318,34 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._text += text
|
||||
|
||||
# Process any complete patterns in the buffer
|
||||
processed_text, modified = await self._process_complete_patterns(self._text)
|
||||
patterns, processed_text = await self._process_complete_patterns(self._text)
|
||||
|
||||
# Only update the buffer if modifications were made
|
||||
if modified:
|
||||
self._text = processed_text
|
||||
self._text = processed_text
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
if self._has_incomplete_patterns(self._text):
|
||||
# Still waiting for complete patterns
|
||||
return None
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
return None
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
@@ -244,7 +353,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return result
|
||||
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@ text processing scenarios.
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class SimpleTextAggregator(BaseTextAggregator):
|
||||
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
self._text = ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(self._text, AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text and return completed sentences.
|
||||
|
||||
Adds the new text to the buffer and checks for end-of-sentence markers.
|
||||
@@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
result = self._text[:eos_end_marker]
|
||||
self._text = self._text[eos_end_marker:]
|
||||
|
||||
return result
|
||||
return Aggregation(result, AggregationType.SENTENCE) if result else None
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the text buffer.
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class SkipTagsAggregator(BaseTextAggregator):
|
||||
@@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(self._text, AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text while respecting tag boundaries.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete
|
||||
@@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return result
|
||||
return Aggregation(result, AggregationType.SENTENCE)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -7,30 +7,42 @@
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
|
||||
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.aggregator = PatternPairAggregator()
|
||||
self.test_handler = AsyncMock()
|
||||
self.code_handler = AsyncMock()
|
||||
|
||||
# Add a test pattern
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="test_pattern",
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
remove_match=True,
|
||||
)
|
||||
self.aggregator.add_pattern(
|
||||
type="code_pattern",
|
||||
start_pattern="<code>",
|
||||
end_pattern="</code>",
|
||||
action=MatchAction.AGGREGATE,
|
||||
)
|
||||
|
||||
# Register the mock handler
|
||||
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
|
||||
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
||||
|
||||
async def test_pattern_match_and_removal(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
@@ -39,20 +51,49 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.pattern_id, "test_pattern")
|
||||
self.assertEqual(call_args.type, "test_pattern")
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.content, "pattern content")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
self.assertEqual(result, "Hello !")
|
||||
self.assertEqual(result.text, "Hello !")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result, " This is another sentence.")
|
||||
self.assertEqual(result.text, " This is another sentence.")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_pattern_match_and_aggregate(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
self.assertEqual(result.text, "Here is code ")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.code_handler.assert_called_once()
|
||||
call_args = self.code_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.type, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(result.text, "pattern content")
|
||||
self.assertEqual(result.type, "code_pattern")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result.text, " This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
@@ -65,26 +106,30 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_not_called()
|
||||
|
||||
# Buffer should contain the incomplete text
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Reset and confirm buffer is cleared
|
||||
await self.aggregator.reset()
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_multiple_patterns(self):
|
||||
# Set up multiple patterns and handlers
|
||||
voice_handler = AsyncMock()
|
||||
emphasis_handler = AsyncMock()
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
|
||||
self.aggregator.add_pattern(
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="emphasis",
|
||||
self.aggregator.add_pattern(
|
||||
type="emphasis",
|
||||
start_pattern="<em>",
|
||||
end_pattern="</em>",
|
||||
remove_match=False, # Keep emphasis tags
|
||||
action=MatchAction.KEEP, # Keep emphasis tags
|
||||
)
|
||||
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
@@ -97,19 +142,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Both handlers should be called with correct data
|
||||
voice_handler.assert_called_once()
|
||||
voice_match = voice_handler.call_args[0][0]
|
||||
self.assertEqual(voice_match.pattern_id, "voice")
|
||||
self.assertEqual(voice_match.content, "female")
|
||||
self.assertEqual(voice_match.type, "voice")
|
||||
self.assertEqual(voice_match.text, "female")
|
||||
|
||||
emphasis_handler.assert_called_once()
|
||||
emphasis_match = emphasis_handler.call_args[0][0]
|
||||
self.assertEqual(emphasis_match.pattern_id, "emphasis")
|
||||
self.assertEqual(emphasis_match.content, "very")
|
||||
self.assertEqual(emphasis_match.type, "emphasis")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_handle_interruption(self):
|
||||
# Start with incomplete pattern
|
||||
@@ -120,7 +165,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.handle_interruption()
|
||||
|
||||
# Buffer should be cleared
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
# Handler should not have been called
|
||||
self.test_handler.assert_not_called()
|
||||
@@ -138,10 +183,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Handler should be called with entire content
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
|
||||
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
||||
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result, "Hello Final sentence.")
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
@@ -13,6 +13,7 @@ import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
@@ -74,6 +75,7 @@ async def test_run_piper_tts_success(aiohttp_client):
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
AggregatedTextFrame,
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -121,7 +123,7 @@ async def test_run_piper_tts_error(aiohttp_client):
|
||||
TTSSpeakFrame(text="Error case."),
|
||||
]
|
||||
|
||||
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
|
||||
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
|
||||
@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert self.aggregator.text == "Hello "
|
||||
assert self.aggregator.text.text == "Hello "
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text == ""
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
|
||||
assert self.aggregator.text == ""
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
|
||||
assert await self.aggregator.aggregate("you?") == " How are you?"
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert self.aggregator.text.text == " How are "
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == " How are you?"
|
||||
|
||||
@@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# No tags involved, aggregate at end of sentence.
|
||||
result = await self.aggregator.aggregate("Hello Pipecat!")
|
||||
self.assertEqual(result, "Hello Pipecat!")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "Hello Pipecat!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_basic_tags(self):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, avoid aggregation during tags.
|
||||
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_streaming_tags(self):
|
||||
await self.aggregator.reset()
|
||||
@@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Tags involved, stream small chunk of texts.
|
||||
result = await self.aggregator.aggregate("My email is <sp")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <sp")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <sp")
|
||||
|
||||
result = await self.aggregator.aggregate("ell>foo.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
|
||||
|
||||
result = await self.aggregator.aggregate("bar@pipecat.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
|
||||
|
||||
result = await self.aggregator.aggregate("ai</spe")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
result = await self.aggregator.aggregate("ll>.")
|
||||
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
@@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -130,11 +131,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(), # Wait for StartedSpeaking to process
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
TTSTextFrame(text="How"),
|
||||
TTSTextFrame(text="are"),
|
||||
TTSTextFrame(text="you?"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(), # Wait for text frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -195,9 +196,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text=""), # Empty text
|
||||
TTSTextFrame(text=" "), # Just whitespace
|
||||
TTSTextFrame(text="\n"), # Just newline
|
||||
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
|
||||
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
|
||||
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
@@ -235,14 +236,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(),
|
||||
InterruptionFrame(), # User interrupts here
|
||||
SleepFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="New"),
|
||||
TTSTextFrame(text="response"),
|
||||
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -299,8 +300,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
|
||||
@@ -338,8 +339,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(), # Ensure messages are processed
|
||||
CancelFrame(),
|
||||
]
|
||||
@@ -401,8 +402,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Assistant"),
|
||||
TTSTextFrame(text="message"),
|
||||
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
@@ -439,7 +440,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Test the specific pattern shared
|
||||
def make_tts_text_frame(text: str) -> TTSTextFrame:
|
||||
frame = TTSTextFrame(text=text)
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
|
||||
Reference in New Issue
Block a user