From fe9aa3383e111dfb2fd0884a66a1cd231175d356 Mon Sep 17 00:00:00 2001 From: mattie ruth backman Date: Tue, 21 Oct 2025 12:16:01 -0400 Subject: [PATCH] Adding support for new bot-output RTVI Message: 1. TTSTextFrames now include metadata about whether the text was spoken or not along with a type string to describe what the text represents: ex. "sentence", "word", "custom aggregation" 2. Expanded how aggregators work so that the aggregate method returns aggregated text along with the type of aggregation used to create it 3. Deprecated the RTVI bot-transcription event in lieu of... 4. Introduced support for a new bot-output event. This event is meant to be the one stop shop for communicating what the bot actually "says". It is based off TTSTextFrames to communicate both sentence by sentence (or whatever aggregation is used) as well as word by word. In addition, it will include LLMTextFrames, aggregated by sentence when tts is turned off (i.e. skip_tts is true). Resolves pipecat-ai/pipecat-client-web#158 --- src/pipecat/frames/frames.py | 3 + src/pipecat/processors/frameworks/rtvi.py | 81 ++++++++++++++- src/pipecat/services/aws/nova_sonic/llm.py | 6 +- .../services/google/gemini_live/llm.py | 2 +- src/pipecat/services/openai/realtime/llm.py | 2 +- .../services/openai_realtime_beta/openai.py | 2 +- src/pipecat/services/tts_service.py | 71 ++++++++----- .../utils/text/base_text_aggregator.py | 38 ++++++- .../utils/text/pattern_pair_aggregator.py | 99 ++++++++++++------- .../utils/text/simple_text_aggregator.py | 10 +- .../utils/text/skip_tags_aggregator.py | 8 +- tests/test_transcript_processor.py | 38 +++---- 12 files changed, 259 insertions(+), 101 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 6f48f79f7..4af13003d 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -359,6 +359,9 @@ class LLMTextFrame(TextFrame): class TTSTextFrame(TextFrame): """Text frame generated by Text-to-Speech services.""" + aggregated_by: Literal["sentence", "word"] | str + spoken: Optional[bool] = True # Whether this text has been spoken by TTS + pass diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index f04cbd395..18418081f 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -704,6 +704,29 @@ class RTVITextMessageData(BaseModel): text: str +class RTVIBotOutputMessageData(RTVITextMessageData): + """Data for bot output RTVI messages. + + Extends RTVITextMessageData to include metadata about the output. + """ + + spoken: bool = True # Indicates if the text has been spoken by TTS + aggregated_by: Optional[Literal["word", "sentence"] | str] = None + # Indicates what form the text is in (e.g., by word, sentence, etc.) + + +class RTVIBotOutputMessage(BaseModel): + """Message containing bot output text. + + An event meant to wholistically represent what the bot is outputting, + along with metadata about the output and if it has been spoken. + """ + + label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL + type: Literal["bot-output"] = "bot-output" + data: RTVIBotOutputMessageData + + class RTVIBotTranscriptionMessage(BaseModel): """Message containing bot transcription text. @@ -960,6 +983,8 @@ class RTVIObserver(BaseObserver): self._last_user_audio_level = 0 self._last_bot_audio_level = 0 + self._skip_tts = None + if self._params.system_logs_enabled: self._system_logger_id = logger.add(self._logger_sink) @@ -1050,8 +1075,7 @@ class RTVIObserver(BaseObserver): await self.send_rtvi_message(RTVIBotTTSStoppedMessage()) elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled: if isinstance(src, BaseOutputTransport): - message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) - await self.send_rtvi_message(message) + await self._handle_tts_text_frame(frame) else: mark_as_seen = False elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled: @@ -1115,14 +1139,63 @@ class RTVIObserver(BaseObserver): if message: await self.send_rtvi_message(message) + async def _handle_tts_text_frame(self, frame: TTSTextFrame): + """Handle TTS text output frames.""" + # send the tts-text message + message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) + await self.send_rtvi_message(message) + # send the bot-output message + message = RTVIBotOutputMessage( + data=RTVIBotOutputMessageData( + text=frame.text, spoken=frame.spoken, aggregated_by=frame.aggregated_by + ) + ) + await self.send_rtvi_message(message) + async def _handle_llm_text_frame(self, frame: LLMTextFrame): """Handle LLM text output frames.""" message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text)) await self.send_rtvi_message(message) + # initialize skip_tts on first LLMTextFrame + if self._skip_tts is None: + self._skip_tts = frame.skip_tts + + messages = [] + should_reset_transcription = False self._bot_transcription += frame.text - if match_endofsentence(self._bot_transcription): - await self._push_bot_transcription() + + if not frame.skip_tts and self._skip_tts: + # We just switched from skipping TTS to not skipping TTS. + # Send and reset any existing transcription. + if len(self._bot_transcription) > 0: + message.append( + RTVIBotOutputMessage( + data=RTVIBotOutputMessageData( + text=self._bot_transcription, spoken=False, aggregated_by="sentence" + ) + ) + ) + should_reset_transcription = True + + if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0: + messages.append( + RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription)) + ) + if frame.skip_tts: + messages.append( + RTVIBotOutputMessage( + data=RTVIBotOutputMessageData( + text=self._bot_transcription, spoken=False, aggregated_by="sentence" + ) + ) + ) + should_reset_transcription = True + + for msg in messages: + await self.send_rtvi_message(msg) + if should_reset_transcription: + self._bot_transcription = "" async def _handle_user_transcriptions(self, frame: Frame): """Handle user transcription frames.""" diff --git a/src/pipecat/services/aws/nova_sonic/llm.py b/src/pipecat/services/aws/nova_sonic/llm.py index 2572b03cb..1a50fc22f 100644 --- a/src/pipecat/services/aws/nova_sonic/llm.py +++ b/src/pipecat/services/aws/nova_sonic/llm.py @@ -1027,7 +1027,7 @@ class AWSNovaSonicLLMService(LLMService): logger.debug(f"Assistant response text added: {text}") # Report the text of the assistant response. - frame = TTSTextFrame(text) + frame = TTSTextFrame(text, aggregated_by="sentence", spoken=True) frame.includes_inter_frame_spaces = True await self.push_frame(frame) @@ -1062,7 +1062,9 @@ class AWSNovaSonicLLMService(LLMService): # TTSTextFrame would be ignored otherwise (the interruption frame # would have cleared the assistant aggregator state). await self.push_frame(LLMFullResponseStartFrame()) - frame = TTSTextFrame(self._assistant_text_buffer) + frame = TTSTextFrame( + self._assistant_text_buffer, aggregated_by="sentence", spoken=True + ) frame.includes_inter_frame_spaces = True await self.push_frame(frame) self._may_need_repush_assistant_text = False diff --git a/src/pipecat/services/google/gemini_live/llm.py b/src/pipecat/services/google/gemini_live/llm.py index 11632968e..ed28298ea 100644 --- a/src/pipecat/services/google/gemini_live/llm.py +++ b/src/pipecat/services/google/gemini_live/llm.py @@ -1646,7 +1646,7 @@ class GeminiLiveLLMService(LLMService): await self.push_frame(TTSStartedFrame()) await self.push_frame(LLMFullResponseStartFrame()) - frame = TTSTextFrame(text=text) + frame = TTSTextFrame(text=text, aggregated_by="sentence") # Gemini Live text already includes any necessary inter-chunk spaces frame.includes_inter_frame_spaces = True diff --git a/src/pipecat/services/openai/realtime/llm.py b/src/pipecat/services/openai/realtime/llm.py index f66e6e8e1..1d29908ba 100644 --- a/src/pipecat/services/openai/realtime/llm.py +++ b/src/pipecat/services/openai/realtime/llm.py @@ -686,7 +686,7 @@ class OpenAIRealtimeLLMService(LLMService): # We receive audio transcript deltas (as opposed to text deltas) when # the output modality is "audio" (the default) if evt.delta: - frame = TTSTextFrame(evt.delta) + frame = TTSTextFrame(evt.delta, aggregated_by="sentence") # OpenAI Realtime text already includes any necessary inter-chunk spaces frame.includes_inter_frame_spaces = True await self.push_frame(frame) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index af0600882..d67e58cf8 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -652,7 +652,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): async def _handle_evt_audio_transcript_delta(self, evt): if evt.delta: await self.push_frame(LLMTextFrame(evt.delta)) - await self.push_frame(TTSTextFrame(evt.delta)) + await self.push_frame(TTSTextFrame(evt.delta, aggregated_by="sentence", spoken=True)) async def _handle_evt_speech_started(self, evt): await self._truncate_current_audio_response() diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 29c54f497..fa7956c81 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -101,6 +101,8 @@ class TTSService(AIService): sample_rate: Optional[int] = None, # Text aggregator to aggregate incoming tokens and decide when to push to the TTS. text_aggregator: Optional[BaseTextAggregator] = None, + # Types of text aggregations that should not be spoken. + skip_aggregator_types: Optional[List[str]] = [], # Text filter executed after text has been aggregated. text_filters: Optional[Sequence[BaseTextFilter]] = None, text_filter: Optional[BaseTextFilter] = None, @@ -120,6 +122,7 @@ class TTSService(AIService): pause_frame_processing: Whether to pause frame processing during audio generation. sample_rate: Output sample rate for generated audio. text_aggregator: Custom text aggregator for processing incoming text. + skip_aggregator_types: List of aggregation types that should not be spoken. text_filters: Sequence of text filters to apply after aggregation. text_filter: Single text filter (deprecated, use text_filters). @@ -142,6 +145,7 @@ class TTSService(AIService): self._voice_id: str = "" self._settings: Dict[str, Any] = {} self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() + self._skip_aggregator_types: List[str] = skip_aggregator_types or [] self._text_filters: Sequence[BaseTextFilter] = text_filters or [] self._transport_destination: Optional[str] = transport_destination self._tracing_enabled: bool = False @@ -368,10 +372,14 @@ class TTSService(AIService): # pause to avoid audio overlapping. await self._maybe_pause_frame_processing() - sentence = self._text_aggregator.text + aggregate = self._text_aggregator.text await self._text_aggregator.reset() self._processing_text = False - await self._push_tts_frames(sentence) + await self._push_tts_frames( + text=aggregate.text, + should_speak=aggregate.type not in self._skip_aggregator_types, + aggregated_by=aggregate.type, + ) if isinstance(frame, LLMFullResponseEndFrame): if self._push_text_frames: await self.push_frame(frame, direction) @@ -380,7 +388,7 @@ class TTSService(AIService): elif isinstance(frame, TTSSpeakFrame): # Store if we were processing text or not so we can set it back. processing_text = self._processing_text - await self._push_tts_frames(frame.text) + await self._push_tts_frames(frame.text, should_speak=True, aggregated_by="word") # We pause processing incoming frames because we are sending data to # the TTS. We pause to avoid audio overlapping. await self._maybe_pause_frame_processing() @@ -472,42 +480,51 @@ class TTSService(AIService): text: Optional[str] = None if not self._aggregate_sentences: text = frame.text + should_speak = True + aggregated_by = "token" else: - text = await self._text_aggregator.aggregate(frame.text) + aggregate = await self._text_aggregator.aggregate(frame.text) + if aggregate: + text = aggregate.text + should_speak = aggregate.type not in self._skip_aggregator_types + aggregated_by = aggregate.type if text: - await self._push_tts_frames(text) + logger.trace(f"Pushing TTS frames for text: {text}, {should_speak}, {aggregated_by}") + await self._push_tts_frames(text, should_speak, aggregated_by) - async def _push_tts_frames(self, text: str): - # Remove leading newlines only - text = text.lstrip("\n") + async def _push_tts_frames(self, text: str, should_speak: bool, aggregated_by: str): + if should_speak: + # Remove leading newlines only + text = text.lstrip("\n") - # Don't send only whitespace. This causes problems for some TTS models. But also don't - # strip all whitespace, as whitespace can influence prosody. - if not text.strip(): - return + # Don't send only whitespace. This causes problems for some TTS models. But also don't + # strip all whitespace, as whitespace can influence prosody. + if not text.strip(): + return - # This is just a flag that indicates if we sent something to the TTS - # service. It will be cleared if we sent text because of a TTSSpeakFrame - # or when we received an LLMFullResponseEndFrame - self._processing_text = True + # This is just a flag that indicates if we sent something to the TTS + # service. It will be cleared if we sent text because of a TTSSpeakFrame + # or when we received an LLMFullResponseEndFrame + self._processing_text = True - await self.start_processing_metrics() + await self.start_processing_metrics() - # Process all filter. - for filter in self._text_filters: - await filter.reset_interruption() - text = await filter.filter(text) + # Process all filter. + for filter in self._text_filters: + await filter.reset_interruption() + text = await filter.filter(text) - if text: - await self.process_generator(self.run_tts(text)) + if text: + await self.push_frame(TTSTextFrame(text, spoken=True, aggregated_by=aggregated_by)) + await self.process_generator(self.run_tts(text)) - await self.stop_processing_metrics() + await self.stop_processing_metrics() - if self._push_text_frames: + if self._push_text_frames or not should_speak: # We send the original text after the audio. This way, if we are # interrupted, the text is not added to the assistant context. - frame = TTSTextFrame(text) + frame = TTSTextFrame(text, spoken=should_speak, aggregated_by=aggregated_by) frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces await self.push_frame(frame) @@ -635,7 +652,7 @@ class WordTTSService(TTSService): frame = TTSStoppedFrame() frame.pts = last_pts else: - frame = TTSTextFrame(word) + frame = TTSTextFrame(word, spoken=True, aggregated_by="word") frame.pts = self._initial_word_timestamp + timestamp if frame: last_pts = frame.pts diff --git a/src/pipecat/utils/text/base_text_aggregator.py b/src/pipecat/utils/text/base_text_aggregator.py index 27e50fff5..5a5196920 100644 --- a/src/pipecat/utils/text/base_text_aggregator.py +++ b/src/pipecat/utils/text/base_text_aggregator.py @@ -12,9 +12,38 @@ aggregated text should be sent for speech synthesis. """ from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Optional +@dataclass +class Aggregation: + """Data class representing aggregated text and its type. + + An Aggregation object is created whenever a stream of text is aggregated by + a text aggregator. It contains the aggregated text and a type indicating + the nature of the aggregation. + """ + + def __init__(self, text: str, type: str): + """Initialize an aggregation instance. + + Args: + text: The aggregated text content. + type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation'). + """ + self.text = text + self.type = type + + def __str__(self) -> str: + """Return a string representation of the aggregation. + + Returns: + A descriptive string showing the type and text of the aggregation. + """ + return f"Aggregation by {self.type}: {self.text}" + + class BaseTextAggregator(ABC): """Base class for text aggregators in the Pipecat framework. @@ -30,7 +59,7 @@ class BaseTextAggregator(ABC): @property @abstractmethod - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently aggregated text. Subclasses must implement this property to return the text that has @@ -42,12 +71,13 @@ class BaseTextAggregator(ABC): pass @abstractmethod - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate the specified text with the currently accumulated text. This method should be implemented to define how the new text contributes - to the aggregation process. It returns the updated aggregated text if - it's ready to be processed, or None otherwise. + to the aggregation process. It returns the aggregated text and a string + describing how it was aggregated if it's ready to be processed, + or None otherwise. Subclasses should implement their specific logic for: diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index ac074f2de..fe32f5b51 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -12,15 +12,15 @@ support for custom handlers and configurable pattern removal. """ import re -from typing import Awaitable, Callable, Optional, Tuple +from typing import Awaitable, Callable, List, Optional, Tuple from loguru import logger from pipecat.utils.string import match_endofsentence -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator -class PatternMatch: +class PatternMatch(Aggregation): """Represents a matched pattern pair with its content. A PatternMatch object is created when a complete pattern pair is found @@ -29,17 +29,19 @@ class PatternMatch: content between the patterns. """ - def __init__(self, pattern_id: str, full_match: str, content: str): + def __init__(self, pattern_id: str, full_match: str, content: str, type: str): """Initialize a pattern match. Args: pattern_id: The identifier of the matched pattern pair. full_match: The complete text including start and end patterns. content: The text content between the start and end patterns. + type: The type of aggregation the matched content represents + (e.g., 'code', 'speaker', 'custom'). """ + super().__init__(text=content, type=type) self.pattern_id = pattern_id self.full_match = full_match - self.content = content def __str__(self) -> str: """Return a string representation of the pattern match. @@ -47,7 +49,7 @@ class PatternMatch: Returns: A descriptive string showing the pattern ID and content. """ - return f"PatternMatch(id={self.pattern_id}, content={self.content})" + return f"PatternMatch(id={self.pattern_id}, content={self.text}, full_match={self.full_match}, type={self.type})" class PatternPairAggregator(BaseTextAggregator): @@ -64,7 +66,7 @@ class PatternPairAggregator(BaseTextAggregator): boundaries. """ - def __init__(self): + def __init__(self, **kwargs): """Initialize the pattern pair aggregator. Creates an empty aggregator with no patterns or handlers registered. @@ -75,16 +77,24 @@ class PatternPairAggregator(BaseTextAggregator): self._handlers = {} @property - def text(self) -> str: - """Get the currently buffered text. + def text(self) -> Aggregation: + """Get the currently aggregated text. Returns: - The current text buffer content that hasn't been processed yet. + The text that has been accumulated in the buffer. """ - return self._text + start, curtype = self._match_start_of_pattern(self._text) + if curtype: + return Aggregation(self._text, curtype) + return Aggregation(self._text, "sentence") def add_pattern_pair( - self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True + self, + pattern_id: str, + start_pattern: str, + end_pattern: str, + type: str, + remove_match: bool = True, ) -> "PatternPairAggregator": """Add a pattern pair to detect in the text. @@ -96,7 +106,9 @@ class PatternPairAggregator(BaseTextAggregator): pattern_id: Unique identifier for this pattern pair. start_pattern: Pattern that marks the beginning of content. end_pattern: Pattern that marks the end of content. - remove_match: Whether to remove the matched content from the text. + type: The type of aggregation the matched content represents + (e.g., 'code', 'speaker', 'custom'). + remove_match: Whether to remove the matched content from the text returned. Returns: Self for method chaining. @@ -104,6 +116,7 @@ class PatternPairAggregator(BaseTextAggregator): self._patterns[pattern_id] = { "start": start_pattern, "end": end_pattern, + "type": type, "remove_match": remove_match, } return self @@ -127,7 +140,7 @@ class PatternPairAggregator(BaseTextAggregator): self._handlers[pattern_id] = handler return self - async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]: + async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]: """Process all complete pattern pairs in the text. Searches for all complete pattern pairs in the text, calls the @@ -137,19 +150,20 @@ class PatternPairAggregator(BaseTextAggregator): text: The text to process for pattern matches. Returns: - Tuple of (processed_text, was_modified) where: + Tuple of (all_matches, processed_text) where: - - processed_text is the text after processing patterns - - was_modified indicates whether any changes were made + - all_matches is a list of all pattern matches found. Note: There really should only ever be 1. + - processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text. """ + all_matches = [] processed_text = text - modified = False for pattern_id, pattern_info in self._patterns.items(): # Escape special regex characters in the patterns start = re.escape(pattern_info["start"]) end = re.escape(pattern_info["end"]) remove_match = pattern_info["remove_match"] + match_type = pattern_info["type"] # Create regex to match from start pattern to end pattern # The .*? is non-greedy to handle nested patterns @@ -165,7 +179,7 @@ class PatternPairAggregator(BaseTextAggregator): # Create pattern match object pattern_match = PatternMatch( - pattern_id=pattern_id, full_match=full_match, content=content + pattern_id=pattern_id, full_match=full_match, content=content, type=match_type ) # Call the appropriate handler if registered @@ -178,11 +192,13 @@ class PatternPairAggregator(BaseTextAggregator): # Remove the pattern from the text if configured if remove_match: processed_text = processed_text.replace(full_match, "", 1) - modified = True + # modified = True + else: + all_matches.append(pattern_match) - return processed_text, modified + return all_matches, processed_text - def _has_incomplete_patterns(self, text: str) -> bool: + def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]: """Check if text contains incomplete pattern pairs. Determines whether the text contains any start patterns without @@ -192,7 +208,8 @@ class PatternPairAggregator(BaseTextAggregator): text: The text to check for incomplete patterns. Returns: - True if there are incomplete patterns, False otherwise. + A tuple of (start_index, type) if an incomplete pattern is found, + or None if no patterns are found or all patterns are complete. """ for pattern_id, pattern_info in self._patterns.items(): start = pattern_info["start"] @@ -203,12 +220,16 @@ class PatternPairAggregator(BaseTextAggregator): end_count = text.count(end) # If there are more starts than ends, we have incomplete patterns + # Again, this is written generically but there only ever should + # be one pattern active at a time, so the counts should be 0 or 1. + # Which is why we base the return on the first found. if start_count > end_count: - return True + start_index = text.find(start) + return [start_index, pattern_info["type"]] - return False + return None, None - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[PatternMatch]: """Aggregate text and process pattern pairs. This method adds the new text to the buffer, processes any complete pattern @@ -227,16 +248,28 @@ class PatternPairAggregator(BaseTextAggregator): self._text += text # Process any complete patterns in the buffer - processed_text, modified = await self._process_complete_patterns(self._text) + patterns, processed_text = await self._process_complete_patterns(self._text) - # Only update the buffer if modifications were made - if modified: - self._text = processed_text + self._text = processed_text + + # + if len(patterns) > 0: + if len(patterns) > 1: + logger.warning( + f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned." + ) + self._text = "" + return patterns[0] # Check if we have incomplete patterns - if self._has_incomplete_patterns(self._text): + start, curtype = self._match_start_of_pattern(self._text) + if start is not None: # Still waiting for complete patterns - return None + if start == 0: + return None + result = self._text[:start] + self._text = self._text[start:] + return PatternMatch(f"_sentence", result, result, "sentence") # Find sentence boundary if no incomplete patterns eos_marker = match_endofsentence(self._text) @@ -244,7 +277,7 @@ class PatternPairAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return result + return PatternMatch(f"_sentence", result, result, "sentence") # No complete sentence found yet return None diff --git a/src/pipecat/utils/text/simple_text_aggregator.py b/src/pipecat/utils/text/simple_text_aggregator.py index f9eb7d83a..16d6aef06 100644 --- a/src/pipecat/utils/text/simple_text_aggregator.py +++ b/src/pipecat/utils/text/simple_text_aggregator.py @@ -14,7 +14,7 @@ text processing scenarios. from typing import Optional from pipecat.utils.string import match_endofsentence -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator class SimpleTextAggregator(BaseTextAggregator): @@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator): self._text = "" @property - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently aggregated text. Returns: The text that has been accumulated in the buffer. """ - return self._text + return Aggregation(self._text, "sentence") - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text and return completed sentences. Adds the new text to the buffer and checks for end-of-sentence markers. @@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator): result = self._text[:eos_end_marker] self._text = self._text[eos_end_marker:] - return result + return Aggregation(result, "sentence") if result else None async def handle_interruption(self): """Handle interruptions by clearing the text buffer. diff --git a/src/pipecat/utils/text/skip_tags_aggregator.py b/src/pipecat/utils/text/skip_tags_aggregator.py index 6f6f8455c..da4933f2e 100644 --- a/src/pipecat/utils/text/skip_tags_aggregator.py +++ b/src/pipecat/utils/text/skip_tags_aggregator.py @@ -14,7 +14,7 @@ as a unit regardless of internal punctuation. from typing import Optional, Sequence from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator class SkipTagsAggregator(BaseTextAggregator): @@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator): Returns: The current text buffer content that hasn't been processed yet. """ - return self._text + return Aggregation(self._text, "sentence") - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text while respecting tag boundaries. This method adds the new text to the buffer, processes any complete @@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return result + return Aggregation(result, "sentence") # No complete sentence found yet return None diff --git a/tests/test_transcript_processor.py b/tests/test_transcript_processor.py index 19366086c..be58f061e 100644 --- a/tests/test_transcript_processor.py +++ b/tests/test_transcript_processor.py @@ -130,11 +130,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), # Wait for StartedSpeaking to process - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world!"), - TTSTextFrame(text="How"), - TTSTextFrame(text="are"), - TTSTextFrame(text="you?"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world!", aggregated_by="word"), + TTSTextFrame(text="How", aggregated_by="word"), + TTSTextFrame(text="are", aggregated_by="word"), + TTSTextFrame(text="you?", aggregated_by="word"), SleepFrame(), # Wait for text frames to queue BotStoppedSpeakingFrame(), ] @@ -195,9 +195,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text=""), # Empty text - TTSTextFrame(text=" "), # Just whitespace - TTSTextFrame(text="\n"), # Just newline + TTSTextFrame(text="", aggregated_by="word"), # Empty text + TTSTextFrame(text=" ", aggregated_by="word"), # Just whitespace + TTSTextFrame(text="\n", aggregated_by="word"), # Just newline BotStoppedSpeakingFrame(), # Pipeline ends here; run_test will automatically send EndFrame ] @@ -235,14 +235,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world!"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world!", aggregated_by="word"), SleepFrame(), InterruptionFrame(), # User interrupts here SleepFrame(), BotStartedSpeakingFrame(), - TTSTextFrame(text="New"), - TTSTextFrame(text="response"), + TTSTextFrame(text="New", aggregated_by="word"), + TTSTextFrame(text="response", aggregated_by="word"), SleepFrame(), BotStoppedSpeakingFrame(), ] @@ -299,8 +299,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world", aggregated_by="word"), # Pipeline ends here; run_test will automatically send EndFrame ] @@ -338,8 +338,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world", aggregated_by="word"), SleepFrame(), # Ensure messages are processed CancelFrame(), ] @@ -401,8 +401,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Assistant"), - TTSTextFrame(text="message"), + TTSTextFrame(text="Assistant", aggregated_by="word"), + TTSTextFrame(text="message", aggregated_by="word"), BotStoppedSpeakingFrame(), ] @@ -439,7 +439,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): # Test the specific pattern shared def make_tts_text_frame(text: str) -> TTSTextFrame: - frame = TTSTextFrame(text=text) + frame = TTSTextFrame(text=text, aggregated_by="word") frame.includes_inter_frame_spaces = True return frame