Compare commits

..

3 Commits

Author SHA1 Message Date
James Hush
1b28fc8e8e Fix: Ensure EndFrame propagates through AIService before stop()
This fix addresses a critical bug where EndFrame (and potentially other
system frames) would trigger the stop() method in AIService but never
be pushed downstream to subsequent processors, causing pipelines to hang.

The issue occurred because AIService.process_frame() would call stop(frame)
for EndFrame without first pushing it downstream. This meant that downstream
processors never received the shutdown signal, leaving the pipeline in a
waiting state.

The fix ensures EndFrame is pushed downstream BEFORE calling stop(), following
the same pattern used by RTVIProcessor and properly-implemented processors.
This guarantees that:
1. Downstream processors receive the EndFrame for proper cleanup
2. The stop() method can then safely perform service-specific cleanup
3. The ordering prevents race conditions during shutdown

This bug affected all AI services inheriting from AIService that didn't
override process_frame() to explicitly handle EndFrame, including scenarios
with TTS services, LLM services, and other AI service implementations.

Fixes pipeline hangs during graceful shutdown when EndFrame is sent.
2025-11-17 11:09:42 +01:00
Mark Backman
35ff44b799 Merge pull request #3059 from pipecat-ai/mb/remove-llm-tracing-fallback 2025-11-14 14:07:40 -05:00
Mark Backman
d01876ee60 Remove fallbacks in traced_llm 2025-11-14 12:13:49 -05:00
28 changed files with 204 additions and 1034 deletions

View File

