Compare commits
23 Commits
aleix/dont
...
bot-output
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8640d84ae | ||
|
|
23e4e29999 | ||
|
|
713b488bb6 | ||
|
|
71b87fd420 | ||
|
|
3f269f9834 | ||
|
|
4c698777f3 | ||
|
|
5ca04ad741 | ||
|
|
9a3902a82c | ||
|
|
8ab0c92681 | ||
|
|
124f147a37 | ||
|
|
ed808a9246 | ||
|
|
e9de9daf8c | ||
|
|
82b9c4f0b6 | ||
|
|
5dfe20be91 | ||
|
|
0d2c5286fa | ||
|
|
29417ba44d | ||
|
|
bc6a9cac26 | ||
|
|
8a90decbc0 | ||
|
|
ccca6e8d81 | ||
|
|
e6dc1a510d | ||
|
|
69945c5e0d | ||
|
|
5c8635570d | ||
|
|
fe9aa3383e |
107
CHANGELOG.md
107
CHANGELOG.md
@@ -16,6 +16,82 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
services that subclass `TTSService` can indicate whether the text in the
|
||||
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
|
||||
|
||||
- Introduced new `AggregatedTextFrame` type to support representing a best effort of
|
||||
the perceived llm output whether or not it is processed by the TTS. This new frame
|
||||
type includes the field `aggregated_by` to represent the conceptual format by which
|
||||
the given text is aggregated. `TTSTextFrame`s now inherit from `AggregatedTextFrame`.
|
||||
With this inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate
|
||||
the perceived output and determine whether or not the text was spoken based on if that
|
||||
frame is also a `TTSTextFrame`. (See bullet below on new `bot-output` which takes
|
||||
advantage of this)
|
||||
|
||||
- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how
|
||||
LLMTextFrames should be aggregated and considered. It's purpose is to turn
|
||||
`LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still
|
||||
aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you
|
||||
wish to override how the llm text is aggregated, you should no longer override the
|
||||
TTS's internal aggregator, but instead, insert this processor between your LLM and
|
||||
TTS in the pipeline.
|
||||
|
||||
- New `bot-output` RTVI message to represent what the bot actually "says".
|
||||
- The `RTVIObserver` now emits `bot-output` messages based off the new `AggregatedTextFrame`s
|
||||
(`bot-tts-text` and `bot-llm-text` are still supported and generated, but `bot-transcript` is
|
||||
now deprecated in lieu of this new, more thorough, message).
|
||||
- The new `RTVIBotOutputMessage` includes the fields:
|
||||
- `spoken`: A boolean indicating whether the text was spoken by TTS
|
||||
- `aggregated_by`: A string representing how the text was aggregated ("sentence", "word",
|
||||
"my custom aggregation")
|
||||
- Introduced new fields to `RTVIObserver` to support the new `bot-output` messaging:
|
||||
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output messages.
|
||||
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that match
|
||||
aggregation types that should not be included in bot-output messages. (Ex. `credit_card`)
|
||||
- Introduced new methods, `add_text_transformer()` and `remove_text_transformer()`, to `RTVIObserver` to support providing (and subsequently removing)
|
||||
callbacks for various types of aggregations (or all aggregations with `*`) that can modify the
|
||||
text before being sent as a `bot-output` or `tts-text` message. (Think obscuring the credit card
|
||||
or inserting extra detail the client might want that the context doesn't need.)
|
||||
|
||||
- Updated the base aggregator type:
|
||||
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
|
||||
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
|
||||
aggregation")
|
||||
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
|
||||
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
|
||||
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
|
||||
(instead of `Optional[str]`). To update:
|
||||
```
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
if (aggregation):
|
||||
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
|
||||
```
|
||||
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
|
||||
produce/consume `Aggregation` objects.
|
||||
|
||||
- Augmented the `PatternPairAggregator`:
|
||||
- Introduced a new, preferred version of `add_pattern` to support a new option for treating a
|
||||
match as a separate aggregation returned from `aggregate()`. This replaces the now
|
||||
deprecated `add_pattern_pair` method and you provide a `MatchAction` in lieu of the `remove_match` field.
|
||||
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization for how
|
||||
a match should be handled.
|
||||
- `REMOVE`: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
- `KEEP`: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
- `AGGREGATE`: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
- `PatternMatch` now extends `Aggregation` and provides richer info to handlers.
|
||||
- **BREAKING**: The `PatternMatch` type returned to handlers registered via `on_pattern_match`
|
||||
has been updated to subclass from the new `Aggregation` type, which means that `content`
|
||||
has been replaced with `text` and `pattern_id` has been replaced with `type`:
|
||||
```
|
||||
async dev on_match_tag(match: PatternMatch):
|
||||
pattern = match.type # instead of match.pattern_id
|
||||
text = match.text # instead of match.content
|
||||
```
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated all STT and TTS services to use consistent error handling pattern with
|
||||
@@ -33,11 +109,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Updated language mappings for the Google and Gemini TTS services to match
|
||||
official documentation.
|
||||
|
||||
- `TextFrame` new field `append_to_context` used to indicate if the encompassing
|
||||
text should be added to the LLM context (by the LLM assistant aggregator). It
|
||||
defaults to `True`.
|
||||
|
||||
- TTS flow respects aggregation metadata
|
||||
- `TTSService` accepts a new `skip_aggregator_types` to avoid speaking certain aggregation types
|
||||
(now determined/returned by the aggregator)
|
||||
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when either an
|
||||
aggregation occurs that should not be spoken or when the TTS service supports word-by-word
|
||||
timestamping. In the latter case, the `TTSService` preliminarily generates an
|
||||
`AggregatedTextFrame`, aggregated by sentence to generate the full sentence content as early
|
||||
as possible.
|
||||
- Introduced a new methods, `add_text_transformer()` and `remove_text_transformer()`:
|
||||
These functions introduce the ability to provide (and subsequently remove) callbacks to the TTS to transform text based on
|
||||
its aggregated type prior to sending the text to the underlying TTS service. This makes it
|
||||
possible to do things like introduce TTS-specific tags for spelling or emotion or change the
|
||||
pronunciation of something on the fly.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
|
||||
`credentials` or `credentials_path` instead for Google Cloud authentication.
|
||||
|
||||
- The RTVI `bot-transcription` event is deprecated in favor of the new `bot-output`
|
||||
message which is the canonical representation of bot output (spoken or not). The code
|
||||
still emits a transcription message for backwards compatibility while transition occurs.
|
||||
|
||||
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
|
||||
`LLMTextProcessor`. TTSServices still have an internal aggregator for support of default
|
||||
behavior, but if you want to override the aggregation behavior, you should use the new
|
||||
processor.
|
||||
|
||||
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a `pattern_id`
|
||||
and `remove_match` field in favor of the new `add_pattern` method which takes a `type` and an
|
||||
`action`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed subtle issue of assistant context messages ending up with double spaces
|
||||
|
||||
@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -106,16 +110,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
pattern_aggregator = PatternPairAggregator()
|
||||
|
||||
# Add pattern for voice switching
|
||||
pattern_aggregator.add_pattern_pair(
|
||||
pattern_id="voice_tag",
|
||||
pattern_aggregator.add_pattern(
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
remove_match=True,
|
||||
action=MatchAction.REMOVE, # Remove tags from final text
|
||||
)
|
||||
|
||||
# Register handler for voice switching
|
||||
async def on_voice_tag(match: PatternMatch):
|
||||
voice_name = match.content.strip().lower()
|
||||
voice_name = match.text.strip().lower()
|
||||
if voice_name in VOICE_IDS:
|
||||
# First flush any existing audio to finish the current context
|
||||
await tts.flush_audio()
|
||||
@@ -125,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
else:
|
||||
logger.warning(f"Unknown voice: {voice_name}")
|
||||
|
||||
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
|
||||
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
|
||||
@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
|
||||
class IVRStatus(Enum):
|
||||
@@ -114,15 +118,15 @@ class IVRProcessor(FrameProcessor):
|
||||
def _setup_xml_patterns(self):
|
||||
"""Set up XML pattern detection and handlers."""
|
||||
# Register DTMF pattern
|
||||
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", remove_match=True)
|
||||
self._aggregator.add_pattern("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
|
||||
|
||||
# Register mode pattern
|
||||
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", remove_match=True)
|
||||
self._aggregator.add_pattern("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
|
||||
|
||||
# Register IVR pattern
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", remove_match=True)
|
||||
self._aggregator.add_pattern("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -159,7 +163,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing DTMF content.
|
||||
"""
|
||||
value = match.content
|
||||
value = match.text
|
||||
logger.debug(f"DTMF detected: {value}")
|
||||
|
||||
try:
|
||||
@@ -180,7 +184,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing IVR status content.
|
||||
"""
|
||||
status = match.content
|
||||
status = match.text
|
||||
logger.trace(f"IVR status detected: {status}")
|
||||
|
||||
# Convert string to enum, with validation
|
||||
@@ -211,7 +215,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing mode content.
|
||||
"""
|
||||
mode = match.content
|
||||
mode = match.text
|
||||
logger.debug(f"Mode detected: {mode}")
|
||||
if mode == "conversation":
|
||||
await self._handle_conversation()
|
||||
|
||||
@@ -12,6 +12,7 @@ and LLM processing.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -337,11 +338,14 @@ class TextFrame(DataFrame):
|
||||
# mandatory fields of theirs to have defaults to preserve
|
||||
# non-default-before-default argument order)
|
||||
includes_inter_frame_spaces: bool = field(init=False)
|
||||
# Whether this text frame should be appended to the LLM context.
|
||||
append_to_context: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
@@ -355,8 +359,32 @@ class LLMTextFrame(TextFrame):
|
||||
pass
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSTextFrame(TextFrame):
|
||||
class AggregatedTextFrame(TextFrame):
|
||||
"""Text frame representing an aggregation of TextFrames.
|
||||
|
||||
This frame contains multiple TextFrames aggregated together for processing
|
||||
or output along with a field to indicate how they are aggregated.
|
||||
|
||||
Parameters:
|
||||
aggregated_by: Method used to aggregate the text frames.
|
||||
"""
|
||||
|
||||
aggregated_by: AggregationType | str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSTextFrame(AggregatedTextFrame):
|
||||
"""Text frame generated by Text-to-Speech services."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1001,7 +1001,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
if not self._started or not frame.append_to_context:
|
||||
return
|
||||
|
||||
if self._params.expect_stripped_words:
|
||||
|
||||
@@ -814,7 +814,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
if not self._started or not frame.append_to_context:
|
||||
return
|
||||
|
||||
# Make sure we really have text (spaces count, too!)
|
||||
|
||||
106
src/pipecat/processors/aggregators/llm_text_processor.py
Normal file
106
src/pipecat/processors/aggregators/llm_text_processor.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LLM text processor module for processing and aggregating raw LLM output text.
|
||||
|
||||
This processor will convert LLMTextFrames into AggregatedTextFrames based on the
|
||||
configured text aggregator. Using the customizable aggregator, it provides
|
||||
functionality to handle or manipulate LLM text frames before they are sent to other
|
||||
components such as TTS services or context aggregators. It can be used to pre-aggregate
|
||||
and categorize, modify, or filter direct output tokens from the LLM.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class LLMTextProcessor(FrameProcessor):
|
||||
"""A processor for handling or manipulating LLM text frames before they are processed further.
|
||||
|
||||
This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured
|
||||
text aggregator. Using the customizable aggregator, it provides functionality to handle or
|
||||
manipulate LLM text frames before they are sent to other components such as TTS services or
|
||||
context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct
|
||||
output tokens from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs):
|
||||
"""Initialize the LLM text processor.
|
||||
|
||||
Args:
|
||||
text_aggregator: An optional text aggregator to use for processing LLM text frames. By
|
||||
default, a SimpleTextAggregator aggregating by sentence will be used.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
|
||||
TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame.skip_tts)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._handle_llm_end()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_interruption(self, _):
|
||||
"""Handle interruptions by resetting the text aggregator."""
|
||||
await self._text_aggregator.handle_interruption()
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the internal state of the text processor and its aggregator."""
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
@@ -24,6 +24,7 @@ from typing import (
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
@@ -32,6 +33,8 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.audio.utils import calculate_audio_volume
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -704,6 +707,29 @@ class RTVITextMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotOutputMessageData(RTVITextMessageData):
|
||||
"""Data for bot output RTVI messages.
|
||||
|
||||
Extends RTVITextMessageData to include metadata about the output.
|
||||
"""
|
||||
|
||||
spoken: bool = False # Indicates if the text has been spoken by TTS
|
||||
aggregated_by: AggregationType | str
|
||||
# Indicates what form the text is in (e.g., by word, sentence, etc.)
|
||||
|
||||
|
||||
class RTVIBotOutputMessage(BaseModel):
|
||||
"""Message containing bot output text.
|
||||
|
||||
An event meant to holistically represent what the bot is outputting,
|
||||
along with metadata about the output and if it has been spoken.
|
||||
"""
|
||||
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-output"] = "bot-output"
|
||||
data: RTVIBotOutputMessageData
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
"""Message containing bot transcription text.
|
||||
|
||||
@@ -896,6 +922,7 @@ class RTVIObserverParams:
|
||||
Parameter `errors_enabled` is deprecated. Error messages are always enabled.
|
||||
|
||||
Parameters:
|
||||
bot_output_enabled: Indicates if bot output messages should be sent.
|
||||
bot_llm_enabled: Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled: Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled: Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
@@ -907,9 +934,17 @@ class RTVIObserverParams:
|
||||
metrics_enabled: Indicates if metrics messages should be sent.
|
||||
system_logs_enabled: Indicates if system logs should be sent.
|
||||
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
|
||||
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
|
||||
Note: if using this to avoid sending secure information, be sure to also disable
|
||||
bot_llm_enabled to avoid leaking through LLM messages.
|
||||
bot_output_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
audio_level_period_secs: How often audio levels should be sent if enabled.
|
||||
"""
|
||||
|
||||
bot_output_enabled: bool = True
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
@@ -921,6 +956,15 @@ class RTVIObserverParams:
|
||||
metrics_enabled: bool = True
|
||||
system_logs_enabled: bool = False
|
||||
errors_enabled: Optional[bool] = None
|
||||
skip_aggregator_types: Optional[List[AggregationType | str]] = None
|
||||
bot_output_transforms: Optional[
|
||||
List[
|
||||
Tuple[
|
||||
AggregationType | str,
|
||||
Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
]
|
||||
]
|
||||
] = None
|
||||
audio_level_period_secs: float = 0.15
|
||||
|
||||
|
||||
@@ -973,8 +1017,45 @@ class RTVIObserver(BaseObserver):
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._aggregation_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = self._params.bot_output_transforms or []
|
||||
|
||||
def add_bot_output_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" to
|
||||
handle all text before sending to the client.
|
||||
"""
|
||||
self._aggregation_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_bot_output_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._aggregation_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._aggregation_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _logger_sink(self, message):
|
||||
"""Logger sink so we cna send system logs to RTVI clients."""
|
||||
"""Logger sink so we can send system logs to RTVI clients."""
|
||||
message = RTVISystemLogMessage(data=RTVITextMessageData(text=message))
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
@@ -1048,12 +1129,15 @@ class RTVIObserver(BaseObserver):
|
||||
await self.send_rtvi_message(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.send_rtvi_message(message)
|
||||
else:
|
||||
elif isinstance(frame, AggregatedTextFrame) and (
|
||||
self._params.bot_output_enabled or self._params.bot_tts_enabled
|
||||
):
|
||||
if isinstance(frame, TTSTextFrame) and not isinstance(src, BaseOutputTransport):
|
||||
# This check is to make sure we handle the frame when it has gone
|
||||
# through the transport and has correct timing.
|
||||
mark_as_seen = False
|
||||
else:
|
||||
await self._handle_aggregated_llm_text(frame)
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
@@ -1084,15 +1168,6 @@ class RTVIObserver(BaseObserver):
|
||||
if mark_as_seen:
|
||||
self._frames_seen.add(frame.id)
|
||||
|
||||
async def _push_bot_transcription(self):
|
||||
"""Push accumulated bot transcription as a message."""
|
||||
if len(self._bot_transcription) > 0:
|
||||
message = RTVIBotTranscriptionMessage(
|
||||
data=RTVITextMessageData(text=self._bot_transcription)
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
self._bot_transcription = ""
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
"""Handle user speaking interruption frames."""
|
||||
message = None
|
||||
@@ -1115,14 +1190,45 @@ class RTVIObserver(BaseObserver):
|
||||
if message:
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
|
||||
"""Handle aggregated LLM text output frames."""
|
||||
# Skip certain aggregator types if configured to do so.
|
||||
if (
|
||||
self._params.skip_aggregator_types
|
||||
and frame.aggregated_by in self._params.skip_aggregator_types
|
||||
):
|
||||
return
|
||||
|
||||
text = frame.text
|
||||
type = frame.aggregated_by
|
||||
for aggregation_type, transform in self._aggregation_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
text = await transform(text, type)
|
||||
|
||||
isTTS = isinstance(frame, TTSTextFrame)
|
||||
if self._params.bot_output_enabled:
|
||||
message = RTVIBotOutputMessage(
|
||||
data=RTVIBotOutputMessageData(text=text, spoken=isTTS, aggregated_by=type)
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
if isTTS and self._params.bot_tts_enabled:
|
||||
tts_message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.send_rtvi_message(tts_message)
|
||||
|
||||
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
|
||||
"""Handle LLM text output frames."""
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
# TODO (mrkb): Remove all this logic when we fully deprecate bot-transcription messages.
|
||||
self._bot_transcription += frame.text
|
||||
if match_endofsentence(self._bot_transcription):
|
||||
await self._push_bot_transcription()
|
||||
|
||||
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
|
||||
await self.send_rtvi_message(
|
||||
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
|
||||
)
|
||||
self._bot_transcription = ""
|
||||
|
||||
async def _handle_user_transcriptions(self, frame: Frame):
|
||||
"""Handle user transcription frames."""
|
||||
@@ -1248,7 +1354,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
# Default to 0.3.0 which is the last version before actually having a
|
||||
# "client-version".
|
||||
self._client_version = [0, 3, 0]
|
||||
self._skip_tts: bool = False # Keep in sync with llm_service.py
|
||||
self._llm_skip_tts: bool = False # Keep in sync with llm_service.py's configuration.
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -1441,7 +1547,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
elif isinstance(frame, RTVIActionFrame):
|
||||
await self._action_queue.put(frame)
|
||||
elif isinstance(frame, LLMConfigureOutputFrame):
|
||||
self._skip_tts = frame.skip_tts
|
||||
self._llm_skip_tts = frame.skip_tts
|
||||
await self.push_frame(frame, direction)
|
||||
# Other frames
|
||||
else:
|
||||
@@ -1697,9 +1803,9 @@ class RTVIProcessor(FrameProcessor):
|
||||
opts = data.options if data.options is not None else RTVISendTextOptions()
|
||||
if opts.run_immediately:
|
||||
await self.interrupt_bot()
|
||||
cur_skip_tts = self._skip_tts
|
||||
cur_llm_skip_tts = self._llm_skip_tts
|
||||
should_skip_tts = not opts.audio_response
|
||||
toggle_skip_tts = cur_skip_tts != should_skip_tts
|
||||
toggle_skip_tts = cur_llm_skip_tts != should_skip_tts
|
||||
if toggle_skip_tts:
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=should_skip_tts)
|
||||
await self.push_frame(output_frame)
|
||||
@@ -1709,7 +1815,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
)
|
||||
await self.push_frame(text_frame)
|
||||
if toggle_skip_tts:
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=cur_skip_tts)
|
||||
output_frame = LLMConfigureOutputFrame(skip_tts=cur_llm_skip_tts)
|
||||
await self.push_frame(output_frame)
|
||||
|
||||
async def _handle_update_context(self, data: RTVIAppendToContextData):
|
||||
|
||||
@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -1027,7 +1028,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
logger.debug(f"Assistant response text added: {text}")
|
||||
|
||||
# Report the text of the assistant response.
|
||||
frame = TTSTextFrame(text)
|
||||
frame = TTSTextFrame(text, aggregated_by=AggregationType.SENTENCE)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -1062,7 +1063,9 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# TTSTextFrame would be ignored otherwise (the interruption frame
|
||||
# would have cleared the assistant aggregator state).
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
frame = TTSTextFrame(self._assistant_text_buffer)
|
||||
frame = TTSTextFrame(
|
||||
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
|
||||
)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
self._may_need_repush_assistant_text = False
|
||||
|
||||
@@ -10,7 +10,8 @@ import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -125,6 +126,72 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class CartesiaEmotion(str, Enum):
|
||||
"""Predefined Emotions supported by Cartesia."""
|
||||
|
||||
# Primary emotions supported by Cartesia
|
||||
NEUTRAL = "neutral"
|
||||
ANGRY = "angry"
|
||||
EXCITED = "excited"
|
||||
CONTENT = "content"
|
||||
SAD = "sad"
|
||||
SCARED = "scared"
|
||||
# Additional emotions supported by Cartesia
|
||||
HAPPY = "happy"
|
||||
ENTHUSIASTIC = "enthusiastic"
|
||||
ELATED = "elated"
|
||||
EUPHORIC = "euphoric"
|
||||
TRIUMPHANT = "triumphant"
|
||||
AMAZED = "amazed"
|
||||
SURPRISED = "surprised"
|
||||
FLIRTATIOUS = "flirtatious"
|
||||
JOKING_COMEDIC = "joking/comedic"
|
||||
CURIOUS = "curious"
|
||||
PEACEFUL = "peaceful"
|
||||
SERENE = "serene"
|
||||
CALM = "calm"
|
||||
GRATEFUL = "grateful"
|
||||
AFFECTIONATE = "affectionate"
|
||||
TRUST = "trust"
|
||||
SYMPATHETIC = "sympathetic"
|
||||
ANTICIPATION = "anticipation"
|
||||
MYSTERIOUS = "mysterious"
|
||||
MAD = "mad"
|
||||
OUTRAGED = "outraged"
|
||||
FRUSTRATED = "frustrated"
|
||||
AGITATED = "agitated"
|
||||
THREATENED = "threatened"
|
||||
DISGUSTED = "disgusted"
|
||||
CONTEMPT = "contempt"
|
||||
ENVIOUS = "envious"
|
||||
SARCASTIC = "sarcastic"
|
||||
IRONIC = "ironic"
|
||||
DEJECTED = "dejected"
|
||||
MELANCHOLIC = "melancholic"
|
||||
DISAPPOINTED = "disappointed"
|
||||
HURT = "hurt"
|
||||
GUILTY = "guilty"
|
||||
BORED = "bored"
|
||||
TIRED = "tired"
|
||||
REJECTED = "rejected"
|
||||
NOSTALGIC = "nostalgic"
|
||||
WISTFUL = "wistful"
|
||||
APOLOGETIC = "apologetic"
|
||||
HESITANT = "hesitant"
|
||||
INSECURE = "insecure"
|
||||
CONFUSED = "confused"
|
||||
RESIGNED = "resigned"
|
||||
ANXIOUS = "anxious"
|
||||
PANICKED = "panicked"
|
||||
ALARMED = "alarmed"
|
||||
PROUD = "proud"
|
||||
CONFIDENT = "confident"
|
||||
DISTANT = "distant"
|
||||
SKEPTICAL = "skeptical"
|
||||
CONTEMPLATIVE = "contemplative"
|
||||
DETERMINED = "determined"
|
||||
|
||||
|
||||
class CartesiaTTSService(AudioContextWordTTSService):
|
||||
"""Cartesia TTS service with WebSocket streaming and word timestamps.
|
||||
|
||||
@@ -182,6 +249,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
container: Audio container format.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
@@ -200,10 +271,18 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
text_aggregator=text_aggregator or SkipTagsAggregator([("<spell>", "</spell>")]),
|
||||
text_aggregator=text_aggregator,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
|
||||
|
||||
params = params or CartesiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
@@ -257,6 +336,27 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
# A set of Cartesia-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Cartesia spell tag."""
|
||||
return f"<spell>{text}</spell>"
|
||||
|
||||
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
|
||||
"""Convenience method to create an emotion tag."""
|
||||
return f'<emotion value="{emotion}" />'
|
||||
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f'<break time="{seconds}s" />'
|
||||
|
||||
def VOLUME_TAG(volume: float) -> str:
|
||||
"""Convenience method to create a volume tag."""
|
||||
return f'<volume ratio="{volume}" />'
|
||||
|
||||
def SPEED_TAG(speed: float) -> str:
|
||||
"""Convenience method to create a speed tag."""
|
||||
return f'<speed ratio="{speed}" />'
|
||||
|
||||
def _is_cjk_language(self, language: str) -> bool:
|
||||
"""Check if the given language is CJK (Chinese, Japanese, Korean).
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -1646,7 +1647,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
frame = TTSTextFrame(text=text)
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
|
||||
# Gemini Live text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from pipecat.adapters.services.open_ai_realtime_adapter import (
|
||||
OpenAIRealtimeLLMAdapter,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -686,7 +687,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# We receive audio transcript deltas (as opposed to text deltas) when
|
||||
# the output modality is "audio" (the default)
|
||||
if evt.delta:
|
||||
frame = TTSTextFrame(evt.delta)
|
||||
frame = TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE)
|
||||
# OpenAI Realtime text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -17,6 +17,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -652,7 +653,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
|
||||
@@ -113,6 +113,10 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
params: Additional configuration parameters.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
@@ -123,10 +127,17 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
|
||||
|
||||
params = params or RimeTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
@@ -152,6 +163,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._context_id = None # Tracks current turn
|
||||
self._receive_task = None
|
||||
self._cumulative_time = 0 # Accumulates time across messages
|
||||
self._extra_msg_fields = {} # Extra fields for next message
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -181,6 +193,31 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
# A set of Rime-specific helpers for text transformations
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Rime spell function."""
|
||||
return f"spell({text})"
|
||||
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f"<{seconds * 1000}>"
|
||||
|
||||
def PRONOUNCE(self, text: str, word: str, phoneme: str) -> str:
|
||||
"""Convenience method to support Rime's custom pronunciations feature.
|
||||
|
||||
https://docs.rime.ai/api-reference/custom-pronunciation
|
||||
"""
|
||||
self._extra_msg_fields["phonemizeBetweenBrackets"] = True
|
||||
return text.replace(word, f"{phoneme}")
|
||||
|
||||
def INLINE_SPEED(self, text: str, speed: float) -> str:
|
||||
"""Convenience method to support inline speeds."""
|
||||
if not self._extra_msg_fields:
|
||||
self._extra_msg_fields = {}
|
||||
speed_vals = self._extra_msg_fields.get("inlineSpeedAlpha", "").split(",")
|
||||
self._extra_msg_fields["inlineSpeedAlpha"] = ",".join(speed_vals + [str(speed)])
|
||||
return f"[{text}]"
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
@@ -193,7 +230,11 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Rime API."""
|
||||
return {"text": text, "contextId": self._context_id}
|
||||
msg = {"text": text, "contextId": self._context_id}
|
||||
if self._extra_msg_fields:
|
||||
msg |= self._extra_msg_fields
|
||||
self._extra_msg_fields = {}
|
||||
return msg
|
||||
|
||||
def _build_clear_msg(self) -> dict:
|
||||
"""Build clear operation message."""
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
@@ -23,6 +25,8 @@ from typing import (
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -101,6 +105,16 @@ class TTSService(AIService):
|
||||
sample_rate: Optional[int] = None,
|
||||
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
# Types of text aggregations that should not be spoken.
|
||||
skip_aggregator_types: Optional[List[str]] = [],
|
||||
# A list of callables to transform text before just before sending it to TTS.
|
||||
# Each callable takes the aggregated text and its type, and returns the transformed text.
|
||||
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
|
||||
text_transforms: Optional[
|
||||
List[
|
||||
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
|
||||
]
|
||||
] = None,
|
||||
# Text filter executed after text has been aggregated.
|
||||
text_filters: Optional[Sequence[BaseTextFilter]] = None,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
@@ -120,6 +134,16 @@ class TTSService(AIService):
|
||||
pause_frame_processing: Whether to pause frame processing during audio generation.
|
||||
sample_rate: Output sample rate for generated audio.
|
||||
text_aggregator: Custom text aggregator for processing incoming text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
skip_aggregator_types: List of aggregation types that should not be spoken.
|
||||
text_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
|
||||
text_filters: Sequence of text filters to apply after aggregation.
|
||||
text_filter: Single text filter (deprecated, use text_filters).
|
||||
|
||||
@@ -142,6 +166,21 @@ class TTSService(AIService):
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
if text_aggregator:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
|
||||
self._text_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = text_transforms or []
|
||||
# TODO: Deprecate _text_filters when added to LLMTextProcessor
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
@@ -298,6 +337,39 @@ class TTSService(AIService):
|
||||
await self.cancel_task(self._stop_frame_task)
|
||||
self._stop_frame_task = None
|
||||
|
||||
def add_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
|
||||
the function should handle all text before sending to TTS.
|
||||
"""
|
||||
self._text_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._text_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._text_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
@@ -353,6 +425,8 @@ class TTSService(AIService):
|
||||
and frame.skip_tts
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AggregatedTextFrame):
|
||||
await self._push_tts_frames(frame)
|
||||
elif (
|
||||
isinstance(frame, TextFrame)
|
||||
and not isinstance(frame, InterimTranscriptionFrame)
|
||||
@@ -368,10 +442,10 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
sentence = self._text_aggregator.text
|
||||
aggregate = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
await self._push_tts_frames(sentence)
|
||||
await self._push_tts_frames(AggregatedTextFrame(aggregate.text, aggregate.type))
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -380,7 +454,7 @@ class TTSService(AIService):
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
await self._push_tts_frames(frame.text)
|
||||
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
@@ -472,13 +546,24 @@ class TTSService(AIService):
|
||||
text: Optional[str] = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
aggregated_by = "token"
|
||||
else:
|
||||
text = await self._text_aggregator.aggregate(frame.text)
|
||||
aggregate = await self._text_aggregator.aggregate(frame.text)
|
||||
if aggregate:
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(AggregatedTextFrame(text, aggregated_by))
|
||||
|
||||
async def _push_tts_frames(self, src_frame: AggregatedTextFrame):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
if type in self._skip_aggregator_types:
|
||||
await self.push_frame(src_frame)
|
||||
return
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
# Remove leading newlines only
|
||||
text = text.lstrip("\n")
|
||||
|
||||
@@ -499,15 +584,39 @@ class TTSService(AIService):
|
||||
await filter.reset_interruption()
|
||||
text = await filter.filter(text)
|
||||
|
||||
if text:
|
||||
await self.process_generator(self.run_tts(text))
|
||||
if not text.strip():
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# To support use cases that may want to know the text before it's spoken, we
|
||||
# push the AggregatedTextFrame version before transforming and sending to TTS.
|
||||
# However, we do not want to add this text to the assistant context until it
|
||||
# is spoken, so we set append_to_context to False.
|
||||
src_frame.append_to_context = False
|
||||
await self.push_frame(src_frame)
|
||||
|
||||
# Note: Text transformations are meant to only affect the text sent to the TTS for
|
||||
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
|
||||
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
|
||||
# services that support word-level timestamps, this CAN affect the resulting context
|
||||
# since the TTSTextFrames are generated from the TTS output stream
|
||||
transformed_text = text
|
||||
for aggregation_type, transform in self._text_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
transformed_text = await transform(transformed_text, type)
|
||||
await self.process_generator(self.run_tts(transformed_text))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
frame = TTSTextFrame(text)
|
||||
# In TTS services that support word timestamps, the TTSTextFrames
|
||||
# are pushed as words are spoken. However, in the case where the TTS service
|
||||
# does not support word timestamps (i.e. _push_text_frames is True), we send
|
||||
# the original (non-transformed) text after the TTS generation has completed.
|
||||
# This way, if we are interrupted, the text is not added to the assistant
|
||||
# context and the context that IS added does not include TTS-specific tags
|
||||
# or transformations.
|
||||
frame = TTSTextFrame(text, aggregated_by=type)
|
||||
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -635,7 +744,7 @@ class WordTTSService(TTSService):
|
||||
frame = TTSStoppedFrame()
|
||||
frame.pts = last_pts
|
||||
else:
|
||||
frame = TTSTextFrame(word)
|
||||
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
|
||||
frame.pts = self._initial_word_timestamp + timestamp
|
||||
if frame:
|
||||
last_pts = frame.pts
|
||||
|
||||
@@ -203,8 +203,16 @@ async def run_test(
|
||||
if not isinstance(frame, EndFrame) or not send_end_frame:
|
||||
received_down_frames.append(frame)
|
||||
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
print("expected DOWN frames =", expected_down_frames)
|
||||
down_frames_printed = "["
|
||||
for frame in received_down_frames:
|
||||
down_frames_printed += f"{frame.__class__.__name__}, "
|
||||
down_frames_printed += "]"
|
||||
expected_frames_printed = "["
|
||||
for frame in expected_down_frames:
|
||||
expected_frames_printed += f"{frame.__name__}, "
|
||||
expected_frames_printed += "]"
|
||||
print("received DOWN frames =", down_frames_printed)
|
||||
print("expected DOWN frames =", expected_frames_printed)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
|
||||
@@ -12,9 +12,46 @@ aggregated text should be sent for speech synthesis.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aggregation:
|
||||
"""Data class representing aggregated text and its type.
|
||||
|
||||
An Aggregation object is created whenever a stream of text is aggregated by
|
||||
a text aggregator. It contains the aggregated text and a type indicating
|
||||
the nature of the aggregation.
|
||||
|
||||
Parameters:
|
||||
text: The aggregated text content.
|
||||
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the aggregation.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the type and text of the aggregation.
|
||||
"""
|
||||
return f"Aggregation by {self.type}: {self.text}"
|
||||
|
||||
|
||||
class BaseTextAggregator(ABC):
|
||||
"""Base class for text aggregators in the Pipecat framework.
|
||||
|
||||
@@ -30,7 +67,7 @@ class BaseTextAggregator(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Subclasses must implement this property to return the text that has
|
||||
@@ -42,12 +79,13 @@ class BaseTextAggregator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate the specified text with the currently accumulated text.
|
||||
|
||||
This method should be implemented to define how the new text contributes
|
||||
to the aggregation process. It returns the updated aggregated text if
|
||||
it's ready to be processed, or None otherwise.
|
||||
to the aggregation process. It returns the aggregated text and a string
|
||||
describing how it was aggregated if it's ready to be processed,
|
||||
or None otherwise.
|
||||
|
||||
Subclasses should implement their specific logic for:
|
||||
|
||||
|
||||
@@ -8,19 +8,41 @@
|
||||
|
||||
This module provides an aggregator that identifies and processes content between
|
||||
pattern pairs (like XML tags or custom delimiters) in streaming text, with
|
||||
support for custom handlers and configurable pattern removal.
|
||||
support for custom handlers and configurable actions for when a pattern is found.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Awaitable, Callable, Optional, Tuple
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class PatternMatch:
|
||||
class MatchAction(Enum):
|
||||
"""Actions to take when a pattern pair is matched.
|
||||
|
||||
Parameters:
|
||||
REMOVE: The text along with its delimiters will be removed from the streaming text.
|
||||
Sentence aggregation will continue on as if this text did not exist.
|
||||
KEEP: The delimiters will be removed, but the content between them will be kept.
|
||||
Sentence aggregation will continue on with the internal text included.
|
||||
AGGREGATE: The delimiters will be removed and the content between will be treated
|
||||
as a separate aggregation. Any text before the start of the pattern will be
|
||||
returned early, whether or not a complete sentence was found. Then the pattern
|
||||
will be returned. Then the aggregation will continue on sentence matching after
|
||||
the closing delimiter is found. The content between the delimiters is not
|
||||
aggregated by sentence. It is aggregated as one single block of text.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP = "keep"
|
||||
AGGREGATE = "aggregate"
|
||||
|
||||
|
||||
class PatternMatch(Aggregation):
|
||||
"""Represents a matched pattern pair with its content.
|
||||
|
||||
A PatternMatch object is created when a complete pattern pair is found
|
||||
@@ -29,25 +51,25 @@ class PatternMatch:
|
||||
content between the patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern_id: str, full_match: str, content: str):
|
||||
def __init__(self, content: str, type: str, full_match: str):
|
||||
"""Initialize a pattern match.
|
||||
|
||||
Args:
|
||||
pattern_id: The identifier of the matched pattern pair.
|
||||
type: The type of the matched pattern pair. It should be representative
|
||||
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
|
||||
full_match: The complete text including start and end patterns.
|
||||
content: The text content between the start and end patterns.
|
||||
"""
|
||||
self.pattern_id = pattern_id
|
||||
super().__init__(text=content, type=type)
|
||||
self.full_match = full_match
|
||||
self.content = content
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the pattern match.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the pattern ID and content.
|
||||
A descriptive string showing the pattern type and content.
|
||||
"""
|
||||
return f"PatternMatch(id={self.pattern_id}, content={self.content})"
|
||||
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
|
||||
|
||||
|
||||
class PatternPairAggregator(BaseTextAggregator):
|
||||
@@ -55,16 +77,21 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
This aggregator buffers text until it can identify complete pattern pairs
|
||||
(defined by start and end patterns), processes the content between these
|
||||
patterns using registered handlers, and returns text at sentence boundaries.
|
||||
It's particularly useful for processing structured content in streaming text,
|
||||
such as XML tags, markdown formatting, or custom delimiters.
|
||||
patterns using registered handlers. By default, its aggregation method
|
||||
returns text at sentence boundaries, and remove the content found between
|
||||
any matched patterns. However, matched patterns can also be configured to
|
||||
returned as a separate aggregation object containing the content between
|
||||
their start and end patterns or left in, so that only the delimiters are
|
||||
removed and a callback can be triggered.
|
||||
|
||||
This aggregator is particularly useful for processing structured content in
|
||||
streaming text, such as XML tags, markdown formatting, or custom delimiters.
|
||||
|
||||
The aggregator ensures that patterns spanning multiple text chunks are
|
||||
correctly identified and handles cases where patterns contain sentence
|
||||
boundaries.
|
||||
correctly identified.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the pattern pair aggregator.
|
||||
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
@@ -75,16 +102,23 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._handlers = {}
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Get the currently buffered text.
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return self._text
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start:
|
||||
return Aggregation(self._text, pattern_start[1].get("type", AggregationType.SENTENCE))
|
||||
return Aggregation(self._text, AggregationType.SENTENCE)
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
def add_pattern(
|
||||
self,
|
||||
type: str,
|
||||
start_pattern: str,
|
||||
end_pattern: str,
|
||||
action: MatchAction = MatchAction.REMOVE,
|
||||
) -> "PatternPairAggregator":
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
@@ -93,41 +127,94 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
the end pattern, and treat the content between them as a match.
|
||||
|
||||
Args:
|
||||
pattern_id: Unique identifier for this pattern pair.
|
||||
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
|
||||
those are reserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
remove_match: Whether to remove the matched content from the text.
|
||||
action: What to do when a complete pattern is matched:
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._patterns[pattern_id] = {
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
|
||||
raise ValueError(
|
||||
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
|
||||
)
|
||||
self._patterns[type] = {
|
||||
"start": start_pattern,
|
||||
"end": end_pattern,
|
||||
"remove_match": remove_match,
|
||||
"type": type,
|
||||
"action": action,
|
||||
}
|
||||
return self
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
):
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use `add_pattern` with a type and MatchAction instead.
|
||||
|
||||
This method calls `add_pattern` setting type with the provided pattern_id and action
|
||||
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
|
||||
|
||||
Args:
|
||||
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
|
||||
as those arereserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
|
||||
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once")
|
||||
warnings.warn(
|
||||
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
|
||||
" removed in a future version. Use add_pattern with a type and MatchAction instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
|
||||
return self.add_pattern(
|
||||
type=pattern_id,
|
||||
start_pattern=start_pattern,
|
||||
end_pattern=end_pattern,
|
||||
action=action,
|
||||
)
|
||||
|
||||
def on_pattern_match(
|
||||
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
) -> "PatternPairAggregator":
|
||||
"""Register a handler for when a pattern pair is matched.
|
||||
|
||||
The handler will be called whenever a complete match for the
|
||||
specified pattern ID is found in the text.
|
||||
specified type is found in the text.
|
||||
|
||||
Args:
|
||||
pattern_id: ID of the pattern pair to match.
|
||||
type: The type of the pattern pair to trigger the handler.
|
||||
handler: Async function to call when pattern is matched.
|
||||
The function should accept a PatternMatch object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._handlers[pattern_id] = handler
|
||||
self._handlers[type] = handler
|
||||
return self
|
||||
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]:
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
|
||||
"""Process all complete pattern pairs in the text.
|
||||
|
||||
Searches for all complete pattern pairs in the text, calls the
|
||||
@@ -137,19 +224,19 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to process for pattern matches.
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_text, was_modified) where:
|
||||
Tuple of (all_matches, processed_text) where:
|
||||
|
||||
- processed_text is the text after processing patterns
|
||||
- was_modified indicates whether any changes were made
|
||||
- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
|
||||
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
|
||||
"""
|
||||
all_matches = []
|
||||
processed_text = text
|
||||
modified = False
|
||||
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
for type, pattern_info in self._patterns.items():
|
||||
# Escape special regex characters in the patterns
|
||||
start = re.escape(pattern_info["start"])
|
||||
end = re.escape(pattern_info["end"])
|
||||
remove_match = pattern_info["remove_match"]
|
||||
action = pattern_info["action"]
|
||||
|
||||
# Create regex to match from start pattern to end pattern
|
||||
# The .*? is non-greedy to handle nested patterns
|
||||
@@ -164,25 +251,24 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
full_match = match.group(0) # Full match including patterns
|
||||
|
||||
# Create pattern match object
|
||||
pattern_match = PatternMatch(
|
||||
pattern_id=pattern_id, full_match=full_match, content=content
|
||||
)
|
||||
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
|
||||
|
||||
# Call the appropriate handler if registered
|
||||
if pattern_id in self._handlers:
|
||||
if type in self._handlers:
|
||||
try:
|
||||
await self._handlers[pattern_id](pattern_match)
|
||||
await self._handlers[type](pattern_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
|
||||
logger.error(f"Error in pattern handler for {type}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
if remove_match:
|
||||
if action == MatchAction.REMOVE:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
modified = True
|
||||
else:
|
||||
all_matches.append(pattern_match)
|
||||
|
||||
return processed_text, modified
|
||||
return all_matches, processed_text
|
||||
|
||||
def _has_incomplete_patterns(self, text: str) -> bool:
|
||||
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
|
||||
"""Check if text contains incomplete pattern pairs.
|
||||
|
||||
Determines whether the text contains any start patterns without
|
||||
@@ -192,9 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to check for incomplete patterns.
|
||||
|
||||
Returns:
|
||||
True if there are incomplete patterns, False otherwise.
|
||||
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
|
||||
or None if no patterns are found or all patterns are complete.
|
||||
"""
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
for type, pattern_info in self._patterns.items():
|
||||
start = pattern_info["start"]
|
||||
end = pattern_info["end"]
|
||||
|
||||
@@ -203,12 +290,16 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
end_count = text.count(end)
|
||||
|
||||
# If there are more starts than ends, we have incomplete patterns
|
||||
# Again, this is written generically but there only ever should
|
||||
# be one pattern active at a time, so the counts should be 0 or 1.
|
||||
# Which is why we base the return on the first found.
|
||||
if start_count > end_count:
|
||||
return True
|
||||
start_index = text.find(start)
|
||||
return [start_index, pattern_info]
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[PatternMatch]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete pattern
|
||||
@@ -227,16 +318,34 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._text += text
|
||||
|
||||
# Process any complete patterns in the buffer
|
||||
processed_text, modified = await self._process_complete_patterns(self._text)
|
||||
patterns, processed_text = await self._process_complete_patterns(self._text)
|
||||
|
||||
# Only update the buffer if modifications were made
|
||||
if modified:
|
||||
self._text = processed_text
|
||||
self._text = processed_text
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
if self._has_incomplete_patterns(self._text):
|
||||
# Still waiting for complete patterns
|
||||
return None
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
return None
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
@@ -244,7 +353,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return result
|
||||
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@ text processing scenarios.
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class SimpleTextAggregator(BaseTextAggregator):
|
||||
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
self._text = ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(self._text, AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text and return completed sentences.
|
||||
|
||||
Adds the new text to the buffer and checks for end-of-sentence markers.
|
||||
@@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
result = self._text[:eos_end_marker]
|
||||
self._text = self._text[eos_end_marker:]
|
||||
|
||||
return result
|
||||
return Aggregation(result, AggregationType.SENTENCE) if result else None
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the text buffer.
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class SkipTagsAggregator(BaseTextAggregator):
|
||||
@@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(self._text, AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text while respecting tag boundaries.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete
|
||||
@@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return result
|
||||
return Aggregation(result, AggregationType.SENTENCE)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -7,30 +7,42 @@
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
|
||||
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.aggregator = PatternPairAggregator()
|
||||
self.test_handler = AsyncMock()
|
||||
self.code_handler = AsyncMock()
|
||||
|
||||
# Add a test pattern
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="test_pattern",
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
remove_match=True,
|
||||
)
|
||||
self.aggregator.add_pattern(
|
||||
type="code_pattern",
|
||||
start_pattern="<code>",
|
||||
end_pattern="</code>",
|
||||
action=MatchAction.AGGREGATE,
|
||||
)
|
||||
|
||||
# Register the mock handler
|
||||
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
|
||||
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
||||
|
||||
async def test_pattern_match_and_removal(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
@@ -39,20 +51,49 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.pattern_id, "test_pattern")
|
||||
self.assertEqual(call_args.type, "test_pattern")
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.content, "pattern content")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
self.assertEqual(result, "Hello !")
|
||||
self.assertEqual(result.text, "Hello !")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result, " This is another sentence.")
|
||||
self.assertEqual(result.text, " This is another sentence.")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_pattern_match_and_aggregate(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
self.assertEqual(result.text, "Here is code ")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.code_handler.assert_called_once()
|
||||
call_args = self.code_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.type, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(result.text, "pattern content")
|
||||
self.assertEqual(result.type, "code_pattern")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result.text, " This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
@@ -65,26 +106,30 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_not_called()
|
||||
|
||||
# Buffer should contain the incomplete text
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Reset and confirm buffer is cleared
|
||||
await self.aggregator.reset()
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_multiple_patterns(self):
|
||||
# Set up multiple patterns and handlers
|
||||
voice_handler = AsyncMock()
|
||||
emphasis_handler = AsyncMock()
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
|
||||
self.aggregator.add_pattern(
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="emphasis",
|
||||
self.aggregator.add_pattern(
|
||||
type="emphasis",
|
||||
start_pattern="<em>",
|
||||
end_pattern="</em>",
|
||||
remove_match=False, # Keep emphasis tags
|
||||
action=MatchAction.KEEP, # Keep emphasis tags
|
||||
)
|
||||
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
@@ -97,19 +142,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Both handlers should be called with correct data
|
||||
voice_handler.assert_called_once()
|
||||
voice_match = voice_handler.call_args[0][0]
|
||||
self.assertEqual(voice_match.pattern_id, "voice")
|
||||
self.assertEqual(voice_match.content, "female")
|
||||
self.assertEqual(voice_match.type, "voice")
|
||||
self.assertEqual(voice_match.text, "female")
|
||||
|
||||
emphasis_handler.assert_called_once()
|
||||
emphasis_match = emphasis_handler.call_args[0][0]
|
||||
self.assertEqual(emphasis_match.pattern_id, "emphasis")
|
||||
self.assertEqual(emphasis_match.content, "very")
|
||||
self.assertEqual(emphasis_match.type, "emphasis")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_handle_interruption(self):
|
||||
# Start with incomplete pattern
|
||||
@@ -120,7 +165,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.handle_interruption()
|
||||
|
||||
# Buffer should be cleared
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
# Handler should not have been called
|
||||
self.test_handler.assert_not_called()
|
||||
@@ -138,10 +183,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Handler should be called with entire content
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
|
||||
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
||||
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result, "Hello Final sentence.")
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
@@ -13,6 +13,7 @@ import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
@@ -74,6 +75,7 @@ async def test_run_piper_tts_success(aiohttp_client):
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
AggregatedTextFrame,
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -121,7 +123,7 @@ async def test_run_piper_tts_error(aiohttp_client):
|
||||
TTSSpeakFrame(text="Error case."),
|
||||
]
|
||||
|
||||
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
|
||||
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
|
||||
@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert self.aggregator.text == "Hello "
|
||||
assert self.aggregator.text.text == "Hello "
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text == ""
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
|
||||
assert self.aggregator.text == ""
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
|
||||
assert await self.aggregator.aggregate("you?") == " How are you?"
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert self.aggregator.text.text == " How are "
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == " How are you?"
|
||||
|
||||
@@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# No tags involved, aggregate at end of sentence.
|
||||
result = await self.aggregator.aggregate("Hello Pipecat!")
|
||||
self.assertEqual(result, "Hello Pipecat!")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "Hello Pipecat!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_basic_tags(self):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, avoid aggregation during tags.
|
||||
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_streaming_tags(self):
|
||||
await self.aggregator.reset()
|
||||
@@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Tags involved, stream small chunk of texts.
|
||||
result = await self.aggregator.aggregate("My email is <sp")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <sp")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <sp")
|
||||
|
||||
result = await self.aggregator.aggregate("ell>foo.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
|
||||
|
||||
result = await self.aggregator.aggregate("bar@pipecat.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
|
||||
|
||||
result = await self.aggregator.aggregate("ai</spe")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
result = await self.aggregator.aggregate("ll>.")
|
||||
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
@@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -130,11 +131,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(), # Wait for StartedSpeaking to process
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
TTSTextFrame(text="How"),
|
||||
TTSTextFrame(text="are"),
|
||||
TTSTextFrame(text="you?"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(), # Wait for text frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -195,9 +196,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text=""), # Empty text
|
||||
TTSTextFrame(text=" "), # Just whitespace
|
||||
TTSTextFrame(text="\n"), # Just newline
|
||||
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
|
||||
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
|
||||
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
@@ -235,14 +236,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world!"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(),
|
||||
InterruptionFrame(), # User interrupts here
|
||||
SleepFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="New"),
|
||||
TTSTextFrame(text="response"),
|
||||
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -299,8 +300,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
|
||||
@@ -338,8 +339,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="world"),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(), # Ensure messages are processed
|
||||
CancelFrame(),
|
||||
]
|
||||
@@ -401,8 +402,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Assistant"),
|
||||
TTSTextFrame(text="message"),
|
||||
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
@@ -439,7 +440,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Test the specific pattern shared
|
||||
def make_tts_text_frame(text: str) -> TTSTextFrame:
|
||||
frame = TTSTextFrame(text=text)
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
|
||||
Reference in New Issue
Block a user