Compare commits

...

23 Commits

Author SHA1 Message Date
mattie ruth backman
e8640d84ae test fix now that we send an aggregated text frame for non word-by-word tts services 2025-11-14 17:13:08 -05:00
mattie ruth backman
23e4e29999 CHANGELOG fixes 2025-11-14 13:57:49 -05:00
mattie ruth backman
713b488bb6 Final PR Feedback changes 2025-11-14 13:54:20 -05:00
mattie ruth backman
71b87fd420 add transformers to initialization args 2025-11-14 13:54:20 -05:00
mattie ruth backman
3f269f9834 Add backwards compatibility for add_pattern_pair 2025-11-14 13:54:20 -05:00
mattie ruth backman
4c698777f3 PR Feedback 2025-11-14 13:54:20 -05:00
mattie ruth backman
5ca04ad741 CHANGELOG updates 2025-11-14 13:54:20 -05:00
mattie ruth backman
9a3902a82c Introducing a new processor: LLMTextProcessor
This new processor wraps an aggregator that can be overridden for the purposes
of customizing how the llm output gets categorized and handled in the pipeline.

Along with this, we are deprecating the ability to override the default
aggregator in the TTS to encourage use of the LLMTextProcessor in cases where
custome aggregation is needed.

This PR also:
- Introduces TTSService.transform_aggregation_type():
  This function provides the ability to provide callbacks to the TTS to
  transform text based on its aggregated type prior to sending the text to the
  underlying TTS service. This makes it possible to do things like introduce
  TTS-specific tags for spelling or emotion or change the pronunciation of
  something on the fly.
- Introduces to the RTVIObserver:
  - new init field skip_aggregator_types: A way to provide a list of aggregation
    types that should not be included in bot-output (or tts-text) messages
  - transform_aggregation_type(): Same as with TTSService, this allows you
    to provide a callback to transform text being sent as bot-output before
    it gets sent.
2025-11-14 13:54:20 -05:00
mattie ruth backman
8ab0c92681 Rename AggregatedLLMTextFrame to AggregatedTextFrame and made built-in types an enum 2025-11-14 13:54:20 -05:00
mattie ruth backman
124f147a37 CHANGELOG improvements 2025-11-14 13:54:18 -05:00
mattie ruth backman
ed808a9246 Fix new test and str version of PatternMatch 2025-11-14 13:53:23 -05:00
mattie ruth backman
e9de9daf8c Update PatternPairAggregator patterns to replace pattern_id with type to simplify the API 2025-11-14 13:53:23 -05:00
mattie ruth backman
82b9c4f0b6 various PR Review fixes:
1. Added support for turning off bot-output messages with the bot_output_enabled flag
2. Cleaned up logic and comments around TTSService:_push_tts_frames to hopefully make
   it easier to understand
3. Other minor cleanup
2025-11-14 13:53:23 -05:00
mattie ruth backman
5dfe20be91 Update Changelog 2025-11-14 13:53:22 -05:00
mattie ruth backman
0d2c5286fa Support customization over the way the assistant aggregator aggregates LLMTextFrames when tts_skip is on 2025-11-14 13:51:45 -05:00
mattie ruth backman
29417ba44d Move aggregation logic when skip_tts is on to the assistant aggregator 2025-11-14 13:51:45 -05:00
mattie ruth backman
bc6a9cac26 Add append_to_context boolean field to TextFrames
This allows any given TextFrame to be marked in a way such that it does not get
added to the context.