@@ -16,82 +16,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
services that subclass `TTSService` can indicate whether the text in the
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
- Introduced new `AggregatedTextFrame` type to support representing a best effort of
the perceived llm output whether or not it is processed by the TTS. This new frame
type includes the field `aggregated_by` to represent the conceptual format by which
the given text is aggregated. `TTSTextFrame`s now inherit from `AggregatedTextFrame`.
With this inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate
the perceived output and determine whether or not the text was spoken based on if that
frame is also a `TTSTextFrame`. (See bullet below on new `bot-output` which takes
advantage of this)
- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how
LLMTextFrames should be aggregated and considered. It's purpose is to turn
`LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still
aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you
wish to override how the llm text is aggregated, you should no longer override the
TTS's internal aggregator, but instead, insert this processor between your LLM and
TTS in the pipeline.
- New `bot-output` RTVI message to represent what the bot actually "says".
- The `RTVIObserver` now emits `bot-output` messages based off the new `AggregatedTextFrame`s
(`bot-tts-text` and `bot-llm-text` are still supported and generated, but `bot-transcript` is
now deprecated in lieu of this new, more thorough, message).
- The new `RTVIBotOutputMessage` includes the fields:
- `spoken`: A boolean indicating whether the text was spoken by TTS
- `aggregated_by`: A string representing how the text was aggregated ("sentence", "word",
"my custom aggregation")
- Introduced new fields to `RTVIObserver` to support the new `bot-output` messaging:
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output messages.
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that match
aggregation types that should not be included in bot-output messages. (Ex. `credit_card`)
- Introduced new methods, `add_text_transformer()` and `remove_text_transformer()`, to `RTVIObserver` to support providing (and subsequently removing)
callbacks for various types of aggregations (or all aggregations with `*`) that can modify the
text before being sent as a `bot-output` or `tts-text` message. (Think obscuring the credit card
or inserting extra detail the client might want that the context doesn't need.)
- Updated the base aggregator type:
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
aggregation")
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
(instead of `Optional[str]`). To update:
```
aggregation = myAggregator.aggregate(text)
if (aggregation):
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
```
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
produce/consume `Aggregation` objects.
- Augmented the `PatternPairAggregator`:
- Introduced a new, preferred version of `add_pattern` to support a new option for treating a
match as a separate aggregation returned from `aggregate()`. This replaces the now
deprecated `add_pattern_pair` method and you provide a `MatchAction` in lieu of the `remove_match` field.
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization for how
a match should be handled.
- `REMOVE`: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
- `KEEP`: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
- `AGGREGATE`: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
- `PatternMatch` now extends `Aggregation` and provides richer info to handlers.
- **BREAKING**: The `PatternMatch` type returned to handlers registered via `on_pattern_match`
has been updated to subclass from the new `Aggregation` type, which means that `content`
has been replaced with `text` and `pattern_id` has been replaced with `type`:
```
async dev on_match_tag(match: PatternMatch):
pattern = match.type # instead of match.pattern_id
text = match.text # instead of match.content
```
### Changed
- Updated all STT and TTS services to use consistent error handling pattern with
@@ -109,42 +33,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated language mappings for the Google and Gemini TTS services to match
official documentation.
- `TextFrame` new field `append_to_context` used to indicate if the encompassing
text should be added to the LLM context (by the LLM assistant aggregator). It
defaults to `True`.
- TTS flow respects aggregation metadata
- `TTSService` accepts a new `skip_aggregator_types` to avoid speaking certain aggregation types
(now determined/returned by the aggregator)
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when either an
aggregation occurs that should not be spoken or when the TTS service supports word-by-word
timestamping. In the latter case, the `TTSService` preliminarily generates an
`AggregatedTextFrame`, aggregated by sentence to generate the full sentence content as early
as possible.
- Introduced a new methods, `add_text_transformer()` and `remove_text_transformer()`:
These functions introduce the ability to provide (and subsequently remove) callbacks to the TTS to transform text based on
its aggregated type prior to sending the text to the underlying TTS service. This makes it
possible to do things like introduce TTS-specific tags for spelling or emotion or change the
pronunciation of something on the fly.
### Deprecated
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
`credentials` or `credentials_path` instead for Google Cloud authentication.
- The RTVI `bot-transcription` event is deprecated in favor of the new `bot-output`
message which is the canonical representation of bot output (spoken or not). The code
still emits a transcription message for backwards compatibility while transition occurs.
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
`LLMTextProcessor`. TTSServices still have an internal aggregator for support of default
behavior, but if you want to override the aggregation behavior, you should use the new
processor.
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a `pattern_id`
and `remove_match` field in favor of the new `add_pattern` method which takes a `type` and an
`action`
### Fixed
- Fixed subtle issue of assistant context messages ending up with double spaces

View File

@@ -62,11 +62,7 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
load_dotenv(override=True)
@@ -110,16 +106,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
pattern_aggregator = PatternPairAggregator()
# Add pattern for voice switching
pattern_aggregator.add_pattern(
type="voice",
pattern_aggregator.add_pattern_pair(
pattern_id="voice_tag",
start_pattern="<voice>",
end_pattern="</voice>",
action=MatchAction.REMOVE, # Remove tags from final text
remove_match=True,
)
# Register handler for voice switching
async def on_voice_tag(match: PatternMatch):
voice_name = match.text.strip().lower()
voice_name = match.content.strip().lower()
if voice_name in VOICE_IDS:
# First flush any existing audio to finish the current context
await tts.flush_audio()
@@ -129,7 +125,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
else:
logger.warning(f"Unknown voice: {voice_name}")
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))

View File

@@ -31,11 +31,7 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import LLMService
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
class IVRStatus(Enum):
@@ -118,15 +114,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
self._aggregator.add_pattern("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", remove_match=True)
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
self._aggregator.add_pattern("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", remove_match=True)
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
self._aggregator.add_pattern("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", remove_match=True)
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -163,7 +159,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing DTMF content.
"""
value = match.text
value = match.content
logger.debug(f"DTMF detected: {value}")
try:
@@ -184,7 +180,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing IVR status content.
"""
status = match.text
status = match.content
logger.trace(f"IVR status detected: {status}")
# Convert string to enum, with validation
@@ -215,7 +211,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing mode content.
"""
mode = match.text
mode = match.content
logger.debug(f"Mode detected: {mode}")
if mode == "conversation":
await self._handle_conversation()

View File

@@ -12,7 +12,6 @@ and LLM processing.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
@@ -338,14 +337,11 @@ class TextFrame(DataFrame):
# mandatory fields of theirs to have defaults to preserve
# non-default-before-default argument order)
includes_inter_frame_spaces: bool = field(init=False)
# Whether this text frame should be appended to the LLM context.
append_to_context: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.includes_inter_frame_spaces = False
self.append_to_context = True
def __str__(self):
pts = format_pts(self.pts)
@@ -359,32 +355,8 @@ class LLMTextFrame(TextFrame):
pass
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class AggregatedTextFrame(TextFrame):
"""Text frame representing an aggregation of TextFrames.
This frame contains multiple TextFrames aggregated together for processing
or output along with a field to indicate how they are aggregated.
Parameters:
aggregated_by: Method used to aggregate the text frames.
"""
aggregated_by: AggregationType | str
@dataclass
class TTSTextFrame(AggregatedTextFrame):
class TTSTextFrame(TextFrame):
"""Text frame generated by Text-to-Speech services."""
pass

View File

@@ -1001,7 +1001,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started or not frame.append_to_context:
if not self._started:
return
if self._params.expect_stripped_words:

View File

@@ -814,7 +814,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started or not frame.append_to_context:
if not self._started:
return
# Make sure we really have text (spaces count, too!)

View File

@@ -1,106 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""LLM text processor module for processing and aggregating raw LLM output text.
This processor will convert LLMTextFrames into AggregatedTextFrames based on the
configured text aggregator. Using the customizable aggregator, it provides
functionality to handle or manipulate LLM text frames before they are sent to other
components such as TTS services or context aggregators. It can be used to pre-aggregate
and categorize, modify, or filter direct output tokens from the LLM.
"""
from typing import Optional
from pipecat.frames.frames import (
AggregatedTextFrame,
EndFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMTextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
class LLMTextProcessor(FrameProcessor):
"""A processor for handling or manipulating LLM text frames before they are processed further.
This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured
text aggregator. Using the customizable aggregator, it provides functionality to handle or
manipulate LLM text frames before they are sent to other components such as TTS services or
context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct
output tokens from the LLM.
"""
def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs):
"""Initialize the LLM text processor.
Args:
text_aggregator: An optional text aggregator to use for processing LLM text frames. By
default, a SimpleTextAggregator aggregating by sentence will be used.
**kwargs: Additional arguments passed to parent class.
TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters).
"""
super().__init__(**kwargs)
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
await self._handle_interruption(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame.skip_tts)
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
await self._handle_llm_end()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _handle_interruption(self, _):
"""Handle interruptions by resetting the text aggregator."""
await self._text_aggregator.handle_interruption()
async def reset(self):
"""Reset the internal state of the text processor and its aggregator."""
await self._text_aggregator.reset()
async def _handle_llm_text(self, in_frame: LLMTextFrame):
aggregation = await self._text_aggregator.aggregate(in_frame.text)
if aggregation:
out_frame = AggregatedTextFrame(
text=aggregation.text,
aggregated_by=aggregation.type,
)
out_frame.skip_tts = in_frame.skip_tts
await self.push_frame(out_frame)
async def _handle_llm_end(self, skip_tts: bool = False):
# Flush any remaining aggregated text at the end of the LLM response
aggregation = self._text_aggregator.text
await self._text_aggregator.reset()
text = aggregation.text.strip()
if text:
out_frame = AggregatedTextFrame(
text=text,
aggregated_by=aggregation.type,
)
out_frame.skip_tts = skip_tts
await self.push_frame(out_frame)

View File

@@ -24,7 +24,6 @@ from typing import (
Literal,
Mapping,
Optional,
Tuple,
Union,
)
@@ -33,8 +32,6 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
from pipecat.audio.utils import calculate_audio_volume
from pipecat.frames.frames import (
AggregatedTextFrame,
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -707,29 +704,6 @@ class RTVITextMessageData(BaseModel):
text: str
class RTVIBotOutputMessageData(RTVITextMessageData):
"""Data for bot output RTVI messages.
Extends RTVITextMessageData to include metadata about the output.
"""
spoken: bool = False # Indicates if the text has been spoken by TTS
aggregated_by: AggregationType | str
# Indicates what form the text is in (e.g., by word, sentence, etc.)
class RTVIBotOutputMessage(BaseModel):
"""Message containing bot output text.
An event meant to holistically represent what the bot is outputting,
along with metadata about the output and if it has been spoken.
"""
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
type: Literal["bot-output"] = "bot-output"
data: RTVIBotOutputMessageData
class RTVIBotTranscriptionMessage(BaseModel):
"""Message containing bot transcription text.
@@ -922,7 +896,6 @@ class RTVIObserverParams:
Parameter `errors_enabled` is deprecated. Error messages are always enabled.
Parameters:
bot_output_enabled: Indicates if bot output messages should be sent.
bot_llm_enabled: Indicates if the bot's LLM messages should be sent.
bot_tts_enabled: Indicates if the bot's TTS messages should be sent.
bot_speaking_enabled: Indicates if the bot's started/stopped speaking messages should be sent.
@@ -934,17 +907,9 @@ class RTVIObserverParams:
metrics_enabled: Indicates if metrics messages should be sent.
system_logs_enabled: Indicates if system logs should be sent.
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
Note: if using this to avoid sending secure information, be sure to also disable
bot_llm_enabled to avoid leaking through LLM messages.
bot_output_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
audio_level_period_secs: How often audio levels should be sent if enabled.
"""
bot_output_enabled: bool = True
bot_llm_enabled: bool = True
bot_tts_enabled: bool = True
bot_speaking_enabled: bool = True
@@ -956,15 +921,6 @@ class RTVIObserverParams:
metrics_enabled: bool = True
system_logs_enabled: bool = False
errors_enabled: Optional[bool] = None
skip_aggregator_types: Optional[List[AggregationType | str]] = None
bot_output_transforms: Optional[
List[
Tuple[
AggregationType | str,
Callable[[str, AggregationType | str], Awaitable[str]],
]
]
] = None
audio_level_period_secs: float = 0.15
@@ -1017,45 +973,8 @@ class RTVIObserver(BaseObserver):
DeprecationWarning,
)
self._aggregation_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = self._params.bot_output_transforms or []
def add_bot_output_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" to
handle all text before sending to the client.
"""
self._aggregation_transforms.append((aggregation_type, transform_function))
def remove_bot_output_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._aggregation_transforms = [
(agg_type, func)
for agg_type, func in self._aggregation_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _logger_sink(self, message):
"""Logger sink so we can send system logs to RTVI clients."""
"""Logger sink so we cna send system logs to RTVI clients."""
message = RTVISystemLogMessage(data=RTVITextMessageData(text=message))
await self.send_rtvi_message(message)
@@ -1129,15 +1048,12 @@ class RTVIObserver(BaseObserver):
await self.send_rtvi_message(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
elif isinstance(frame, AggregatedTextFrame) and (
self._params.bot_output_enabled or self._params.bot_tts_enabled
):
if isinstance(frame, TTSTextFrame) and not isinstance(src, BaseOutputTransport):
# This check is to make sure we handle the frame when it has gone
# through the transport and has correct timing.
mark_as_seen = False
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
if isinstance(src, BaseOutputTransport):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self.send_rtvi_message(message)
else:
await self._handle_aggregated_llm_text(frame)
mark_as_seen = False
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
await self._handle_metrics(frame)
elif isinstance(frame, RTVIServerMessageFrame):
@@ -1168,6 +1084,15 @@ class RTVIObserver(BaseObserver):
if mark_as_seen:
self._frames_seen.add(frame.id)
async def _push_bot_transcription(self):
"""Push accumulated bot transcription as a message."""
if len(self._bot_transcription) > 0:
message = RTVIBotTranscriptionMessage(
data=RTVITextMessageData(text=self._bot_transcription)
)
await self.send_rtvi_message(message)
self._bot_transcription = ""
async def _handle_interruptions(self, frame: Frame):
"""Handle user speaking interruption frames."""
message = None
@@ -1190,45 +1115,14 @@ class RTVIObserver(BaseObserver):
if message:
await self.send_rtvi_message(message)
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
"""Handle aggregated LLM text output frames."""
# Skip certain aggregator types if configured to do so.
if (
self._params.skip_aggregator_types
and frame.aggregated_by in self._params.skip_aggregator_types
):
return
text = frame.text
type = frame.aggregated_by
for aggregation_type, transform in self._aggregation_transforms:
if aggregation_type == type or aggregation_type == "*":
text = await transform(text, type)
isTTS = isinstance(frame, TTSTextFrame)
if self._params.bot_output_enabled:
message = RTVIBotOutputMessage(
data=RTVIBotOutputMessageData(text=text, spoken=isTTS, aggregated_by=type)
)
await self.send_rtvi_message(message)
if isTTS and self._params.bot_tts_enabled:
tts_message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=text))
await self.send_rtvi_message(tts_message)
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
"""Handle LLM text output frames."""
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
await self.send_rtvi_message(message)
# TODO (mrkb): Remove all this logic when we fully deprecate bot-transcription messages.
self._bot_transcription += frame.text
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
await self.send_rtvi_message(
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
)
self._bot_transcription = ""
if match_endofsentence(self._bot_transcription):
await self._push_bot_transcription()
async def _handle_user_transcriptions(self, frame: Frame):
"""Handle user transcription frames."""
@@ -1354,7 +1248,7 @@ class RTVIProcessor(FrameProcessor):
# Default to 0.3.0 which is the last version before actually having a
# "client-version".
self._client_version = [0, 3, 0]
self._llm_skip_tts: bool = False # Keep in sync with llm_service.py's configuration.
self._skip_tts: bool = False # Keep in sync with llm_service.py
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
@@ -1547,7 +1441,7 @@ class RTVIProcessor(FrameProcessor):
elif isinstance(frame, RTVIActionFrame):
await self._action_queue.put(frame)
elif isinstance(frame, LLMConfigureOutputFrame):
self._llm_skip_tts = frame.skip_tts
self._skip_tts = frame.skip_tts
await self.push_frame(frame, direction)
# Other frames
else:
@@ -1803,9 +1697,9 @@ class RTVIProcessor(FrameProcessor):
opts = data.options if data.options is not None else RTVISendTextOptions()
if opts.run_immediately:
await self.interrupt_bot()
cur_llm_skip_tts = self._llm_skip_tts
cur_skip_tts = self._skip_tts
should_skip_tts = not opts.audio_response
toggle_skip_tts = cur_llm_skip_tts != should_skip_tts
toggle_skip_tts = cur_skip_tts != should_skip_tts
if toggle_skip_tts:
output_frame = LLMConfigureOutputFrame(skip_tts=should_skip_tts)
await self.push_frame(output_frame)
@@ -1815,7 +1709,7 @@ class RTVIProcessor(FrameProcessor):
)
await self.push_frame(text_frame)
if toggle_skip_tts:
output_frame = LLMConfigureOutputFrame(skip_tts=cur_llm_skip_tts)
output_frame = LLMConfigureOutputFrame(skip_tts=cur_skip_tts)
await self.push_frame(output_frame)
async def _handle_update_context(self, data: RTVIAppendToContextData):

View File

@@ -152,6 +152,9 @@ class AIService(FrameProcessor):
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() may wait on tasks to
# finish and downstream processors need to receive the EndFrame.
await self.push_frame(frame, direction)
await self.stop(frame)
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):

View File

@@ -27,7 +27,6 @@ from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -1028,7 +1027,7 @@ class AWSNovaSonicLLMService(LLMService):
logger.debug(f"Assistant response text added: {text}")
# Report the text of the assistant response.
frame = TTSTextFrame(text, aggregated_by=AggregationType.SENTENCE)
frame = TTSTextFrame(text)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
@@ -1063,9 +1062,7 @@ class AWSNovaSonicLLMService(LLMService):
# TTSTextFrame would be ignored otherwise (the interruption frame
# would have cleared the assistant aggregator state).
await self.push_frame(LLMFullResponseStartFrame())
frame = TTSTextFrame(
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
)
frame = TTSTextFrame(self._assistant_text_buffer)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
self._may_need_repush_assistant_text = False

View File

@@ -10,8 +10,7 @@ import base64
import json
import uuid
import warnings
from enum import Enum
from typing import AsyncGenerator, List, Literal, Optional
from typing import AsyncGenerator, List, Literal, Optional, Union
from loguru import logger
from pydantic import BaseModel, Field
@@ -126,72 +125,6 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
class CartesiaEmotion(str, Enum):
"""Predefined Emotions supported by Cartesia."""
# Primary emotions supported by Cartesia
NEUTRAL = "neutral"
ANGRY = "angry"
EXCITED = "excited"
CONTENT = "content"
SAD = "sad"
SCARED = "scared"
# Additional emotions supported by Cartesia
HAPPY = "happy"
ENTHUSIASTIC = "enthusiastic"
ELATED = "elated"
EUPHORIC = "euphoric"
TRIUMPHANT = "triumphant"
AMAZED = "amazed"
SURPRISED = "surprised"
FLIRTATIOUS = "flirtatious"
JOKING_COMEDIC = "joking/comedic"
CURIOUS = "curious"
PEACEFUL = "peaceful"
SERENE = "serene"
CALM = "calm"
GRATEFUL = "grateful"
AFFECTIONATE = "affectionate"
TRUST = "trust"
SYMPATHETIC = "sympathetic"
ANTICIPATION = "anticipation"
MYSTERIOUS = "mysterious"
MAD = "mad"
OUTRAGED = "outraged"
FRUSTRATED = "frustrated"
AGITATED = "agitated"
THREATENED = "threatened"
DISGUSTED = "disgusted"
CONTEMPT = "contempt"
ENVIOUS = "envious"
SARCASTIC = "sarcastic"
IRONIC = "ironic"
DEJECTED = "dejected"
MELANCHOLIC = "melancholic"
DISAPPOINTED = "disappointed"
HURT = "hurt"
GUILTY = "guilty"
BORED = "bored"
TIRED = "tired"
REJECTED = "rejected"
NOSTALGIC = "nostalgic"
WISTFUL = "wistful"
APOLOGETIC = "apologetic"
HESITANT = "hesitant"
INSECURE = "insecure"
CONFUSED = "confused"
RESIGNED = "resigned"
ANXIOUS = "anxious"
PANICKED = "panicked"
ALARMED = "alarmed"
PROUD = "proud"
CONFIDENT = "confident"
DISTANT = "distant"
SKEPTICAL = "skeptical"
CONTEMPLATIVE = "contemplative"
DETERMINED = "determined"
class CartesiaTTSService(AudioContextWordTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
@@ -249,10 +182,6 @@ class CartesiaTTSService(AudioContextWordTTSService):
container: Audio container format.
params: Additional input parameters for voice customization.
text_aggregator: Custom text aggregator for processing input text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
aggregate_sentences: Whether to aggregate sentences within the TTSService.
**kwargs: Additional arguments passed to the parent service.
"""
@@ -271,18 +200,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
push_text_frames=False,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator,
text_aggregator=text_aggregator or SkipTagsAggregator([("<spell>", "</spell>")]),
**kwargs,
)
if not text_aggregator:
# Always skip tags added for spelled-out text
# Note: This is primarily to support backwards compatibility.
# The preferred way of taking advantage of Cartesia SSML Tags is
# to use an LLMTextProcessor and/or a text_transformer to identify
# and insert these tags for the purpose of the TTS service alone.
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
params = params or CartesiaTTSService.InputParams()
self._api_key = api_key
@@ -336,27 +257,6 @@ class CartesiaTTSService(AudioContextWordTTSService):
"""
return language_to_cartesia_language(language)
# A set of Cartesia-specific helpers for text transformations
def SPELL(text: str) -> str:
"""Wrap text in Cartesia spell tag."""
return f"<spell>{text}</spell>"
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
"""Convenience method to create an emotion tag."""
return f'<emotion value="{emotion}" />'
def PAUSE_TAG(seconds: float) -> str:
"""Convenience method to create a pause tag."""
return f'<break time="{seconds}s" />'
def VOLUME_TAG(volume: float) -> str:
"""Convenience method to create a volume tag."""
return f'<volume ratio="{volume}" />'
def SPEED_TAG(speed: float) -> str:
"""Convenience method to create a speed tag."""
return f'<speed ratio="{speed}" />'
def _is_cjk_language(self, language: str) -> bool:
"""Check if the given language is CJK (Chinese, Japanese, Korean).

View File

@@ -27,7 +27,6 @@ from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -1647,7 +1646,7 @@ class GeminiLiveLLMService(LLMService):
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
frame = TTSTextFrame(text=text)
# Gemini Live text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True

View File

@@ -19,7 +19,6 @@ from pipecat.adapters.services.open_ai_realtime_adapter import (
OpenAIRealtimeLLMAdapter,
)
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -687,7 +686,7 @@ class OpenAIRealtimeLLMService(LLMService):
# We receive audio transcript deltas (as opposed to text deltas) when
# the output modality is "audio" (the default)
if evt.delta:
frame = TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE)
frame = TTSTextFrame(evt.delta)
# OpenAI Realtime text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)

View File

@@ -17,7 +17,6 @@ from loguru import logger
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -653,7 +652,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_audio_transcript_delta(self, evt):
if evt.delta:
await self.push_frame(LLMTextFrame(evt.delta))
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
await self.push_frame(TTSTextFrame(evt.delta))
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()

View File

@@ -113,10 +113,6 @@ class RimeTTSService(AudioContextWordTTSService):
sample_rate: Audio sample rate in Hz.
params: Additional configuration parameters.
text_aggregator: Custom text aggregator for processing input text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
aggregate_sentences: Whether to aggregate sentences within the TTSService.
**kwargs: Additional arguments passed to parent class.
"""
@@ -127,17 +123,10 @@ class RimeTTSService(AudioContextWordTTSService):
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
**kwargs,
)
if not text_aggregator:
# Always skip tags added for spelled-out text
# Note: This is primarily to support backwards compatibility.
# The preferred way of taking advantage of Rime spelling is
# to use an LLMTextProcessor and/or a text_transformer to identify
# and insert these tags for the purpose of the TTS service alone.
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
params = params or RimeTTSService.InputParams()
# Store service configuration
@@ -163,7 +152,6 @@ class RimeTTSService(AudioContextWordTTSService):
self._context_id = None # Tracks current turn
self._receive_task = None
self._cumulative_time = 0 # Accumulates time across messages
self._extra_msg_fields = {} # Extra fields for next message
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -193,31 +181,6 @@ class RimeTTSService(AudioContextWordTTSService):
self._model = model
await super().set_model(model)
# A set of Rime-specific helpers for text transformations
def SPELL(text: str) -> str:
"""Wrap text in Rime spell function."""
return f"spell({text})"
def PAUSE_TAG(seconds: float) -> str:
"""Convenience method to create a pause tag."""
return f"<{seconds * 1000}>"
def PRONOUNCE(self, text: str, word: str, phoneme: str) -> str:
"""Convenience method to support Rime's custom pronunciations feature.
https://docs.rime.ai/api-reference/custom-pronunciation
"""
self._extra_msg_fields["phonemizeBetweenBrackets"] = True
return text.replace(word, f"{phoneme}")
def INLINE_SPEED(self, text: str, speed: float) -> str:
"""Convenience method to support inline speeds."""
if not self._extra_msg_fields:
self._extra_msg_fields = {}
speed_vals = self._extra_msg_fields.get("inlineSpeedAlpha", "").split(",")
self._extra_msg_fields["inlineSpeedAlpha"] = ",".join(speed_vals + [str(speed)])
return f"[{text}]"
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""
prev_voice = self._voice_id
@@ -230,11 +193,7 @@ class RimeTTSService(AudioContextWordTTSService):
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Rime API."""
msg = {"text": text, "contextId": self._context_id}
if self._extra_msg_fields:
msg |= self._extra_msg_fields
self._extra_msg_fields = {}
return msg
return {"text": text, "contextId": self._context_id}
def _build_clear_msg(self) -> dict:
"""Build clear operation message."""

View File

@@ -12,8 +12,6 @@ from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Mapping,
@@ -25,8 +23,6 @@ from typing import (
from loguru import logger
from pipecat.frames.frames import (
AggregatedTextFrame,
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -105,16 +101,6 @@ class TTSService(AIService):
sample_rate: Optional[int] = None,
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
text_aggregator: Optional[BaseTextAggregator] = None,
# Types of text aggregations that should not be spoken.
skip_aggregator_types: Optional[List[str]] = [],
# A list of callables to transform text before just before sending it to TTS.
# Each callable takes the aggregated text and its type, and returns the transformed text.
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
text_transforms: Optional[
List[
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
]
] = None,
# Text filter executed after text has been aggregated.
text_filters: Optional[Sequence[BaseTextFilter]] = None,
text_filter: Optional[BaseTextFilter] = None,
@@ -134,16 +120,6 @@ class TTSService(AIService):
pause_frame_processing: Whether to pause frame processing during audio generation.
sample_rate: Output sample rate for generated audio.
text_aggregator: Custom text aggregator for processing incoming text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
skip_aggregator_types: List of aggregation types that should not be spoken.
text_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
text_filters: Sequence of text filters to apply after aggregation.
text_filter: Single text filter (deprecated, use text_filters).
@@ -166,21 +142,6 @@ class TTSService(AIService):
self._voice_id: str = ""
self._settings: Dict[str, Any] = {}
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
if text_aggregator:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
DeprecationWarning,
)
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
self._text_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = text_transforms or []
# TODO: Deprecate _text_filters when added to LLMTextProcessor
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
self._transport_destination: Optional[str] = transport_destination
self._tracing_enabled: bool = False
@@ -337,39 +298,6 @@ class TTSService(AIService):
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
def add_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
the function should handle all text before sending to TTS.
"""
self._text_transforms.append((aggregation_type, transform_function))
def remove_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._text_transforms = [
(agg_type, func)
for agg_type, func in self._text_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _update_settings(self, settings: Mapping[str, Any]):
for key, value in settings.items():
if key in self._settings:
@@ -425,8 +353,6 @@ class TTSService(AIService):
and frame.skip_tts
):
await self.push_frame(frame, direction)
elif isinstance(frame, AggregatedTextFrame):
await self._push_tts_frames(frame)
elif (
isinstance(frame, TextFrame)
and not isinstance(frame, InterimTranscriptionFrame)
@@ -442,10 +368,10 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
aggregate = self._text_aggregator.text
sentence = self._text_aggregator.text
await self._text_aggregator.reset()
self._processing_text = False
await self._push_tts_frames(AggregatedTextFrame(aggregate.text, aggregate.type))
await self._push_tts_frames(sentence)
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
@@ -454,7 +380,7 @@ class TTSService(AIService):
elif isinstance(frame, TTSSpeakFrame):
# Store if we were processing text or not so we can set it back.
processing_text = self._processing_text
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
await self._push_tts_frames(frame.text)
# We pause processing incoming frames because we are sending data to
# the TTS. We pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
@@ -546,24 +472,13 @@ class TTSService(AIService):
text: Optional[str] = None
if not self._aggregate_sentences:
text = frame.text
aggregated_by = "token"
else:
aggregate = await self._text_aggregator.aggregate(frame.text)
if aggregate:
text = aggregate.text
aggregated_by = aggregate.type
text = await self._text_aggregator.aggregate(frame.text)
if text:
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(AggregatedTextFrame(text, aggregated_by))
async def _push_tts_frames(self, src_frame: AggregatedTextFrame):
type = src_frame.aggregated_by
text = src_frame.text
if type in self._skip_aggregator_types:
await self.push_frame(src_frame)
return
await self._push_tts_frames(text)
async def _push_tts_frames(self, text: str):
# Remove leading newlines only
text = text.lstrip("\n")
@@ -584,39 +499,15 @@ class TTSService(AIService):
await filter.reset_interruption()
text = await filter.filter(text)
if not text.strip():
await self.stop_processing_metrics()
return
# To support use cases that may want to know the text before it's spoken, we
# push the AggregatedTextFrame version before transforming and sending to TTS.
# However, we do not want to add this text to the assistant context until it
# is spoken, so we set append_to_context to False.
src_frame.append_to_context = False
await self.push_frame(src_frame)
# Note: Text transformations are meant to only affect the text sent to the TTS for
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
# services that support word-level timestamps, this CAN affect the resulting context
# since the TTSTextFrames are generated from the TTS output stream
transformed_text = text
for aggregation_type, transform in self._text_transforms:
if aggregation_type == type or aggregation_type == "*":
transformed_text = await transform(transformed_text, type)
await self.process_generator(self.run_tts(transformed_text))
if text:
await self.process_generator(self.run_tts(text))
await self.stop_processing_metrics()
if self._push_text_frames:
# In TTS services that support word timestamps, the TTSTextFrames
# are pushed as words are spoken. However, in the case where the TTS service
# does not support word timestamps (i.e. _push_text_frames is True), we send
# the original (non-transformed) text after the TTS generation has completed.
# This way, if we are interrupted, the text is not added to the assistant
# context and the context that IS added does not include TTS-specific tags
# or transformations.
frame = TTSTextFrame(text, aggregated_by=type)
# We send the original text after the audio. This way, if we are
# interrupted, the text is not added to the assistant context.
frame = TTSTextFrame(text)
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
await self.push_frame(frame)
@@ -744,7 +635,7 @@ class WordTTSService(TTSService):
frame = TTSStoppedFrame()
frame.pts = last_pts
else:
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
frame = TTSTextFrame(word)
frame.pts = self._initial_word_timestamp + timestamp
if frame:
last_pts = frame.pts

View File

@@ -203,16 +203,8 @@ async def run_test(
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)
down_frames_printed = "["
for frame in received_down_frames:
down_frames_printed += f"{frame.__class__.__name__}, "
down_frames_printed += "]"
expected_frames_printed = "["
for frame in expected_down_frames:
expected_frames_printed += f"{frame.__name__}, "
expected_frames_printed += "]"
print("received DOWN frames =", down_frames_printed)
print("expected DOWN frames =", expected_frames_printed)
print("received DOWN frames =", received_down_frames)
print("expected DOWN frames =", expected_down_frames)
assert len(received_down_frames) == len(expected_down_frames)

View File

@@ -12,46 +12,9 @@ aggregated text should be sent for speech synthesis.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class Aggregation:
"""Data class representing aggregated text and its type.
An Aggregation object is created whenever a stream of text is aggregated by
a text aggregator. It contains the aggregated text and a type indicating
the nature of the aggregation.
Parameters:
text: The aggregated text content.
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
"""
text: str
type: str
def __str__(self) -> str:
"""Return a string representation of the aggregation.
Returns:
A descriptive string showing the type and text of the aggregation.
"""
return f"Aggregation by {self.type}: {self.text}"
class BaseTextAggregator(ABC):
"""Base class for text aggregators in the Pipecat framework.
@@ -67,7 +30,7 @@ class BaseTextAggregator(ABC):
@property
@abstractmethod
def text(self) -> Aggregation:
def text(self) -> str:
"""Get the currently aggregated text.
Subclasses must implement this property to return the text that has
@@ -79,13 +42,12 @@ class BaseTextAggregator(ABC):
pass
@abstractmethod
async def aggregate(self, text: str) -> Optional[Aggregation]:
async def aggregate(self, text: str) -> Optional[str]:
"""Aggregate the specified text with the currently accumulated text.
This method should be implemented to define how the new text contributes
to the aggregation process. It returns the aggregated text and a string
describing how it was aggregated if it's ready to be processed,
or None otherwise.
to the aggregation process. It returns the updated aggregated text if
it's ready to be processed, or None otherwise.
Subclasses should implement their specific logic for:

View File

@@ -8,41 +8,19 @@
This module provides an aggregator that identifies and processes content between
pattern pairs (like XML tags or custom delimiters) in streaming text, with
support for custom handlers and configurable actions for when a pattern is found.
support for custom handlers and configurable pattern removal.
"""
import re
from enum import Enum
from typing import Awaitable, Callable, List, Optional, Tuple
from typing import Awaitable, Callable, Optional, Tuple
from loguru import logger
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
class MatchAction(Enum):
"""Actions to take when a pattern pair is matched.
Parameters:
REMOVE: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
KEEP: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
AGGREGATE: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
"""
REMOVE = "remove"
KEEP = "keep"
AGGREGATE = "aggregate"
class PatternMatch(Aggregation):
class PatternMatch:
"""Represents a matched pattern pair with its content.
A PatternMatch object is created when a complete pattern pair is found
@@ -51,25 +29,25 @@ class PatternMatch(Aggregation):
content between the patterns.
"""
def __init__(self, content: str, type: str, full_match: str):
def __init__(self, pattern_id: str, full_match: str, content: str):
"""Initialize a pattern match.
Args:
type: The type of the matched pattern pair. It should be representative
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
pattern_id: The identifier of the matched pattern pair.
full_match: The complete text including start and end patterns.
content: The text content between the start and end patterns.
"""
super().__init__(text=content, type=type)
self.pattern_id = pattern_id
self.full_match = full_match
self.content = content
def __str__(self) -> str:
"""Return a string representation of the pattern match.
Returns:
A descriptive string showing the pattern type and content.
A descriptive string showing the pattern ID and content.
"""
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
return f"PatternMatch(id={self.pattern_id}, content={self.content})"
class PatternPairAggregator(BaseTextAggregator):
@@ -77,21 +55,16 @@ class PatternPairAggregator(BaseTextAggregator):
This aggregator buffers text until it can identify complete pattern pairs
(defined by start and end patterns), processes the content between these
patterns using registered handlers. By default, its aggregation method
returns text at sentence boundaries, and remove the content found between
any matched patterns. However, matched patterns can also be configured to
returned as a separate aggregation object containing the content between
their start and end patterns or left in, so that only the delimiters are
removed and a callback can be triggered.
This aggregator is particularly useful for processing structured content in
streaming text, such as XML tags, markdown formatting, or custom delimiters.
patterns using registered handlers, and returns text at sentence boundaries.
It's particularly useful for processing structured content in streaming text,
such as XML tags, markdown formatting, or custom delimiters.
The aggregator ensures that patterns spanning multiple text chunks are
correctly identified.
correctly identified and handles cases where patterns contain sentence
boundaries.
"""
def __init__(self, **kwargs):
def __init__(self):
"""Initialize the pattern pair aggregator.
Creates an empty aggregator with no patterns or handlers registered.
@@ -102,23 +75,16 @@ class PatternPairAggregator(BaseTextAggregator):
self._handlers = {}
@property
def text(self) -> Aggregation:
"""Get the currently aggregated text.
def text(self) -> str:
"""Get the currently buffered text.
Returns:
The text that has been accumulated in the buffer.
The current text buffer content that hasn't been processed yet.
"""
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start:
return Aggregation(self._text, pattern_start[1].get("type", AggregationType.SENTENCE))
return Aggregation(self._text, AggregationType.SENTENCE)
return self._text
def add_pattern(
self,
type: str,
start_pattern: str,
end_pattern: str,
action: MatchAction = MatchAction.REMOVE,
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -127,94 +93,41 @@ class PatternPairAggregator(BaseTextAggregator):
the end pattern, and treat the content between them as a match.
Args:
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
those are reserved for the default behavior.
pattern_id: Unique identifier for this pattern pair.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
action: What to do when a complete pattern is matched:
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
normal text. This allows you to register handlers for
the pattern without affecting the aggregation logic.
- MatchAction.AGGREGATE: Return the matched pattern as a separate
aggregation object.
remove_match: Whether to remove the matched content from the text.
Returns:
Self for method chaining.
"""
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
raise ValueError(
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
)
self._patterns[type] = {
self._patterns[pattern_id] = {
"start": start_pattern,
"end": end_pattern,
"type": type,
"action": action,
"remove_match": remove_match,
}
return self
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
):
"""Add a pattern pair to detect in the text.
.. deprecated:: 0.0.95
This function is deprecated and will be removed in a future version.
Use `add_pattern` with a type and MatchAction instead.
This method calls `add_pattern` setting type with the provided pattern_id and action
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
Args:
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
as those arereserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
"""
import warnings
with warnings.catch_warnings():
warnings.simplefilter("once")
warnings.warn(
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
" removed in a future version. Use add_pattern with a type and MatchAction instead",
DeprecationWarning,
stacklevel=2,
)
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
return self.add_pattern(
type=pattern_id,
start_pattern=start_pattern,
end_pattern=end_pattern,
action=action,
)
def on_pattern_match(
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
) -> "PatternPairAggregator":
"""Register a handler for when a pattern pair is matched.
The handler will be called whenever a complete match for the
specified type is found in the text.
specified pattern ID is found in the text.
Args:
type: The type of the pattern pair to trigger the handler.
pattern_id: ID of the pattern pair to match.
handler: Async function to call when pattern is matched.
The function should accept a PatternMatch object.
Returns:
Self for method chaining.
"""
self._handlers[type] = handler
self._handlers[pattern_id] = handler
return self
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]:
"""Process all complete pattern pairs in the text.
Searches for all complete pattern pairs in the text, calls the
@@ -224,19 +137,19 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to process for pattern matches.
Returns:
Tuple of (all_matches, processed_text) where:
Tuple of (processed_text, was_modified) where:
- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
- processed_text is the text after processing patterns
- was_modified indicates whether any changes were made
"""
all_matches = []
processed_text = text
modified = False
for type, pattern_info in self._patterns.items():
for pattern_id, pattern_info in self._patterns.items():
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
action = pattern_info["action"]
remove_match = pattern_info["remove_match"]
# Create regex to match from start pattern to end pattern
# The .*? is non-greedy to handle nested patterns
@@ -251,24 +164,25 @@ class PatternPairAggregator(BaseTextAggregator):
full_match = match.group(0) # Full match including patterns
# Create pattern match object
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
pattern_match = PatternMatch(
pattern_id=pattern_id, full_match=full_match, content=content
)
# Call the appropriate handler if registered
if type in self._handlers:
if pattern_id in self._handlers:
try:
await self._handlers[type](pattern_match)
await self._handlers[pattern_id](pattern_match)
except Exception as e:
logger.error(f"Error in pattern handler for {type}: {e}")
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
# Remove the pattern from the text if configured
if action == MatchAction.REMOVE:
if remove_match:
processed_text = processed_text.replace(full_match, "", 1)
else:
all_matches.append(pattern_match)
modified = True
return all_matches, processed_text
return processed_text, modified
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
def _has_incomplete_patterns(self, text: str) -> bool:
"""Check if text contains incomplete pattern pairs.
Determines whether the text contains any start patterns without
@@ -278,10 +192,9 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to check for incomplete patterns.
Returns:
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
or None if no patterns are found or all patterns are complete.
True if there are incomplete patterns, False otherwise.
"""
for type, pattern_info in self._patterns.items():
for pattern_id, pattern_info in self._patterns.items():
start = pattern_info["start"]
end = pattern_info["end"]
@@ -290,16 +203,12 @@ class PatternPairAggregator(BaseTextAggregator):
end_count = text.count(end)
# If there are more starts than ends, we have incomplete patterns
# Again, this is written generically but there only ever should
# be one pattern active at a time, so the counts should be 0 or 1.
# Which is why we base the return on the first found.
if start_count > end_count:
start_index = text.find(start)
return [start_index, pattern_info]
return True
return None
return False
async def aggregate(self, text: str) -> Optional[PatternMatch]:
async def aggregate(self, text: str) -> Optional[str]:
"""Aggregate text and process pattern pairs.
This method adds the new text to the buffer, processes any complete pattern
@@ -318,34 +227,16 @@ class PatternPairAggregator(BaseTextAggregator):
self._text += text
# Process any complete patterns in the buffer
patterns, processed_text = await self._process_complete_patterns(self._text)
processed_text, modified = await self._process_complete_patterns(self._text)
self._text = processed_text
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
return patterns[0]
# Only update the buffer if modifications were made
if modified:
self._text = processed_text
# Check if we have incomplete patterns
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
return None
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
if self._has_incomplete_patterns(self._text):
# Still waiting for complete patterns
return None
# Find sentence boundary if no incomplete patterns
eos_marker = match_endofsentence(self._text)
@@ -353,7 +244,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
return result
# No complete sentence found yet
return None

View File

@@ -14,7 +14,7 @@ text processing scenarios.
from typing import Optional
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
class SimpleTextAggregator(BaseTextAggregator):
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
self._text = ""
@property
def text(self) -> Aggregation:
def text(self) -> str:
"""Get the currently aggregated text.
Returns:
The text that has been accumulated in the buffer.
"""
return Aggregation(self._text, AggregationType.SENTENCE)
return self._text
async def aggregate(self, text: str) -> Optional[Aggregation]:
async def aggregate(self, text: str) -> Optional[str]:
"""Aggregate text and return completed sentences.
Adds the new text to the buffer and checks for end-of-sentence markers.
@@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator):
result = self._text[:eos_end_marker]
self._text = self._text[eos_end_marker:]
return Aggregation(result, AggregationType.SENTENCE) if result else None
return result
async def handle_interruption(self):
"""Handle interruptions by clearing the text buffer.

View File

@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
from typing import Optional, Sequence
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
class SkipTagsAggregator(BaseTextAggregator):
@@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator):
Returns:
The current text buffer content that hasn't been processed yet.
"""
return Aggregation(self._text, AggregationType.SENTENCE)
return self._text
async def aggregate(self, text: str) -> Optional[Aggregation]:
async def aggregate(self, text: str) -> Optional[str]:
"""Aggregate text while respecting tag boundaries.
This method adds the new text to the buffer, processes any complete
@@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return Aggregation(result, AggregationType.SENTENCE)
return result
# No complete sentence found yet
return None

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from opentelemetry import context as context_api
from opentelemetry import trace
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.utils.tracing.service_attributes import (
add_gemini_live_span_attributes,
@@ -399,11 +399,6 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
if hasattr(self, "get_llm_adapter"):
adapter = self.get_llm_adapter()
messages = adapter.get_messages_for_logging(context)
elif hasattr(context, "get_messages"):
# Fallback for unknown context types
messages = context.get_messages()
elif hasattr(context, "messages"):
messages = context.messages
# Serialize messages if available
if messages:
@@ -424,15 +419,10 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
if hasattr(self, "get_llm_adapter") and hasattr(context, "tools"):
adapter = self.get_llm_adapter()
tools = adapter.from_standard_tools(context.tools)
elif hasattr(context, "tools"):
# Fallback for unknown context types
tools = context.tools
# Serialize and count tools if available
# Check if tools is not None and not NOT_GIVEN (using attribute check as fallback)
if tools is not None and not (
hasattr(tools, "__name__") and tools.__name__ == "NOT_GIVEN"
):
# Check if tools is not None and not NOT_GIVEN
if tools is not None and tools is not NOT_GIVEN:
serialized_tools = json.dumps(tools)
tool_count = len(tools) if isinstance(tools, list) else 1

View File

@@ -7,42 +7,30 @@
import unittest
from unittest.mock import AsyncMock
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.aggregator = PatternPairAggregator()
self.test_handler = AsyncMock()
self.code_handler = AsyncMock()
# Add a test pattern
self.aggregator.add_pattern_pair(
pattern_id="test_pattern",
start_pattern="<test>",
end_pattern="</test>",
)
self.aggregator.add_pattern(
type="code_pattern",
start_pattern="<code>",
end_pattern="</code>",
action=MatchAction.AGGREGATE,
remove_match=True,
)
# Register the mock handler
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
async def test_pattern_match_and_removal(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.type, "test_pattern")
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</test>!")
@@ -51,49 +39,20 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.type, "test_pattern")
self.assertEqual(call_args.pattern_id, "test_pattern")
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(call_args.content, "pattern content")
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
self.assertEqual(result.text, "Hello !")
self.assertEqual(result.type, "sentence")
self.assertEqual(result, "Hello !")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
self.assertEqual(result, " This is another sentence.")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
async def test_pattern_match_and_aggregate(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Here is code <code>pattern")
self.assertEqual(result.text, "Here is code ")
self.assertEqual(self.aggregator.text.text, "<code>pattern")
self.assertEqual(self.aggregator.text.type, "code_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</code>")
# Verify the handler was called with correct PatternMatch object
self.code_handler.assert_called_once()
call_args = self.code_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.type, "code_pattern")
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(result.text, "pattern content")
self.assertEqual(result.type, "code_pattern")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
self.assertEqual(result.type, "sentence")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text, "")
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
@@ -106,30 +65,26 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_not_called()
# Buffer should contain the incomplete text
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.type, "test_pattern")
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
# Reset and confirm buffer is cleared
await self.aggregator.reset()
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text, "")
async def test_multiple_patterns(self):
# Set up multiple patterns and handlers
voice_handler = AsyncMock()
emphasis_handler = AsyncMock()
self.aggregator.add_pattern(
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
action=MatchAction.REMOVE,
self.aggregator.add_pattern_pair(
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
)
self.aggregator.add_pattern(
type="emphasis",
self.aggregator.add_pattern_pair(
pattern_id="emphasis",
start_pattern="<em>",
end_pattern="</em>",
action=MatchAction.KEEP, # Keep emphasis tags
remove_match=False, # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)
@@ -142,19 +97,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Both handlers should be called with correct data
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
self.assertEqual(voice_match.type, "voice")
self.assertEqual(voice_match.text, "female")
self.assertEqual(voice_match.pattern_id, "voice")
self.assertEqual(voice_match.content, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
self.assertEqual(emphasis_match.type, "emphasis")
self.assertEqual(emphasis_match.text, "very")
self.assertEqual(emphasis_match.pattern_id, "emphasis")
self.assertEqual(emphasis_match.content, "very")
# Voice pattern should be removed, emphasis pattern should remain
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
# Buffer should be empty
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text, "")
async def test_handle_interruption(self):
# Start with incomplete pattern
@@ -165,7 +120,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
await self.aggregator.handle_interruption()
# Buffer should be cleared
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text, "")
# Handler should not have been called
self.test_handler.assert_not_called()
@@ -183,10 +138,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Handler should be called with entire content
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
# Pattern should be removed, resulting in text with sentences merged
self.assertEqual(result.text, "Hello Final sentence.")
self.assertEqual(result, "Hello Final sentence.")
# Buffer should be empty
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text, "")

View File

@@ -13,7 +13,6 @@ import pytest
from aiohttp import web
from pipecat.frames.frames import (
AggregatedTextFrame,
ErrorFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
@@ -75,7 +74,6 @@ async def test_run_piper_tts_success(aiohttp_client):
]
expected_returned_frames = [
AggregatedTextFrame,
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
@@ -123,7 +121,7 @@ async def test_run_piper_tts_error(aiohttp_client):
TTSSpeakFrame(text="Error case."),
]
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]

View File

@@ -15,20 +15,15 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
assert self.aggregator.text.text == "Hello "
assert self.aggregator.text == "Hello "
await self.aggregator.reset()
assert self.aggregator.text.text == ""
assert self.aggregator.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
aggregate = await self.aggregator.aggregate("Pipecat!")
assert aggregate.text == "Hello Pipecat!"
assert aggregate.type == "sentence"
assert self.aggregator.text.text == ""
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
assert self.aggregator.text == ""
async def test_multiple_sentences(self):
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
assert aggregate.text == "Hello Pipecat!"
assert self.aggregator.text.text == " How are "
aggregate = await self.aggregator.aggregate("you?")
assert aggregate.text == " How are you?"
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
assert await self.aggregator.aggregate("you?") == " How are you?"

View File

@@ -18,18 +18,16 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# No tags involved, aggregate at end of sentence.
result = await self.aggregator.aggregate("Hello Pipecat!")
self.assertEqual(result.text, "Hello Pipecat!")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(result, "Hello Pipecat!")
self.assertEqual(self.aggregator.text, "")
async def test_basic_tags(self):
await self.aggregator.reset()
# Tags involved, avoid aggregation during tags.
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
async def test_streaming_tags(self):
await self.aggregator.reset()
@@ -37,22 +35,20 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# Tags involved, stream small chunk of texts.
result = await self.aggregator.aggregate("My email is <sp")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <sp")
self.assertEqual(self.aggregator.text, "My email is <sp")
result = await self.aggregator.aggregate("ell>foo.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
result = await self.aggregator.aggregate("bar@pipecat.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
result = await self.aggregator.aggregate("ai</spe")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.type, "sentence")
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
result = await self.aggregator.aggregate("ll>.")
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text.type, "sentence")
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")

View File

@@ -11,7 +11,6 @@ from datetime import datetime, timezone
from typing import List, Tuple, cast
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -131,11 +130,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(), # Wait for StartedSpeaking to process
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
TTSTextFrame(text="How"),
TTSTextFrame(text="are"),
TTSTextFrame(text="you?"),
SleepFrame(), # Wait for text frames to queue
BotStoppedSpeakingFrame(),
]
@@ -196,9 +195,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
TTSTextFrame(text=""), # Empty text
TTSTextFrame(text=" "), # Just whitespace
TTSTextFrame(text="\n"), # Just newline
BotStoppedSpeakingFrame(),
# Pipeline ends here; run_test will automatically send EndFrame
]
@@ -236,14 +235,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
SleepFrame(),
InterruptionFrame(), # User interrupts here
SleepFrame(),
BotStartedSpeakingFrame(),
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="New"),
TTSTextFrame(text="response"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
@@ -300,8 +299,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world"),
# Pipeline ends here; run_test will automatically send EndFrame
]
@@ -339,8 +338,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world"),
SleepFrame(), # Ensure messages are processed
CancelFrame(),
]
@@ -402,8 +401,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="Assistant"),
TTSTextFrame(text="message"),
BotStoppedSpeakingFrame(),
]
@@ -440,7 +439,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
# Test the specific pattern shared
def make_tts_text_frame(text: str) -> TTSTextFrame:
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
frame = TTSTextFrame(text=text)
frame.includes_inter_frame_spaces = True
return frame

6
uv.lock generated
View File

@@ -36,12 +36,12 @@ wheels = [
[[package]]
name = "aic-sdk"
version = "1.0.2"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/51/90/b02e853e863c303f8456c689b42ac24ad403b781adc9642d0a91ed4bed7e/aic_sdk-1.0.2.tar.gz", hash = "sha256:239097dd3aaa8a8a0fd7542b75d2510cb34144caec796370639b7c636acbc56e", size = 32059, upload-time = "2025-08-24T09:20:03.9Z" }
sdist = { url = "https://files.pythonhosted.org/packages/99/83/bf38b95d98c67b8ebc574fb4a4f23c07a3740b51992d7522976173d30b98/aic_sdk-1.1.0.tar.gz", hash = "sha256:04e08df695581c8cb4db8acca20e73815e9f449e7bd08e0162fd55518c727963", size = 34954, upload-time = "2025-11-11T20:45:24.25Z" }
[[package]]
name = "aioboto3"
@@ -4647,7 +4647,7 @@ docs = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'moondream'", specifier = "~=1.10.0" },
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.0.1" },
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.1.0" },
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.0.0" },
{ name = "aiofiles", specifier = ">=24.1.0,<25" },
{ name = "aiohttp", specifier = ">=3.11.12,<4" },