diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 6f48f79f7..4af13003d 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -359,6 +359,9 @@ class LLMTextFrame(TextFrame): class TTSTextFrame(TextFrame): """Text frame generated by Text-to-Speech services.""" + aggregated_by: Literal["sentence", "word"] | str + spoken: Optional[bool] = True # Whether this text has been spoken by TTS + pass diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index f04cbd395..18418081f 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -704,6 +704,29 @@ class RTVITextMessageData(BaseModel): text: str +class RTVIBotOutputMessageData(RTVITextMessageData): + """Data for bot output RTVI messages. + + Extends RTVITextMessageData to include metadata about the output. + """ + + spoken: bool = True # Indicates if the text has been spoken by TTS + aggregated_by: Optional[Literal["word", "sentence"] | str] = None + # Indicates what form the text is in (e.g., by word, sentence, etc.) + + +class RTVIBotOutputMessage(BaseModel): + """Message containing bot output text. + + An event meant to wholistically represent what the bot is outputting, + along with metadata about the output and if it has been spoken. + """ + + label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL + type: Literal["bot-output"] = "bot-output" + data: RTVIBotOutputMessageData + + class RTVIBotTranscriptionMessage(BaseModel): """Message containing bot transcription text. @@ -960,6 +983,8 @@ class RTVIObserver(BaseObserver): self._last_user_audio_level = 0 self._last_bot_audio_level = 0 + self._skip_tts = None + if self._params.system_logs_enabled: self._system_logger_id = logger.add(self._logger_sink) @@ -1050,8 +1075,7 @@ class RTVIObserver(BaseObserver): await self.send_rtvi_message(RTVIBotTTSStoppedMessage()) elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled: if isinstance(src, BaseOutputTransport): - message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) - await self.send_rtvi_message(message) + await self._handle_tts_text_frame(frame) else: mark_as_seen = False elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled: @@ -1115,14 +1139,63 @@ class RTVIObserver(BaseObserver): if message: await self.send_rtvi_message(message) + async def _handle_tts_text_frame(self, frame: TTSTextFrame): + """Handle TTS text output frames.""" + # send the tts-text message + message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) + await self.send_rtvi_message(message) + # send the bot-output message + message = RTVIBotOutputMessage( + data=RTVIBotOutputMessageData( + text=frame.text, spoken=frame.spoken, aggregated_by=frame.aggregated_by + ) + ) + await self.send_rtvi_message(message) + async def _handle_llm_text_frame(self, frame: LLMTextFrame): """Handle LLM text output frames.""" message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text)) await self.send_rtvi_message(message) + # initialize skip_tts on first LLMTextFrame + if self._skip_tts is None: + self._skip_tts = frame.skip_tts + + messages = [] + should_reset_transcription = False self._bot_transcription += frame.text - if match_endofsentence(self._bot_transcription): - await self._push_bot_transcription() + + if not frame.skip_tts and self._skip_tts: + # We just switched from skipping TTS to not skipping TTS. + # Send and reset any existing transcription. + if len(self._bot_transcription) > 0: + message.append( + RTVIBotOutputMessage( + data=RTVIBotOutputMessageData( + text=self._bot_transcription, spoken=False, aggregated_by="sentence" + ) + ) + ) + should_reset_transcription = True + + if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0: + messages.append( + RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription)) + ) + if frame.skip_tts: + messages.append( + RTVIBotOutputMessage( + data=RTVIBotOutputMessageData( + text=self._bot_transcription, spoken=False, aggregated_by="sentence" + ) + ) + ) + should_reset_transcription = True + + for msg in messages: + await self.send_rtvi_message(msg) + if should_reset_transcription: + self._bot_transcription = "" async def _handle_user_transcriptions(self, frame: Frame): """Handle user transcription frames.""" diff --git a/src/pipecat/services/aws/nova_sonic/llm.py b/src/pipecat/services/aws/nova_sonic/llm.py index 2572b03cb..1a50fc22f 100644 --- a/src/pipecat/services/aws/nova_sonic/llm.py +++ b/src/pipecat/services/aws/nova_sonic/llm.py @@ -1027,7 +1027,7 @@ class AWSNovaSonicLLMService(LLMService): logger.debug(f"Assistant response text added: {text}") # Report the text of the assistant response. - frame = TTSTextFrame(text) + frame = TTSTextFrame(text, aggregated_by="sentence", spoken=True) frame.includes_inter_frame_spaces = True await self.push_frame(frame) @@ -1062,7 +1062,9 @@ class AWSNovaSonicLLMService(LLMService): # TTSTextFrame would be ignored otherwise (the interruption frame # would have cleared the assistant aggregator state). await self.push_frame(LLMFullResponseStartFrame()) - frame = TTSTextFrame(self._assistant_text_buffer) + frame = TTSTextFrame( + self._assistant_text_buffer, aggregated_by="sentence", spoken=True + ) frame.includes_inter_frame_spaces = True await self.push_frame(frame) self._may_need_repush_assistant_text = False diff --git a/src/pipecat/services/google/gemini_live/llm.py b/src/pipecat/services/google/gemini_live/llm.py index 11632968e..ed28298ea 100644 --- a/src/pipecat/services/google/gemini_live/llm.py +++ b/src/pipecat/services/google/gemini_live/llm.py @@ -1646,7 +1646,7 @@ class GeminiLiveLLMService(LLMService): await self.push_frame(TTSStartedFrame()) await self.push_frame(LLMFullResponseStartFrame()) - frame = TTSTextFrame(text=text) + frame = TTSTextFrame(text=text, aggregated_by="sentence") # Gemini Live text already includes any necessary inter-chunk spaces frame.includes_inter_frame_spaces = True diff --git a/src/pipecat/services/openai/realtime/llm.py b/src/pipecat/services/openai/realtime/llm.py index f66e6e8e1..1d29908ba 100644 --- a/src/pipecat/services/openai/realtime/llm.py +++ b/src/pipecat/services/openai/realtime/llm.py @@ -686,7 +686,7 @@ class OpenAIRealtimeLLMService(LLMService): # We receive audio transcript deltas (as opposed to text deltas) when # the output modality is "audio" (the default) if evt.delta: - frame = TTSTextFrame(evt.delta) + frame = TTSTextFrame(evt.delta, aggregated_by="sentence") # OpenAI Realtime text already includes any necessary inter-chunk spaces frame.includes_inter_frame_spaces = True await self.push_frame(frame) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index af0600882..d67e58cf8 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -652,7 +652,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): async def _handle_evt_audio_transcript_delta(self, evt): if evt.delta: await self.push_frame(LLMTextFrame(evt.delta)) - await self.push_frame(TTSTextFrame(evt.delta)) + await self.push_frame(TTSTextFrame(evt.delta, aggregated_by="sentence", spoken=True)) async def _handle_evt_speech_started(self, evt): await self._truncate_current_audio_response() diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 29c54f497..fa7956c81 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -101,6 +101,8 @@ class TTSService(AIService): sample_rate: Optional[int] = None, # Text aggregator to aggregate incoming tokens and decide when to push to the TTS. text_aggregator: Optional[BaseTextAggregator] = None, + # Types of text aggregations that should not be spoken. + skip_aggregator_types: Optional[List[str]] = [], # Text filter executed after text has been aggregated. text_filters: Optional[Sequence[BaseTextFilter]] = None, text_filter: Optional[BaseTextFilter] = None, @@ -120,6 +122,7 @@ class TTSService(AIService): pause_frame_processing: Whether to pause frame processing during audio generation. sample_rate: Output sample rate for generated audio. text_aggregator: Custom text aggregator for processing incoming text. + skip_aggregator_types: List of aggregation types that should not be spoken. text_filters: Sequence of text filters to apply after aggregation. text_filter: Single text filter (deprecated, use text_filters). @@ -142,6 +145,7 @@ class TTSService(AIService): self._voice_id: str = "" self._settings: Dict[str, Any] = {} self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() + self._skip_aggregator_types: List[str] = skip_aggregator_types or [] self._text_filters: Sequence[BaseTextFilter] = text_filters or [] self._transport_destination: Optional[str] = transport_destination self._tracing_enabled: bool = False @@ -368,10 +372,14 @@ class TTSService(AIService): # pause to avoid audio overlapping. await self._maybe_pause_frame_processing() - sentence = self._text_aggregator.text + aggregate = self._text_aggregator.text await self._text_aggregator.reset() self._processing_text = False - await self._push_tts_frames(sentence) + await self._push_tts_frames( + text=aggregate.text, + should_speak=aggregate.type not in self._skip_aggregator_types, + aggregated_by=aggregate.type, + ) if isinstance(frame, LLMFullResponseEndFrame): if self._push_text_frames: await self.push_frame(frame, direction) @@ -380,7 +388,7 @@ class TTSService(AIService): elif isinstance(frame, TTSSpeakFrame): # Store if we were processing text or not so we can set it back. processing_text = self._processing_text - await self._push_tts_frames(frame.text) + await self._push_tts_frames(frame.text, should_speak=True, aggregated_by="word") # We pause processing incoming frames because we are sending data to # the TTS. We pause to avoid audio overlapping. await self._maybe_pause_frame_processing() @@ -472,42 +480,51 @@ class TTSService(AIService): text: Optional[str] = None if not self._aggregate_sentences: text = frame.text + should_speak = True + aggregated_by = "token" else: - text = await self._text_aggregator.aggregate(frame.text) + aggregate = await self._text_aggregator.aggregate(frame.text) + if aggregate: + text = aggregate.text + should_speak = aggregate.type not in self._skip_aggregator_types + aggregated_by = aggregate.type if text: - await self._push_tts_frames(text) + logger.trace(f"Pushing TTS frames for text: {text}, {should_speak}, {aggregated_by}") + await self._push_tts_frames(text, should_speak, aggregated_by) - async def _push_tts_frames(self, text: str): - # Remove leading newlines only - text = text.lstrip("\n") + async def _push_tts_frames(self, text: str, should_speak: bool, aggregated_by: str): + if should_speak: + # Remove leading newlines only + text = text.lstrip("\n") - # Don't send only whitespace. This causes problems for some TTS models. But also don't - # strip all whitespace, as whitespace can influence prosody. - if not text.strip(): - return + # Don't send only whitespace. This causes problems for some TTS models. But also don't + # strip all whitespace, as whitespace can influence prosody. + if not text.strip(): + return - # This is just a flag that indicates if we sent something to the TTS - # service. It will be cleared if we sent text because of a TTSSpeakFrame - # or when we received an LLMFullResponseEndFrame - self._processing_text = True + # This is just a flag that indicates if we sent something to the TTS + # service. It will be cleared if we sent text because of a TTSSpeakFrame + # or when we received an LLMFullResponseEndFrame + self._processing_text = True - await self.start_processing_metrics() + await self.start_processing_metrics() - # Process all filter. - for filter in self._text_filters: - await filter.reset_interruption() - text = await filter.filter(text) + # Process all filter. + for filter in self._text_filters: + await filter.reset_interruption() + text = await filter.filter(text) - if text: - await self.process_generator(self.run_tts(text)) + if text: + await self.push_frame(TTSTextFrame(text, spoken=True, aggregated_by=aggregated_by)) + await self.process_generator(self.run_tts(text)) - await self.stop_processing_metrics() + await self.stop_processing_metrics() - if self._push_text_frames: + if self._push_text_frames or not should_speak: # We send the original text after the audio. This way, if we are # interrupted, the text is not added to the assistant context. - frame = TTSTextFrame(text) + frame = TTSTextFrame(text, spoken=should_speak, aggregated_by=aggregated_by) frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces await self.push_frame(frame) @@ -635,7 +652,7 @@ class WordTTSService(TTSService): frame = TTSStoppedFrame() frame.pts = last_pts else: - frame = TTSTextFrame(word) + frame = TTSTextFrame(word, spoken=True, aggregated_by="word") frame.pts = self._initial_word_timestamp + timestamp if frame: last_pts = frame.pts diff --git a/src/pipecat/utils/text/base_text_aggregator.py b/src/pipecat/utils/text/base_text_aggregator.py index 27e50fff5..5a5196920 100644 --- a/src/pipecat/utils/text/base_text_aggregator.py +++ b/src/pipecat/utils/text/base_text_aggregator.py @@ -12,9 +12,38 @@ aggregated text should be sent for speech synthesis. """ from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Optional +@dataclass +class Aggregation: + """Data class representing aggregated text and its type. + + An Aggregation object is created whenever a stream of text is aggregated by + a text aggregator. It contains the aggregated text and a type indicating + the nature of the aggregation. + """ + + def __init__(self, text: str, type: str): + """Initialize an aggregation instance. + + Args: + text: The aggregated text content. + type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation'). + """ + self.text = text + self.type = type + + def __str__(self) -> str: + """Return a string representation of the aggregation. + + Returns: + A descriptive string showing the type and text of the aggregation. + """ + return f"Aggregation by {self.type}: {self.text}" + + class BaseTextAggregator(ABC): """Base class for text aggregators in the Pipecat framework. @@ -30,7 +59,7 @@ class BaseTextAggregator(ABC): @property @abstractmethod - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently aggregated text. Subclasses must implement this property to return the text that has @@ -42,12 +71,13 @@ class BaseTextAggregator(ABC): pass @abstractmethod - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate the specified text with the currently accumulated text. This method should be implemented to define how the new text contributes - to the aggregation process. It returns the updated aggregated text if - it's ready to be processed, or None otherwise. + to the aggregation process. It returns the aggregated text and a string + describing how it was aggregated if it's ready to be processed, + or None otherwise. Subclasses should implement their specific logic for: diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index ac074f2de..fe32f5b51 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -12,15 +12,15 @@ support for custom handlers and configurable pattern removal. """ import re -from typing import Awaitable, Callable, Optional, Tuple +from typing import Awaitable, Callable, List, Optional, Tuple from loguru import logger from pipecat.utils.string import match_endofsentence -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator -class PatternMatch: +class PatternMatch(Aggregation): """Represents a matched pattern pair with its content. A PatternMatch object is created when a complete pattern pair is found @@ -29,17 +29,19 @@ class PatternMatch: content between the patterns. """ - def __init__(self, pattern_id: str, full_match: str, content: str): + def __init__(self, pattern_id: str, full_match: str, content: str, type: str): """Initialize a pattern match. Args: pattern_id: The identifier of the matched pattern pair. full_match: The complete text including start and end patterns. content: The text content between the start and end patterns. + type: The type of aggregation the matched content represents + (e.g., 'code', 'speaker', 'custom'). """ + super().__init__(text=content, type=type) self.pattern_id = pattern_id self.full_match = full_match - self.content = content def __str__(self) -> str: """Return a string representation of the pattern match. @@ -47,7 +49,7 @@ class PatternMatch: Returns: A descriptive string showing the pattern ID and content. """ - return f"PatternMatch(id={self.pattern_id}, content={self.content})" + return f"PatternMatch(id={self.pattern_id}, content={self.text}, full_match={self.full_match}, type={self.type})" class PatternPairAggregator(BaseTextAggregator): @@ -64,7 +66,7 @@ class PatternPairAggregator(BaseTextAggregator): boundaries. """ - def __init__(self): + def __init__(self, **kwargs): """Initialize the pattern pair aggregator. Creates an empty aggregator with no patterns or handlers registered. @@ -75,16 +77,24 @@ class PatternPairAggregator(BaseTextAggregator): self._handlers = {} @property - def text(self) -> str: - """Get the currently buffered text. + def text(self) -> Aggregation: + """Get the currently aggregated text. Returns: - The current text buffer content that hasn't been processed yet. + The text that has been accumulated in the buffer. """ - return self._text + start, curtype = self._match_start_of_pattern(self._text) + if curtype: + return Aggregation(self._text, curtype) + return Aggregation(self._text, "sentence") def add_pattern_pair( - self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True + self, + pattern_id: str, + start_pattern: str, + end_pattern: str, + type: str, + remove_match: bool = True, ) -> "PatternPairAggregator": """Add a pattern pair to detect in the text. @@ -96,7 +106,9 @@ class PatternPairAggregator(BaseTextAggregator): pattern_id: Unique identifier for this pattern pair. start_pattern: Pattern that marks the beginning of content. end_pattern: Pattern that marks the end of content. - remove_match: Whether to remove the matched content from the text. + type: The type of aggregation the matched content represents + (e.g., 'code', 'speaker', 'custom'). + remove_match: Whether to remove the matched content from the text returned. Returns: Self for method chaining. @@ -104,6 +116,7 @@ class PatternPairAggregator(BaseTextAggregator): self._patterns[pattern_id] = { "start": start_pattern, "end": end_pattern, + "type": type, "remove_match": remove_match, } return self @@ -127,7 +140,7 @@ class PatternPairAggregator(BaseTextAggregator): self._handlers[pattern_id] = handler return self - async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]: + async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]: """Process all complete pattern pairs in the text. Searches for all complete pattern pairs in the text, calls the @@ -137,19 +150,20 @@ class PatternPairAggregator(BaseTextAggregator): text: The text to process for pattern matches. Returns: - Tuple of (processed_text, was_modified) where: + Tuple of (all_matches, processed_text) where: - - processed_text is the text after processing patterns - - was_modified indicates whether any changes were made + - all_matches is a list of all pattern matches found. Note: There really should only ever be 1. + - processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text. """ + all_matches = [] processed_text = text - modified = False for pattern_id, pattern_info in self._patterns.items(): # Escape special regex characters in the patterns start = re.escape(pattern_info["start"]) end = re.escape(pattern_info["end"]) remove_match = pattern_info["remove_match"] + match_type = pattern_info["type"] # Create regex to match from start pattern to end pattern # The .*? is non-greedy to handle nested patterns @@ -165,7 +179,7 @@ class PatternPairAggregator(BaseTextAggregator): # Create pattern match object pattern_match = PatternMatch( - pattern_id=pattern_id, full_match=full_match, content=content + pattern_id=pattern_id, full_match=full_match, content=content, type=match_type ) # Call the appropriate handler if registered @@ -178,11 +192,13 @@ class PatternPairAggregator(BaseTextAggregator): # Remove the pattern from the text if configured if remove_match: processed_text = processed_text.replace(full_match, "", 1) - modified = True + # modified = True + else: + all_matches.append(pattern_match) - return processed_text, modified + return all_matches, processed_text - def _has_incomplete_patterns(self, text: str) -> bool: + def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]: """Check if text contains incomplete pattern pairs. Determines whether the text contains any start patterns without @@ -192,7 +208,8 @@ class PatternPairAggregator(BaseTextAggregator): text: The text to check for incomplete patterns. Returns: - True if there are incomplete patterns, False otherwise. + A tuple of (start_index, type) if an incomplete pattern is found, + or None if no patterns are found or all patterns are complete. """ for pattern_id, pattern_info in self._patterns.items(): start = pattern_info["start"] @@ -203,12 +220,16 @@ class PatternPairAggregator(BaseTextAggregator): end_count = text.count(end) # If there are more starts than ends, we have incomplete patterns + # Again, this is written generically but there only ever should + # be one pattern active at a time, so the counts should be 0 or 1. + # Which is why we base the return on the first found. if start_count > end_count: - return True + start_index = text.find(start) + return [start_index, pattern_info["type"]] - return False + return None, None - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[PatternMatch]: """Aggregate text and process pattern pairs. This method adds the new text to the buffer, processes any complete pattern @@ -227,16 +248,28 @@ class PatternPairAggregator(BaseTextAggregator): self._text += text # Process any complete patterns in the buffer - processed_text, modified = await self._process_complete_patterns(self._text) + patterns, processed_text = await self._process_complete_patterns(self._text) - # Only update the buffer if modifications were made - if modified: - self._text = processed_text + self._text = processed_text + + # + if len(patterns) > 0: + if len(patterns) > 1: + logger.warning( + f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned." + ) + self._text = "" + return patterns[0] # Check if we have incomplete patterns - if self._has_incomplete_patterns(self._text): + start, curtype = self._match_start_of_pattern(self._text) + if start is not None: # Still waiting for complete patterns - return None + if start == 0: + return None + result = self._text[:start] + self._text = self._text[start:] + return PatternMatch(f"_sentence", result, result, "sentence") # Find sentence boundary if no incomplete patterns eos_marker = match_endofsentence(self._text) @@ -244,7 +277,7 @@ class PatternPairAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return result + return PatternMatch(f"_sentence", result, result, "sentence") # No complete sentence found yet return None diff --git a/src/pipecat/utils/text/simple_text_aggregator.py b/src/pipecat/utils/text/simple_text_aggregator.py index f9eb7d83a..16d6aef06 100644 --- a/src/pipecat/utils/text/simple_text_aggregator.py +++ b/src/pipecat/utils/text/simple_text_aggregator.py @@ -14,7 +14,7 @@ text processing scenarios. from typing import Optional from pipecat.utils.string import match_endofsentence -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator class SimpleTextAggregator(BaseTextAggregator): @@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator): self._text = "" @property - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently aggregated text. Returns: The text that has been accumulated in the buffer. """ - return self._text + return Aggregation(self._text, "sentence") - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text and return completed sentences. Adds the new text to the buffer and checks for end-of-sentence markers. @@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator): result = self._text[:eos_end_marker] self._text = self._text[eos_end_marker:] - return result + return Aggregation(result, "sentence") if result else None async def handle_interruption(self): """Handle interruptions by clearing the text buffer. diff --git a/src/pipecat/utils/text/skip_tags_aggregator.py b/src/pipecat/utils/text/skip_tags_aggregator.py index 6f6f8455c..da4933f2e 100644 --- a/src/pipecat/utils/text/skip_tags_aggregator.py +++ b/src/pipecat/utils/text/skip_tags_aggregator.py @@ -14,7 +14,7 @@ as a unit regardless of internal punctuation. from typing import Optional, Sequence from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator class SkipTagsAggregator(BaseTextAggregator): @@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator): Returns: The current text buffer content that hasn't been processed yet. """ - return self._text + return Aggregation(self._text, "sentence") - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text while respecting tag boundaries. This method adds the new text to the buffer, processes any complete @@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return result + return Aggregation(result, "sentence") # No complete sentence found yet return None diff --git a/tests/test_transcript_processor.py b/tests/test_transcript_processor.py index 19366086c..be58f061e 100644 --- a/tests/test_transcript_processor.py +++ b/tests/test_transcript_processor.py @@ -130,11 +130,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), # Wait for StartedSpeaking to process - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world!"), - TTSTextFrame(text="How"), - TTSTextFrame(text="are"), - TTSTextFrame(text="you?"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world!", aggregated_by="word"), + TTSTextFrame(text="How", aggregated_by="word"), + TTSTextFrame(text="are", aggregated_by="word"), + TTSTextFrame(text="you?", aggregated_by="word"), SleepFrame(), # Wait for text frames to queue BotStoppedSpeakingFrame(), ] @@ -195,9 +195,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text=""), # Empty text - TTSTextFrame(text=" "), # Just whitespace - TTSTextFrame(text="\n"), # Just newline + TTSTextFrame(text="", aggregated_by="word"), # Empty text + TTSTextFrame(text=" ", aggregated_by="word"), # Just whitespace + TTSTextFrame(text="\n", aggregated_by="word"), # Just newline BotStoppedSpeakingFrame(), # Pipeline ends here; run_test will automatically send EndFrame ] @@ -235,14 +235,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world!"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world!", aggregated_by="word"), SleepFrame(), InterruptionFrame(), # User interrupts here SleepFrame(), BotStartedSpeakingFrame(), - TTSTextFrame(text="New"), - TTSTextFrame(text="response"), + TTSTextFrame(text="New", aggregated_by="word"), + TTSTextFrame(text="response", aggregated_by="word"), SleepFrame(), BotStoppedSpeakingFrame(), ] @@ -299,8 +299,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world", aggregated_by="word"), # Pipeline ends here; run_test will automatically send EndFrame ] @@ -338,8 +338,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Hello"), - TTSTextFrame(text="world"), + TTSTextFrame(text="Hello", aggregated_by="word"), + TTSTextFrame(text="world", aggregated_by="word"), SleepFrame(), # Ensure messages are processed CancelFrame(), ] @@ -401,8 +401,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send = [ BotStartedSpeakingFrame(), SleepFrame(), - TTSTextFrame(text="Assistant"), - TTSTextFrame(text="message"), + TTSTextFrame(text="Assistant", aggregated_by="word"), + TTSTextFrame(text="message", aggregated_by="word"), BotStoppedSpeakingFrame(), ] @@ -439,7 +439,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): # Test the specific pattern shared def make_tts_text_frame(text: str) -> TTSTextFrame: - frame = TTSTextFrame(text=text) + frame = TTSTextFrame(text=text, aggregated_by="word") frame.includes_inter_frame_spaces = True return frame