Specifically, this fixes a problem with the new AggregatedTextFrames where we
need to send LLM text both in an aggregated form as well as word-by-word but
avoid duplicating the text in the context.
2025-11-14 13:51:45 -05:00
mattie ruth backman
8a90decbc0 codepilot review fixes 2025-11-14 13:51:45 -05:00
mattie ruth backman
ccca6e8d81 Make the PatternPair action an Enum 2025-11-14 13:51:45 -05:00
mattie ruth backman
e6dc1a510d Introduce AggregatedLLMTextFrame to allow a separation of TTSTextFrame, indicating a spoken frame vs other aggregated, non-spoken frames 2025-11-14 13:51:45 -05:00
mattie ruth backman
69945c5e0d Various fixes:
1. Fixed pattern_pair_aggregator to support various ways of handling
   pattern matches (remove, keep and just trigger a callback, or
   aggregate
2. Fixed ivr_navigator use of pattern_pair_aggregator
3. Test fixes -- Tests now pass
2025-11-14 13:51:45 -05:00
mattie ruth backman
5c8635570d test fixes 2025-11-14 13:51:45 -05:00
mattie ruth backman
fe9aa3383e Adding support for new bot-output RTVI Message:
1. TTSTextFrames now include metadata about whether the text was spoken
   or not along with a type string to describe what the text represents:
   ex. "sentence", "word", "custom aggregation"
2. Expanded how aggregators work so that the aggregate method returns
   aggregated text along with the type of aggregation used to create it
3. Deprecated the RTVI bot-transcription event in lieu of...
4. Introduced support for a new bot-output event. This event is meant
   to be the one stop shop for communicating what the bot actually "says".
   It is based off TTSTextFrames to communicate both sentence by sentence
   (or whatever aggregation is used) as well as word by word. In addition,
   it will include LLMTextFrames, aggregated by sentence when tts is
   turned off (i.e. skip_tts is true).

Resolves pipecat-ai/pipecat-client-web#158
2025-11-14 13:51:45 -05:00
25 changed files with 1018 additions and 195 deletions

View File

@@ -16,6 +16,82 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
services that subclass `TTSService` can indicate whether the text in the
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
- Introduced new `AggregatedTextFrame` type to support representing a best effort of
the perceived llm output whether or not it is processed by the TTS. This new frame
type includes the field `aggregated_by` to represent the conceptual format by which
the given text is aggregated. `TTSTextFrame`s now inherit from `AggregatedTextFrame`.
With this inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate
the perceived output and determine whether or not the text was spoken based on if that
frame is also a `TTSTextFrame`. (See bullet below on new `bot-output` which takes
advantage of this)
- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how
LLMTextFrames should be aggregated and considered. It's purpose is to turn
`LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still
aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you
wish to override how the llm text is aggregated, you should no longer override the
TTS's internal aggregator, but instead, insert this processor between your LLM and
TTS in the pipeline.
- New `bot-output` RTVI message to represent what the bot actually "says".
- The `RTVIObserver` now emits `bot-output` messages based off the new `AggregatedTextFrame`s
(`bot-tts-text` and `bot-llm-text` are still supported and generated, but `bot-transcript` is
now deprecated in lieu of this new, more thorough, message).
- The new `RTVIBotOutputMessage` includes the fields:
- `spoken`: A boolean indicating whether the text was spoken by TTS
- `aggregated_by`: A string representing how the text was aggregated ("sentence", "word",
"my custom aggregation")
- Introduced new fields to `RTVIObserver` to support the new `bot-output` messaging:
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output messages.
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that match
aggregation types that should not be included in bot-output messages. (Ex. `credit_card`)
- Introduced new methods, `add_text_transformer()` and `remove_text_transformer()`, to `RTVIObserver` to support providing (and subsequently removing)
callbacks for various types of aggregations (or all aggregations with `*`) that can modify the
text before being sent as a `bot-output` or `tts-text` message. (Think obscuring the credit card
or inserting extra detail the client might want that the context doesn't need.)
- Updated the base aggregator type:
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
aggregation")
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
(instead of `Optional[str]`). To update:
```
aggregation = myAggregator.aggregate(text)
if (aggregation):
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
```
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
produce/consume `Aggregation` objects.
- Augmented the `PatternPairAggregator`:
- Introduced a new, preferred version of `add_pattern` to support a new option for treating a
match as a separate aggregation returned from `aggregate()`. This replaces the now
deprecated `add_pattern_pair` method and you provide a `MatchAction` in lieu of the `remove_match` field.
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization for how
a match should be handled.
- `REMOVE`: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
- `KEEP`: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
- `AGGREGATE`: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
- `PatternMatch` now extends `Aggregation` and provides richer info to handlers.
- **BREAKING**: The `PatternMatch` type returned to handlers registered via `on_pattern_match`
has been updated to subclass from the new `Aggregation` type, which means that `content`
has been replaced with `text` and `pattern_id` has been replaced with `type`:
```
async dev on_match_tag(match: PatternMatch):
pattern = match.type # instead of match.pattern_id
text = match.text # instead of match.content
```
### Changed
- Updated all STT and TTS services to use consistent error handling pattern with
@@ -33,11 +109,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated language mappings for the Google and Gemini TTS services to match
official documentation.
- `TextFrame` new field `append_to_context` used to indicate if the encompassing
text should be added to the LLM context (by the LLM assistant aggregator). It
defaults to `True`.
- TTS flow respects aggregation metadata
- `TTSService` accepts a new `skip_aggregator_types` to avoid speaking certain aggregation types
(now determined/returned by the aggregator)
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when either an
aggregation occurs that should not be spoken or when the TTS service supports word-by-word
timestamping. In the latter case, the `TTSService` preliminarily generates an
`AggregatedTextFrame`, aggregated by sentence to generate the full sentence content as early
as possible.
- Introduced a new methods, `add_text_transformer()` and `remove_text_transformer()`:
These functions introduce the ability to provide (and subsequently remove) callbacks to the TTS to transform text based on
its aggregated type prior to sending the text to the underlying TTS service. This makes it
possible to do things like introduce TTS-specific tags for spelling or emotion or change the
pronunciation of something on the fly.
### Deprecated
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
`credentials` or `credentials_path` instead for Google Cloud authentication.
- The RTVI `bot-transcription` event is deprecated in favor of the new `bot-output`
message which is the canonical representation of bot output (spoken or not). The code
still emits a transcription message for backwards compatibility while transition occurs.
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
`LLMTextProcessor`. TTSServices still have an internal aggregator for support of default
behavior, but if you want to override the aggregation behavior, you should use the new
processor.
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a `pattern_id`
and `remove_match` field in favor of the new `add_pattern` method which takes a `type` and an
`action`
### Fixed
- Fixed subtle issue of assistant context messages ending up with double spaces

View File

@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
load_dotenv(override=True)
@@ -106,16 +110,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
pattern_aggregator = PatternPairAggregator()
# Add pattern for voice switching
pattern_aggregator.add_pattern_pair(
pattern_id="voice_tag",
pattern_aggregator.add_pattern(
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
remove_match=True,
action=MatchAction.REMOVE, # Remove tags from final text
)
# Register handler for voice switching
async def on_voice_tag(match: PatternMatch):
voice_name = match.content.strip().lower()
voice_name = match.text.strip().lower()
if voice_name in VOICE_IDS:
# First flush any existing audio to finish the current context
await tts.flush_audio()
@@ -125,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
else:
logger.warning(f"Unknown voice: {voice_name}")
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))

View File

@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import LLMService
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
class IVRStatus(Enum):
@@ -114,15 +118,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", remove_match=True)
self._aggregator.add_pattern("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", remove_match=True)
self._aggregator.add_pattern("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", remove_match=True)
self._aggregator.add_pattern("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -159,7 +163,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing DTMF content.
"""
value = match.content
value = match.text
logger.debug(f"DTMF detected: {value}")
try:
@@ -180,7 +184,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing IVR status content.
"""
status = match.content
status = match.text
logger.trace(f"IVR status detected: {status}")
# Convert string to enum, with validation
@@ -211,7 +215,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing mode content.
"""
mode = match.content
mode = match.text
logger.debug(f"Mode detected: {mode}")
if mode == "conversation":
await self._handle_conversation()

View File

@@ -12,6 +12,7 @@ and LLM processing.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
@@ -337,11 +338,14 @@ class TextFrame(DataFrame):
# mandatory fields of theirs to have defaults to preserve
# non-default-before-default argument order)
includes_inter_frame_spaces: bool = field(init=False)
# Whether this text frame should be appended to the LLM context.
append_to_context: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.includes_inter_frame_spaces = False
self.append_to_context = True
def __str__(self):
pts = format_pts(self.pts)
@@ -355,8 +359,32 @@ class LLMTextFrame(TextFrame):
pass
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class TTSTextFrame(TextFrame):
class AggregatedTextFrame(TextFrame):
"""Text frame representing an aggregation of TextFrames.
This frame contains multiple TextFrames aggregated together for processing
or output along with a field to indicate how they are aggregated.
Parameters:
aggregated_by: Method used to aggregate the text frames.
"""
aggregated_by: AggregationType | str
@dataclass
class TTSTextFrame(AggregatedTextFrame):
"""Text frame generated by Text-to-Speech services."""
pass

View File

@@ -1001,7 +1001,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
if not self._started or not frame.append_to_context:
return
if self._params.expect_stripped_words:

View File

@@ -814,7 +814,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
if not self._started or not frame.append_to_context:
return
# Make sure we really have text (spaces count, too!)

View File

@@ -0,0 +1,106 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""LLM text processor module for processing and aggregating raw LLM output text.
This processor will convert LLMTextFrames into AggregatedTextFrames based on the
configured text aggregator. Using the customizable aggregator, it provides
functionality to handle or manipulate LLM text frames before they are sent to other
components such as TTS services or context aggregators. It can be used to pre-aggregate
and categorize, modify, or filter direct output tokens from the LLM.
"""
from typing import Optional
from pipecat.frames.frames import (
AggregatedTextFrame,
EndFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMTextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
class LLMTextProcessor(FrameProcessor):
"""A processor for handling or manipulating LLM text frames before they are processed further.
This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured
text aggregator. Using the customizable aggregator, it provides functionality to handle or
manipulate LLM text frames before they are sent to other components such as TTS services or
context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct
output tokens from the LLM.
"""
def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs):
"""Initialize the LLM text processor.
Args:
text_aggregator: An optional text aggregator to use for processing LLM text frames. By
default, a SimpleTextAggregator aggregating by sentence will be used.
**kwargs: Additional arguments passed to parent class.
TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters).
"""
super().__init__(**kwargs)
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
await self._handle_interruption(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame.skip_tts)
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
await self._handle_llm_end()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _handle_interruption(self, _):
"""Handle interruptions by resetting the text aggregator."""
await self._text_aggregator.handle_interruption()
async def reset(self):
"""Reset the internal state of the text processor and its aggregator."""
await self._text_aggregator.reset()
async def _handle_llm_text(self, in_frame: LLMTextFrame):
aggregation = await self._text_aggregator.aggregate(in_frame.text)
if aggregation:
out_frame = AggregatedTextFrame(
text=aggregation.text,
aggregated_by=aggregation.type,
)
out_frame.skip_tts = in_frame.skip_tts
await self.push_frame(out_frame)
async def _handle_llm_end(self, skip_tts: bool = False):
# Flush any remaining aggregated text at the end of the LLM response
aggregation = self._text_aggregator.text
await self._text_aggregator.reset()
text = aggregation.text.strip()
if text:
out_frame = AggregatedTextFrame(
text=text,
aggregated_by=aggregation.type,
)
out_frame.skip_tts = skip_tts
await self.push_frame(out_frame)

View File

@@ -24,6 +24,7 @@ from typing import (
Literal,
Mapping,
Optional,
Tuple,
Union,
)
@@ -32,6 +33,8 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
from pipecat.audio.utils import calculate_audio_volume
from pipecat.frames.frames import (
AggregatedTextFrame,
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -704,6 +707,29 @@ class RTVITextMessageData(BaseModel):
text: str
class RTVIBotOutputMessageData(RTVITextMessageData):
"""Data for bot output RTVI messages.
Extends RTVITextMessageData to include metadata about the output.
"""
spoken: bool = False # Indicates if the text has been spoken by TTS
aggregated_by: AggregationType | str
# Indicates what form the text is in (e.g., by word, sentence, etc.)
class RTVIBotOutputMessage(BaseModel):
"""Message containing bot output text.
An event meant to holistically represent what the bot is outputting,
along with metadata about the output and if it has been spoken.
"""
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
type: Literal["bot-output"] = "bot-output"
data: RTVIBotOutputMessageData
class RTVIBotTranscriptionMessage(BaseModel):
"""Message containing bot transcription text.
@@ -896,6 +922,7 @@ class RTVIObserverParams:
Parameter `errors_enabled` is deprecated. Error messages are always enabled.
Parameters:
bot_output_enabled: Indicates if bot output messages should be sent.
bot_llm_enabled: Indicates if the bot's LLM messages should be sent.
bot_tts_enabled: Indicates if the bot's TTS messages should be sent.
bot_speaking_enabled: Indicates if the bot's started/stopped speaking messages should be sent.
@@ -907,9 +934,17 @@ class RTVIObserverParams:
metrics_enabled: Indicates if metrics messages should be sent.
system_logs_enabled: Indicates if system logs should be sent.
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
Note: if using this to avoid sending secure information, be sure to also disable
bot_llm_enabled to avoid leaking through LLM messages.
bot_output_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
audio_level_period_secs: How often audio levels should be sent if enabled.
"""
bot_output_enabled: bool = True
bot_llm_enabled: bool = True
bot_tts_enabled: bool = True
bot_speaking_enabled: bool = True
@@ -921,6 +956,15 @@ class RTVIObserverParams:
metrics_enabled: bool = True
system_logs_enabled: bool = False
errors_enabled: Optional[bool] = None
skip_aggregator_types: Optional[List[AggregationType | str]] = None
bot_output_transforms: Optional[
List[
Tuple[
AggregationType | str,
Callable[[str, AggregationType | str], Awaitable[str]],
]
]
] = None
audio_level_period_secs: float = 0.15
@@ -973,8 +1017,45 @@ class RTVIObserver(BaseObserver):
DeprecationWarning,
)
self._aggregation_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = self._params.bot_output_transforms or []
def add_bot_output_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" to
handle all text before sending to the client.
"""
self._aggregation_transforms.append((aggregation_type, transform_function))
def remove_bot_output_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._aggregation_transforms = [
(agg_type, func)
for agg_type, func in self._aggregation_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _logger_sink(self, message):
"""Logger sink so we cna send system logs to RTVI clients."""
"""Logger sink so we can send system logs to RTVI clients."""
message = RTVISystemLogMessage(data=RTVITextMessageData(text=message))
await self.send_rtvi_message(message)
@@ -1048,12 +1129,15 @@ class RTVIObserver(BaseObserver):
await self.send_rtvi_message(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
if isinstance(src, BaseOutputTransport):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self.send_rtvi_message(message)
else:
elif isinstance(frame, AggregatedTextFrame) and (
self._params.bot_output_enabled or self._params.bot_tts_enabled
):
if isinstance(frame, TTSTextFrame) and not isinstance(src, BaseOutputTransport):
# This check is to make sure we handle the frame when it has gone
# through the transport and has correct timing.
mark_as_seen = False
else:
await self._handle_aggregated_llm_text(frame)
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
await self._handle_metrics(frame)
elif isinstance(frame, RTVIServerMessageFrame):
@@ -1084,15 +1168,6 @@ class RTVIObserver(BaseObserver):
if mark_as_seen:
self._frames_seen.add(frame.id)
async def _push_bot_transcription(self):
"""Push accumulated bot transcription as a message."""
if len(self._bot_transcription) > 0:
message = RTVIBotTranscriptionMessage(
data=RTVITextMessageData(text=self._bot_transcription)
)
await self.send_rtvi_message(message)
self._bot_transcription = ""
async def _handle_interruptions(self, frame: Frame):
"""Handle user speaking interruption frames."""
message = None
@@ -1115,14 +1190,45 @@ class RTVIObserver(BaseObserver):
if message:
await self.send_rtvi_message(message)
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
"""Handle aggregated LLM text output frames."""
# Skip certain aggregator types if configured to do so.
if (
self._params.skip_aggregator_types
and frame.aggregated_by in self._params.skip_aggregator_types
):
return
text = frame.text
type = frame.aggregated_by
for aggregation_type, transform in self._aggregation_transforms:
if aggregation_type == type or aggregation_type == "*":
text = await transform(text, type)
isTTS = isinstance(frame, TTSTextFrame)
if self._params.bot_output_enabled:
message = RTVIBotOutputMessage(
data=RTVIBotOutputMessageData(text=text, spoken=isTTS, aggregated_by=type)
)
await self.send_rtvi_message(message)
if isTTS and self._params.bot_tts_enabled:
tts_message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=text))
await self.send_rtvi_message(tts_message)
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
"""Handle LLM text output frames."""
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
await self.send_rtvi_message(message)
# TODO (mrkb): Remove all this logic when we fully deprecate bot-transcription messages.
self._bot_transcription += frame.text
if match_endofsentence(self._bot_transcription):
await self._push_bot_transcription()
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
await self.send_rtvi_message(
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
)
self._bot_transcription = ""
async def _handle_user_transcriptions(self, frame: Frame):
"""Handle user transcription frames."""
@@ -1248,7 +1354,7 @@ class RTVIProcessor(FrameProcessor):
# Default to 0.3.0 which is the last version before actually having a
# "client-version".
self._client_version = [0, 3, 0]
self._skip_tts: bool = False # Keep in sync with llm_service.py
self._llm_skip_tts: bool = False # Keep in sync with llm_service.py's configuration.
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
@@ -1441,7 +1547,7 @@ class RTVIProcessor(FrameProcessor):
elif isinstance(frame, RTVIActionFrame):
await self._action_queue.put(frame)
elif isinstance(frame, LLMConfigureOutputFrame):
self._skip_tts = frame.skip_tts
self._llm_skip_tts = frame.skip_tts
await self.push_frame(frame, direction)
# Other frames
else:
@@ -1697,9 +1803,9 @@ class RTVIProcessor(FrameProcessor):
opts = data.options if data.options is not None else RTVISendTextOptions()
if opts.run_immediately:
await self.interrupt_bot()
cur_skip_tts = self._skip_tts
cur_llm_skip_tts = self._llm_skip_tts
should_skip_tts = not opts.audio_response
toggle_skip_tts = cur_skip_tts != should_skip_tts
toggle_skip_tts = cur_llm_skip_tts != should_skip_tts
if toggle_skip_tts:
output_frame = LLMConfigureOutputFrame(skip_tts=should_skip_tts)
await self.push_frame(output_frame)
@@ -1709,7 +1815,7 @@ class RTVIProcessor(FrameProcessor):
)
await self.push_frame(text_frame)
if toggle_skip_tts:
output_frame = LLMConfigureOutputFrame(skip_tts=cur_skip_tts)
output_frame = LLMConfigureOutputFrame(skip_tts=cur_llm_skip_tts)
await self.push_frame(output_frame)
async def _handle_update_context(self, data: RTVIAppendToContextData):

View File

@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -1027,7 +1028,7 @@ class AWSNovaSonicLLMService(LLMService):
logger.debug(f"Assistant response text added: {text}")
# Report the text of the assistant response.
frame = TTSTextFrame(text)
frame = TTSTextFrame(text, aggregated_by=AggregationType.SENTENCE)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
@@ -1062,7 +1063,9 @@ class AWSNovaSonicLLMService(LLMService):
# TTSTextFrame would be ignored otherwise (the interruption frame
# would have cleared the assistant aggregator state).
await self.push_frame(LLMFullResponseStartFrame())
frame = TTSTextFrame(self._assistant_text_buffer)
frame = TTSTextFrame(
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
self._may_need_repush_assistant_text = False

View File

@@ -10,7 +10,8 @@ import base64
import json
import uuid
import warnings
from typing import AsyncGenerator, List, Literal, Optional, Union
from enum import Enum
from typing import AsyncGenerator, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field
@@ -125,6 +126,72 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
class CartesiaEmotion(str, Enum):
"""Predefined Emotions supported by Cartesia."""
# Primary emotions supported by Cartesia
NEUTRAL = "neutral"
ANGRY = "angry"
EXCITED = "excited"
CONTENT = "content"
SAD = "sad"
SCARED = "scared"
# Additional emotions supported by Cartesia
HAPPY = "happy"
ENTHUSIASTIC = "enthusiastic"
ELATED = "elated"
EUPHORIC = "euphoric"
TRIUMPHANT = "triumphant"
AMAZED = "amazed"
SURPRISED = "surprised"
FLIRTATIOUS = "flirtatious"
JOKING_COMEDIC = "joking/comedic"
CURIOUS = "curious"
PEACEFUL = "peaceful"
SERENE = "serene"
CALM = "calm"
GRATEFUL = "grateful"
AFFECTIONATE = "affectionate"
TRUST = "trust"
SYMPATHETIC = "sympathetic"
ANTICIPATION = "anticipation"
MYSTERIOUS = "mysterious"
MAD = "mad"
OUTRAGED = "outraged"
FRUSTRATED = "frustrated"
AGITATED = "agitated"
THREATENED = "threatened"
DISGUSTED = "disgusted"
CONTEMPT = "contempt"
ENVIOUS = "envious"
SARCASTIC = "sarcastic"
IRONIC = "ironic"
DEJECTED = "dejected"
MELANCHOLIC = "melancholic"
DISAPPOINTED = "disappointed"
HURT = "hurt"
GUILTY = "guilty"
BORED = "bored"
TIRED = "tired"
REJECTED = "rejected"
NOSTALGIC = "nostalgic"
WISTFUL = "wistful"
APOLOGETIC = "apologetic"
HESITANT = "hesitant"
INSECURE = "insecure"
CONFUSED = "confused"
RESIGNED = "resigned"
ANXIOUS = "anxious"
PANICKED = "panicked"
ALARMED = "alarmed"
PROUD = "proud"
CONFIDENT = "confident"
DISTANT = "distant"
SKEPTICAL = "skeptical"
CONTEMPLATIVE = "contemplative"
DETERMINED = "determined"
class CartesiaTTSService(AudioContextWordTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
@@ -182,6 +249,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
container: Audio container format.
params: Additional input parameters for voice customization.
text_aggregator: Custom text aggregator for processing input text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
aggregate_sentences: Whether to aggregate sentences within the TTSService.
**kwargs: Additional arguments passed to the parent service.
"""
@@ -200,10 +271,18 @@ class CartesiaTTSService(AudioContextWordTTSService):
push_text_frames=False,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator or SkipTagsAggregator([("<spell>", "</spell>")]),
text_aggregator=text_aggregator,
**kwargs,
)
if not text_aggregator:
# Always skip tags added for spelled-out text
# Note: This is primarily to support backwards compatibility.
# The preferred way of taking advantage of Cartesia SSML Tags is
# to use an LLMTextProcessor and/or a text_transformer to identify
# and insert these tags for the purpose of the TTS service alone.
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
params = params or CartesiaTTSService.InputParams()
self._api_key = api_key
@@ -257,6 +336,27 @@ class CartesiaTTSService(AudioContextWordTTSService):
"""
return language_to_cartesia_language(language)
# A set of Cartesia-specific helpers for text transformations
def SPELL(text: str) -> str:
"""Wrap text in Cartesia spell tag."""
return f"<spell>{text}</spell>"
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
"""Convenience method to create an emotion tag."""
return f'<emotion value="{emotion}" />'
def PAUSE_TAG(seconds: float) -> str:
"""Convenience method to create a pause tag."""
return f'<break time="{seconds}s" />'
def VOLUME_TAG(volume: float) -> str:
"""Convenience method to create a volume tag."""
return f'<volume ratio="{volume}" />'
def SPEED_TAG(speed: float) -> str:
"""Convenience method to create a speed tag."""
return f'<speed ratio="{speed}" />'
def _is_cjk_language(self, language: str) -> bool:
"""Check if the given language is CJK (Chinese, Japanese, Korean).

View File

@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -1646,7 +1647,7 @@ class GeminiLiveLLMService(LLMService):
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
frame = TTSTextFrame(text=text)
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
# Gemini Live text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True

View File

@@ -19,6 +19,7 @@ from pipecat.adapters.services.open_ai_realtime_adapter import (
OpenAIRealtimeLLMAdapter,
)
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -686,7 +687,7 @@ class OpenAIRealtimeLLMService(LLMService):
# We receive audio transcript deltas (as opposed to text deltas) when
# the output modality is "audio" (the default)
if evt.delta:
frame = TTSTextFrame(evt.delta)
frame = TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE)
# OpenAI Realtime text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)

View File

@@ -17,6 +17,7 @@ from loguru import logger
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -652,7 +653,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_audio_transcript_delta(self, evt):
if evt.delta:
await self.push_frame(LLMTextFrame(evt.delta))
await self.push_frame(TTSTextFrame(evt.delta))
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()

View File

@@ -113,6 +113,10 @@ class RimeTTSService(AudioContextWordTTSService):
sample_rate: Audio sample rate in Hz.
params: Additional configuration parameters.
text_aggregator: Custom text aggregator for processing input text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
aggregate_sentences: Whether to aggregate sentences within the TTSService.
**kwargs: Additional arguments passed to parent class.
"""
@@ -123,10 +127,17 @@ class RimeTTSService(AudioContextWordTTSService):
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
**kwargs,
)
if not text_aggregator:
# Always skip tags added for spelled-out text
# Note: This is primarily to support backwards compatibility.
# The preferred way of taking advantage of Rime spelling is
# to use an LLMTextProcessor and/or a text_transformer to identify
# and insert these tags for the purpose of the TTS service alone.
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
params = params or RimeTTSService.InputParams()
# Store service configuration
@@ -152,6 +163,7 @@ class RimeTTSService(AudioContextWordTTSService):
self._context_id = None # Tracks current turn
self._receive_task = None
self._cumulative_time = 0 # Accumulates time across messages
self._extra_msg_fields = {} # Extra fields for next message
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -181,6 +193,31 @@ class RimeTTSService(AudioContextWordTTSService):
self._model = model
await super().set_model(model)
# A set of Rime-specific helpers for text transformations
def SPELL(text: str) -> str:
"""Wrap text in Rime spell function."""
return f"spell({text})"
def PAUSE_TAG(seconds: float) -> str:
"""Convenience method to create a pause tag."""
return f"<{seconds * 1000}>"
def PRONOUNCE(self, text: str, word: str, phoneme: str) -> str:
"""Convenience method to support Rime's custom pronunciations feature.
https://docs.rime.ai/api-reference/custom-pronunciation
"""
self._extra_msg_fields["phonemizeBetweenBrackets"] = True
return text.replace(word, f"{phoneme}")
def INLINE_SPEED(self, text: str, speed: float) -> str:
"""Convenience method to support inline speeds."""
if not self._extra_msg_fields:
self._extra_msg_fields = {}
speed_vals = self._extra_msg_fields.get("inlineSpeedAlpha", "").split(",")
self._extra_msg_fields["inlineSpeedAlpha"] = ",".join(speed_vals + [str(speed)])
return f"[{text}]"
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""
prev_voice = self._voice_id
@@ -193,7 +230,11 @@ class RimeTTSService(AudioContextWordTTSService):
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Rime API."""
return {"text": text, "contextId": self._context_id}
msg = {"text": text, "contextId": self._context_id}
if self._extra_msg_fields:
msg |= self._extra_msg_fields
self._extra_msg_fields = {}
return msg
def _build_clear_msg(self) -> dict:
"""Build clear operation message."""

View File

@@ -12,6 +12,8 @@ from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Mapping,
@@ -23,6 +25,8 @@ from typing import (
from loguru import logger
from pipecat.frames.frames import (
AggregatedTextFrame,
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -101,6 +105,16 @@ class TTSService(AIService):
sample_rate: Optional[int] = None,
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
text_aggregator: Optional[BaseTextAggregator] = None,
# Types of text aggregations that should not be spoken.
skip_aggregator_types: Optional[List[str]] = [],
# A list of callables to transform text before just before sending it to TTS.
# Each callable takes the aggregated text and its type, and returns the transformed text.
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
text_transforms: Optional[
List[
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
]
] = None,
# Text filter executed after text has been aggregated.
text_filters: Optional[Sequence[BaseTextFilter]] = None,
text_filter: Optional[BaseTextFilter] = None,
@@ -120,6 +134,16 @@ class TTSService(AIService):
pause_frame_processing: Whether to pause frame processing during audio generation.
sample_rate: Output sample rate for generated audio.
text_aggregator: Custom text aggregator for processing incoming text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
skip_aggregator_types: List of aggregation types that should not be spoken.
text_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
text_filters: Sequence of text filters to apply after aggregation.
text_filter: Single text filter (deprecated, use text_filters).
@@ -142,6 +166,21 @@ class TTSService(AIService):
self._voice_id: str = ""
self._settings: Dict[str, Any] = {}
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
if text_aggregator:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
DeprecationWarning,
)
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
self._text_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = text_transforms or []
# TODO: Deprecate _text_filters when added to LLMTextProcessor
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
self._transport_destination: Optional[str] = transport_destination
self._tracing_enabled: bool = False
@@ -298,6 +337,39 @@ class TTSService(AIService):
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
def add_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
the function should handle all text before sending to TTS.
"""
self._text_transforms.append((aggregation_type, transform_function))
def remove_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._text_transforms = [
(agg_type, func)
for agg_type, func in self._text_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _update_settings(self, settings: Mapping[str, Any]):
for key, value in settings.items():
if key in self._settings:
@@ -353,6 +425,8 @@ class TTSService(AIService):
and frame.skip_tts
):
await self.push_frame(frame, direction)
elif isinstance(frame, AggregatedTextFrame):
await self._push_tts_frames(frame)
elif (
isinstance(frame, TextFrame)
and not isinstance(frame, InterimTranscriptionFrame)
@@ -368,10 +442,10 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
sentence = self._text_aggregator.text
aggregate = self._text_aggregator.text
await self._text_aggregator.reset()
self._processing_text = False
await self._push_tts_frames(sentence)
await self._push_tts_frames(AggregatedTextFrame(aggregate.text, aggregate.type))
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
@@ -380,7 +454,7 @@ class TTSService(AIService):
elif isinstance(frame, TTSSpeakFrame):
# Store if we were processing text or not so we can set it back.
processing_text = self._processing_text
await self._push_tts_frames(frame.text)
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
# We pause processing incoming frames because we are sending data to
# the TTS. We pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
@@ -472,13 +546,24 @@ class TTSService(AIService):
text: Optional[str] = None
if not self._aggregate_sentences:
text = frame.text
aggregated_by = "token"
else:
text = await self._text_aggregator.aggregate(frame.text)
aggregate = await self._text_aggregator.aggregate(frame.text)
if aggregate:
text = aggregate.text
aggregated_by = aggregate.type
if text:
await self._push_tts_frames(text)
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(AggregatedTextFrame(text, aggregated_by))
async def _push_tts_frames(self, src_frame: AggregatedTextFrame):
type = src_frame.aggregated_by
text = src_frame.text
if type in self._skip_aggregator_types:
await self.push_frame(src_frame)
return
async def _push_tts_frames(self, text: str):
# Remove leading newlines only
text = text.lstrip("\n")
@@ -499,15 +584,39 @@ class TTSService(AIService):
await filter.reset_interruption()
text = await filter.filter(text)
if text:
await self.process_generator(self.run_tts(text))
if not text.strip():
await self.stop_processing_metrics()
return
# To support use cases that may want to know the text before it's spoken, we
# push the AggregatedTextFrame version before transforming and sending to TTS.
# However, we do not want to add this text to the assistant context until it
# is spoken, so we set append_to_context to False.
src_frame.append_to_context = False
await self.push_frame(src_frame)
# Note: Text transformations are meant to only affect the text sent to the TTS for
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
# services that support word-level timestamps, this CAN affect the resulting context
# since the TTSTextFrames are generated from the TTS output stream
transformed_text = text
for aggregation_type, transform in self._text_transforms:
if aggregation_type == type or aggregation_type == "*":
transformed_text = await transform(transformed_text, type)
await self.process_generator(self.run_tts(transformed_text))
await self.stop_processing_metrics()
if self._push_text_frames:
# We send the original text after the audio. This way, if we are
# interrupted, the text is not added to the assistant context.
frame = TTSTextFrame(text)
# In TTS services that support word timestamps, the TTSTextFrames
# are pushed as words are spoken. However, in the case where the TTS service
# does not support word timestamps (i.e. _push_text_frames is True), we send
# the original (non-transformed) text after the TTS generation has completed.
# This way, if we are interrupted, the text is not added to the assistant
# context and the context that IS added does not include TTS-specific tags
# or transformations.
frame = TTSTextFrame(text, aggregated_by=type)
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
await self.push_frame(frame)
@@ -635,7 +744,7 @@ class WordTTSService(TTSService):
frame = TTSStoppedFrame()
frame.pts = last_pts
else:
frame = TTSTextFrame(word)
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
frame.pts = self._initial_word_timestamp + timestamp
if frame:
last_pts = frame.pts

View File

@@ -203,8 +203,16 @@ async def run_test(
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)
print("received DOWN frames =", received_down_frames)
print("expected DOWN frames =", expected_down_frames)
down_frames_printed = "["
for frame in received_down_frames:
down_frames_printed += f"{frame.__class__.__name__}, "
down_frames_printed += "]"
expected_frames_printed = "["
for frame in expected_down_frames:
expected_frames_printed += f"{frame.__name__}, "
expected_frames_printed += "]"
print("received DOWN frames =", down_frames_printed)
print("expected DOWN frames =", expected_frames_printed)
assert len(received_down_frames) == len(expected_down_frames)

View File

@@ -12,9 +12,46 @@ aggregated text should be sent for speech synthesis.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class Aggregation:
"""Data class representing aggregated text and its type.
An Aggregation object is created whenever a stream of text is aggregated by
a text aggregator. It contains the aggregated text and a type indicating
the nature of the aggregation.
Parameters:
text: The aggregated text content.
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
"""
text: str
type: str
def __str__(self) -> str:
"""Return a string representation of the aggregation.
Returns:
A descriptive string showing the type and text of the aggregation.
"""
return f"Aggregation by {self.type}: {self.text}"
class BaseTextAggregator(ABC):
"""Base class for text aggregators in the Pipecat framework.
@@ -30,7 +67,7 @@ class BaseTextAggregator(ABC):
@property
@abstractmethod
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Subclasses must implement this property to return the text that has
@@ -42,12 +79,13 @@ class BaseTextAggregator(ABC):
pass
@abstractmethod
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate the specified text with the currently accumulated text.
This method should be implemented to define how the new text contributes
to the aggregation process. It returns the updated aggregated text if
it's ready to be processed, or None otherwise.
to the aggregation process. It returns the aggregated text and a string
describing how it was aggregated if it's ready to be processed,
or None otherwise.
Subclasses should implement their specific logic for:

View File

@@ -8,19 +8,41 @@
This module provides an aggregator that identifies and processes content between
pattern pairs (like XML tags or custom delimiters) in streaming text, with
support for custom handlers and configurable pattern removal.
support for custom handlers and configurable actions for when a pattern is found.
"""
import re
from typing import Awaitable, Callable, Optional, Tuple
from enum import Enum
from typing import Awaitable, Callable, List, Optional, Tuple
from loguru import logger
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class PatternMatch:
class MatchAction(Enum):
"""Actions to take when a pattern pair is matched.
Parameters:
REMOVE: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
KEEP: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
AGGREGATE: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
"""
REMOVE = "remove"
KEEP = "keep"
AGGREGATE = "aggregate"
class PatternMatch(Aggregation):
"""Represents a matched pattern pair with its content.
A PatternMatch object is created when a complete pattern pair is found
@@ -29,25 +51,25 @@ class PatternMatch:
content between the patterns.
"""
def __init__(self, pattern_id: str, full_match: str, content: str):
def __init__(self, content: str, type: str, full_match: str):
"""Initialize a pattern match.
Args:
pattern_id: The identifier of the matched pattern pair.
type: The type of the matched pattern pair. It should be representative
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
full_match: The complete text including start and end patterns.
content: The text content between the start and end patterns.
"""
self.pattern_id = pattern_id
super().__init__(text=content, type=type)
self.full_match = full_match
self.content = content
def __str__(self) -> str:
"""Return a string representation of the pattern match.
Returns:
A descriptive string showing the pattern ID and content.
A descriptive string showing the pattern type and content.
"""
return f"PatternMatch(id={self.pattern_id}, content={self.content})"
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
class PatternPairAggregator(BaseTextAggregator):
@@ -55,16 +77,21 @@ class PatternPairAggregator(BaseTextAggregator):
This aggregator buffers text until it can identify complete pattern pairs
(defined by start and end patterns), processes the content between these
patterns using registered handlers, and returns text at sentence boundaries.
It's particularly useful for processing structured content in streaming text,
such as XML tags, markdown formatting, or custom delimiters.
patterns using registered handlers. By default, its aggregation method
returns text at sentence boundaries, and remove the content found between
any matched patterns. However, matched patterns can also be configured to
returned as a separate aggregation object containing the content between
their start and end patterns or left in, so that only the delimiters are
removed and a callback can be triggered.
This aggregator is particularly useful for processing structured content in
streaming text, such as XML tags, markdown formatting, or custom delimiters.
The aggregator ensures that patterns spanning multiple text chunks are
correctly identified and handles cases where patterns contain sentence
boundaries.
correctly identified.
"""
def __init__(self):
def __init__(self, **kwargs):
"""Initialize the pattern pair aggregator.
Creates an empty aggregator with no patterns or handlers registered.
@@ -75,16 +102,23 @@ class PatternPairAggregator(BaseTextAggregator):
self._handlers = {}
@property
def text(self) -> str:
"""Get the currently buffered text.
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Returns:
The current text buffer content that hasn't been processed yet.
The text that has been accumulated in the buffer.
"""
return self._text
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start:
return Aggregation(self._text, pattern_start[1].get("type", AggregationType.SENTENCE))
return Aggregation(self._text, AggregationType.SENTENCE)
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
def add_pattern(
self,
type: str,
start_pattern: str,
end_pattern: str,
action: MatchAction = MatchAction.REMOVE,
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -93,41 +127,94 @@ class PatternPairAggregator(BaseTextAggregator):
the end pattern, and treat the content between them as a match.
Args:
pattern_id: Unique identifier for this pattern pair.
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
those are reserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
remove_match: Whether to remove the matched content from the text.
action: What to do when a complete pattern is matched:
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
normal text. This allows you to register handlers for
the pattern without affecting the aggregation logic.
- MatchAction.AGGREGATE: Return the matched pattern as a separate
aggregation object.
Returns:
Self for method chaining.
"""
self._patterns[pattern_id] = {
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
raise ValueError(
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
)
self._patterns[type] = {
"start": start_pattern,
"end": end_pattern,
"remove_match": remove_match,
"type": type,
"action": action,
}
return self
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
):
"""Add a pattern pair to detect in the text.
.. deprecated:: 0.0.95
This function is deprecated and will be removed in a future version.
Use `add_pattern` with a type and MatchAction instead.
This method calls `add_pattern` setting type with the provided pattern_id and action
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
Args:
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
as those arereserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
"""
import warnings
with warnings.catch_warnings():
warnings.simplefilter("once")
warnings.warn(
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
" removed in a future version. Use add_pattern with a type and MatchAction instead",
DeprecationWarning,
stacklevel=2,
)
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
return self.add_pattern(
type=pattern_id,
start_pattern=start_pattern,
end_pattern=end_pattern,
action=action,
)
def on_pattern_match(
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
) -> "PatternPairAggregator":
"""Register a handler for when a pattern pair is matched.
The handler will be called whenever a complete match for the
specified pattern ID is found in the text.
specified type is found in the text.
Args:
pattern_id: ID of the pattern pair to match.
type: The type of the pattern pair to trigger the handler.
handler: Async function to call when pattern is matched.
The function should accept a PatternMatch object.
Returns:
Self for method chaining.
"""
self._handlers[pattern_id] = handler
self._handlers[type] = handler
return self
async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]:
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
"""Process all complete pattern pairs in the text.
Searches for all complete pattern pairs in the text, calls the
@@ -137,19 +224,19 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to process for pattern matches.
Returns:
Tuple of (processed_text, was_modified) where:
Tuple of (all_matches, processed_text) where:
- processed_text is the text after processing patterns
- was_modified indicates whether any changes were made
- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
"""
all_matches = []
processed_text = text
modified = False
for pattern_id, pattern_info in self._patterns.items():
for type, pattern_info in self._patterns.items():
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
remove_match = pattern_info["remove_match"]
action = pattern_info["action"]
# Create regex to match from start pattern to end pattern
# The .*? is non-greedy to handle nested patterns
@@ -164,25 +251,24 @@ class PatternPairAggregator(BaseTextAggregator):
full_match = match.group(0) # Full match including patterns
# Create pattern match object
pattern_match = PatternMatch(
pattern_id=pattern_id, full_match=full_match, content=content
)
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
# Call the appropriate handler if registered
if pattern_id in self._handlers:
if type in self._handlers:
try:
await self._handlers[pattern_id](pattern_match)
await self._handlers[type](pattern_match)
except Exception as e:
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
logger.error(f"Error in pattern handler for {type}: {e}")
# Remove the pattern from the text if configured
if remove_match:
if action == MatchAction.REMOVE:
processed_text = processed_text.replace(full_match, "", 1)
modified = True
else:
all_matches.append(pattern_match)
return processed_text, modified
return all_matches, processed_text
def _has_incomplete_patterns(self, text: str) -> bool:
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
"""Check if text contains incomplete pattern pairs.
Determines whether the text contains any start patterns without
@@ -192,9 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to check for incomplete patterns.
Returns:
True if there are incomplete patterns, False otherwise.
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
or None if no patterns are found or all patterns are complete.
"""
for pattern_id, pattern_info in self._patterns.items():
for type, pattern_info in self._patterns.items():
start = pattern_info["start"]
end = pattern_info["end"]
@@ -203,12 +290,16 @@ class PatternPairAggregator(BaseTextAggregator):
end_count = text.count(end)
# If there are more starts than ends, we have incomplete patterns
# Again, this is written generically but there only ever should
# be one pattern active at a time, so the counts should be 0 or 1.
# Which is why we base the return on the first found.
if start_count > end_count:
return True
start_index = text.find(start)
return [start_index, pattern_info]
return False
return None
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[PatternMatch]:
"""Aggregate text and process pattern pairs.
This method adds the new text to the buffer, processes any complete pattern
@@ -227,16 +318,34 @@ class PatternPairAggregator(BaseTextAggregator):
self._text += text
# Process any complete patterns in the buffer
processed_text, modified = await self._process_complete_patterns(self._text)
patterns, processed_text = await self._process_complete_patterns(self._text)
# Only update the buffer if modifications were made
if modified:
self._text = processed_text
self._text = processed_text
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
return patterns[0]
# Check if we have incomplete patterns
if self._has_incomplete_patterns(self._text):
# Still waiting for complete patterns
return None
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
return None
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
# Find sentence boundary if no incomplete patterns
eos_marker = match_endofsentence(self._text)
@@ -244,7 +353,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
# No complete sentence found yet
return None

View File

@@ -14,7 +14,7 @@ text processing scenarios.
from typing import Optional
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class SimpleTextAggregator(BaseTextAggregator):
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
self._text = ""
@property
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Returns:
The text that has been accumulated in the buffer.
"""
return self._text
return Aggregation(self._text, AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text and return completed sentences.
Adds the new text to the buffer and checks for end-of-sentence markers.
@@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator):
result = self._text[:eos_end_marker]
self._text = self._text[eos_end_marker:]
return result
return Aggregation(result, AggregationType.SENTENCE) if result else None
async def handle_interruption(self):
"""Handle interruptions by clearing the text buffer.

View File

@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
from typing import Optional, Sequence
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class SkipTagsAggregator(BaseTextAggregator):
@@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator):
Returns:
The current text buffer content that hasn't been processed yet.
"""
return self._text
return Aggregation(self._text, AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text while respecting tag boundaries.
This method adds the new text to the buffer, processes any complete
@@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
return Aggregation(result, AggregationType.SENTENCE)
# No complete sentence found yet
return None

View File

@@ -7,30 +7,42 @@
import unittest
from unittest.mock import AsyncMock
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.aggregator = PatternPairAggregator()
self.test_handler = AsyncMock()
self.code_handler = AsyncMock()
# Add a test pattern
self.aggregator.add_pattern_pair(
pattern_id="test_pattern",
start_pattern="<test>",
end_pattern="</test>",
remove_match=True,
)
self.aggregator.add_pattern(
type="code_pattern",
start_pattern="<code>",
end_pattern="</code>",
action=MatchAction.AGGREGATE,
)
# Register the mock handler
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
async def test_pattern_match_and_removal(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</test>!")
@@ -39,20 +51,49 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "test_pattern")
self.assertEqual(call_args.type, "test_pattern")
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
self.assertEqual(call_args.content, "pattern content")
self.assertEqual(call_args.text, "pattern content")
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
self.assertEqual(result, "Hello !")
self.assertEqual(result.text, "Hello !")
self.assertEqual(result.type, "sentence")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result, " This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_pattern_match_and_aggregate(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Here is code <code>pattern")
self.assertEqual(result.text, "Here is code ")
self.assertEqual(self.aggregator.text.text, "<code>pattern")
self.assertEqual(self.aggregator.text.type, "code_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</code>")
# Verify the handler was called with correct PatternMatch object
self.code_handler.assert_called_once()
call_args = self.code_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.type, "code_pattern")
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(result.text, "pattern content")
self.assertEqual(result.type, "code_pattern")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
self.assertEqual(result.type, "sentence")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
@@ -65,26 +106,30 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_not_called()
# Buffer should contain the incomplete text
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Reset and confirm buffer is cleared
await self.aggregator.reset()
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_multiple_patterns(self):
# Set up multiple patterns and handlers
voice_handler = AsyncMock()
emphasis_handler = AsyncMock()
self.aggregator.add_pattern_pair(
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
self.aggregator.add_pattern(
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
pattern_id="emphasis",
self.aggregator.add_pattern(
type="emphasis",
start_pattern="<em>",
end_pattern="</em>",
remove_match=False, # Keep emphasis tags
action=MatchAction.KEEP, # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)
@@ -97,19 +142,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Both handlers should be called with correct data
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
self.assertEqual(voice_match.pattern_id, "voice")
self.assertEqual(voice_match.content, "female")
self.assertEqual(voice_match.type, "voice")
self.assertEqual(voice_match.text, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
self.assertEqual(emphasis_match.pattern_id, "emphasis")
self.assertEqual(emphasis_match.content, "very")
self.assertEqual(emphasis_match.type, "emphasis")
self.assertEqual(emphasis_match.text, "very")
# Voice pattern should be removed, emphasis pattern should remain
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
# Buffer should be empty
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_handle_interruption(self):
# Start with incomplete pattern
@@ -120,7 +165,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
await self.aggregator.handle_interruption()
# Buffer should be cleared
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
# Handler should not have been called
self.test_handler.assert_not_called()
@@ -138,10 +183,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Handler should be called with entire content
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
# Pattern should be removed, resulting in text with sentences merged
self.assertEqual(result, "Hello Final sentence.")
self.assertEqual(result.text, "Hello Final sentence.")
# Buffer should be empty
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")

View File

@@ -13,6 +13,7 @@ import pytest
from aiohttp import web
from pipecat.frames.frames import (
AggregatedTextFrame,
ErrorFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
@@ -74,6 +75,7 @@ async def test_run_piper_tts_success(aiohttp_client):
]
expected_returned_frames = [
AggregatedTextFrame,
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
@@ -121,7 +123,7 @@ async def test_run_piper_tts_error(aiohttp_client):
TTSSpeakFrame(text="Error case."),
]
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]

View File

@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
assert self.aggregator.text == "Hello "
assert self.aggregator.text.text == "Hello "
await self.aggregator.reset()
assert self.aggregator.text == ""
assert self.aggregator.text.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
assert self.aggregator.text == ""
aggregate = await self.aggregator.aggregate("Pipecat!")
assert aggregate.text == "Hello Pipecat!"
assert aggregate.type == "sentence"
assert self.aggregator.text.text == ""
async def test_multiple_sentences(self):
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
assert await self.aggregator.aggregate("you?") == " How are you?"
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
assert aggregate.text == "Hello Pipecat!"
assert self.aggregator.text.text == " How are "
aggregate = await self.aggregator.aggregate("you?")
assert aggregate.text == " How are you?"

View File

@@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# No tags involved, aggregate at end of sentence.
result = await self.aggregator.aggregate("Hello Pipecat!")
self.assertEqual(result, "Hello Pipecat!")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "Hello Pipecat!")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
async def test_basic_tags(self):
await self.aggregator.reset()
# Tags involved, avoid aggregation during tags.
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
async def test_streaming_tags(self):
await self.aggregator.reset()
@@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# Tags involved, stream small chunk of texts.
result = await self.aggregator.aggregate("My email is <sp")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <sp")
self.assertEqual(self.aggregator.text.text, "My email is <sp")
result = await self.aggregator.aggregate("ell>foo.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
result = await self.aggregator.aggregate("bar@pipecat.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
result = await self.aggregator.aggregate("ai</spe")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.type, "sentence")
result = await self.aggregator.aggregate("ll>.")
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text.type, "sentence")

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timezone
from typing import List, Tuple, cast
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -130,11 +131,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(), # Wait for StartedSpeaking to process
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
TTSTextFrame(text="How"),
TTSTextFrame(text="are"),
TTSTextFrame(text="you?"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
SleepFrame(), # Wait for text frames to queue
BotStoppedSpeakingFrame(),
]
@@ -195,9 +196,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text=""), # Empty text
TTSTextFrame(text=" "), # Just whitespace
TTSTextFrame(text="\n"), # Just newline
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
BotStoppedSpeakingFrame(),
# Pipeline ends here; run_test will automatically send EndFrame
]
@@ -235,14 +236,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
SleepFrame(),
InterruptionFrame(), # User interrupts here
SleepFrame(),
BotStartedSpeakingFrame(),
TTSTextFrame(text="New"),
TTSTextFrame(text="response"),
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
@@ -299,8 +300,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
# Pipeline ends here; run_test will automatically send EndFrame
]
@@ -338,8 +339,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
SleepFrame(), # Ensure messages are processed
CancelFrame(),
]
@@ -401,8 +402,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Assistant"),
TTSTextFrame(text="message"),
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
]
@@ -439,7 +440,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
# Test the specific pattern shared
def make_tts_text_frame(text: str) -> TTSTextFrame:
frame = TTSTextFrame(text=text)
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
frame.includes_inter_frame_spaces = True
return frame