Compare commits

...

1 Commits

Author SHA1 Message Date
Mark Backman
65ce5d0457 Add SelfClosingTagAggregator, example 45, and unit tests 2025-09-10 14:54:37 -04:00
4 changed files with 458 additions and 0 deletions

View File

@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `SelfClosingTagAggregator` to handle incomplete self-closing tags during
streaming (e.g., prevents splitting `<break time="0.1s"/>` when received as
`<break time="0.` + `1s"/>`).
- Added video streaming support to `LiveKitTransport`.
- Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide

View File

@@ -0,0 +1,122 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.text.self_closing_tag_aggregator import SelfClosingTagAggregator
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
text_aggregator=SelfClosingTagAggregator(["break"]), # Handle Cartesia break tags
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
"role": "system",
"content": """You are a helpful assistant in a voice call. Speak naturally and clearly. When sharing phone numbers, add natural pauses between number groups using <break time="0.2s"/> - for example: "You can reach us at 8-0-0<break time="0.2s"/>5-5-5<break time="0.2s"/>1-2-3-4". Use shorter pauses <break time="0.1s"/> when listing items or giving step-by-step instructions to help listeners follow along.""",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -0,0 +1,162 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Self-closing tag aggregator for handling XML-style self-closing tags.
This module provides a generic text aggregator that can handle any self-closing
XML-style tags (e.g., <break time="0.1s"/>, <pause duration="500ms"/>, etc.)
that should prevent sentence boundary detection when incomplete during streaming.
"""
import re
from typing import List, Optional
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
class SelfClosingTagAggregator(BaseTextAggregator):
r"""Aggregator that handles self-closing XML-style tags during streaming.
This aggregator is designed to handle any self-closing tags that might appear
in streaming text and could be split inappropriately when incomplete.
It prevents sentence boundary detection only when tags are incomplete.
The aggregator works by:
1. Detecting incomplete self-closing tags during streaming (e.g., '<break time="0.')
2. Buffering text until all tags are complete (e.g., '<break time="0.1s"/>')
3. Applying normal sentence boundary detection once all tags are complete
4. Supporting configurable tag patterns for different use cases
Example usage::
# For Cartesia break tags
aggregator = SelfClosingTagAggregator(['break'])
# For multiple tag types
aggregator = SelfClosingTagAggregator(['break', 'pause', 'emphasis'])
# For custom patterns
aggregator = SelfClosingTagAggregator(
patterns=[r'<break\\s+time="[^"]*"\\s*/?>', r'<pause\\s+duration="[^"]*"\\s*/>']
)
"""
def __init__(
self,
tags: Optional[List[str]] = None,
patterns: Optional[List[str]] = None,
):
"""Initialize the self-closing tag aggregator.
Args:
tags: List of tag names to handle (e.g., ['break', 'pause']).
Will generate patterns like <break .../>
patterns: List of custom regex patterns for complete tags.
Takes precedence over tags if provided.
Raises:
ValueError: If neither tags nor patterns are provided.
"""
self._text = ""
if patterns:
# Use custom patterns
self._complete_patterns = [re.compile(pattern) for pattern in patterns]
# Generate incomplete patterns from complete ones
self._incomplete_patterns = []
for pattern in patterns:
# Convert complete pattern to incomplete by making the closing part optional
# This is a simple heuristic - for complex patterns, users should provide both
incomplete = pattern.replace(r"\s*/?>", r"[^>]*$").replace(r"/>", r"[^>]*$")
self._incomplete_patterns.append(re.compile(incomplete))
elif tags:
# Generate patterns from tag names
self._complete_patterns = []
self._incomplete_patterns = []
for tag_name in tags:
# Pattern for complete self-closing tags: <tagname .../>
complete_pattern = rf"<{re.escape(tag_name)}\s+[^>]*\s*/?>"
self._complete_patterns.append(re.compile(complete_pattern))
# Pattern for incomplete tags: <tagname ... (without closing)
incomplete_pattern = rf"<{re.escape(tag_name)}\s+[^>]*$"
self._incomplete_patterns.append(re.compile(incomplete_pattern))
else:
raise ValueError("Must provide either 'tags' or 'patterns' parameter")
@property
def text(self) -> str:
"""Get the currently buffered text.
Returns:
The current text buffer content that hasn't been processed yet.
"""
return self._text
def _has_incomplete_tags(self, text: str) -> bool:
"""Check if the text ends with incomplete self-closing tags.
Args:
text: The text to check.
Returns:
True if the text ends with any incomplete tag patterns.
"""
for pattern in self._incomplete_patterns:
if pattern.search(text):
return True
return False
async def aggregate(self, text: str) -> Optional[str]:
"""Aggregate text while being aware of self-closing tags.
This method adds the new text to the buffer and checks for sentence
boundaries. If tags are incomplete, it continues buffering until
they are complete. Once all tags are complete, normal sentence
detection applies.
Args:
text: New text to add to the buffer.
Returns:
Processed text up to a sentence boundary, or None if incomplete
tags are present and more text is needed.
"""
# Add new text to buffer
self._text += text
# Check if we have incomplete tags - if so, keep buffering
if self._has_incomplete_tags(self._text):
return None
# No incomplete tags, use normal sentence detection
eos_marker = match_endofsentence(self._text)
if eos_marker:
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
# No sentence boundary found yet
return None
async def handle_interruption(self):
"""Handle interruptions by clearing the buffer.
Called when an interruption occurs in the processing pipeline,
to reset the state and discard any partially aggregated text.
"""
self._text = ""
async def reset(self):
"""Clear the internally aggregated text.
Resets the aggregator to its initial state, discarding any
buffered text.
"""
self._text = ""

View File

@@ -0,0 +1,170 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.utils.text.self_closing_tag_aggregator import SelfClosingTagAggregator
class TestSelfClosingTagAggregator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.aggregator = SelfClosingTagAggregator(["break"])
async def test_no_tags(self):
await self.aggregator.reset()
# No tags involved, aggregate at end of sentence.
result = await self.aggregator.aggregate("Hello Pipecat!")
self.assertEqual(result, "Hello Pipecat!")
self.assertEqual(self.aggregator.text, "")
async def test_complete_tags(self):
await self.aggregator.reset()
# Complete tags, should aggregate normally.
result = await self.aggregator.aggregate('Call us at <break time="0.1s"/>now.')
self.assertEqual(result, 'Call us at <break time="0.1s"/>now.')
self.assertEqual(self.aggregator.text, "")
async def test_incomplete_tag_single_chunk(self):
await self.aggregator.reset()
# Incomplete tag in single chunk, should buffer.
result = await self.aggregator.aggregate('Hello <break time="0.')
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, 'Hello <break time="0.')
async def test_multiple_tag_types(self):
# Test with multiple tag types
multi_aggregator = SelfClosingTagAggregator(["break", "pause", "voice"])
await multi_aggregator.reset()
result = await multi_aggregator.aggregate('Hello <voice name="alice"/>world.')
self.assertEqual(result, 'Hello <voice name="alice"/>world.')
self.assertEqual(multi_aggregator.text, "")
# Test incomplete with multiple types
await multi_aggregator.reset()
result = await multi_aggregator.aggregate('Say <pause duration="500')
self.assertIsNone(result)
self.assertEqual(multi_aggregator.text, 'Say <pause duration="500')
result = await multi_aggregator.aggregate('ms"/>this slowly.')
self.assertEqual(result, 'Say <pause duration="500ms"/>this slowly.')
self.assertEqual(multi_aggregator.text, "")
async def test_custom_patterns(self):
# Test with custom regex patterns
pattern_aggregator = SelfClosingTagAggregator(
patterns=[r'<break\s+time="[^"]*"\s*/?>', r'<voice\s+name="[^"]*"\s*/>']
)
await pattern_aggregator.reset()
# Complete custom pattern
result = await pattern_aggregator.aggregate('Test <break time="1.5s"/> custom.')
self.assertEqual(result, 'Test <break time="1.5s"/> custom.')
self.assertEqual(pattern_aggregator.text, "")
# Incomplete custom pattern
await pattern_aggregator.reset()
result = await pattern_aggregator.aggregate('Hello <voice name="bob')
self.assertIsNone(result)
self.assertEqual(pattern_aggregator.text, 'Hello <voice name="bob')
result = await pattern_aggregator.aggregate('"/>there.')
self.assertEqual(result, 'Hello <voice name="bob"/>there.')
self.assertEqual(pattern_aggregator.text, "")
async def test_sentence_boundaries_with_complete_tags(self):
await self.aggregator.reset()
# Multiple sentences with complete tags should split appropriately
result = await self.aggregator.aggregate(
'First <break time="1s"/> sentence. Second sentence.'
)
self.assertEqual(result, 'First <break time="1s"/> sentence.')
self.assertEqual(self.aggregator.text, " Second sentence.")
# Adding empty string should trigger processing of remaining complete sentence
result = await self.aggregator.aggregate("")
self.assertEqual(result, " Second sentence.")
self.assertEqual(self.aggregator.text, "")
async def test_no_sentence_ending(self):
await self.aggregator.reset()
# Text without sentence ending should buffer
result = await self.aggregator.aggregate('Hello <break time="1s"/> world')
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, 'Hello <break time="1s"/> world')
async def test_initialization_errors(self):
# Test that initialization requires either tags or patterns
with self.assertRaises(ValueError) as context:
SelfClosingTagAggregator()
self.assertIn("Must provide either 'tags' or 'patterns' parameter", str(context.exception))
# Test that both tags and patterns work
tag_aggregator = SelfClosingTagAggregator(["test"])
self.assertIsNotNone(tag_aggregator)
pattern_aggregator = SelfClosingTagAggregator(patterns=[r"<test\s*/>"])
self.assertIsNotNone(pattern_aggregator)
async def test_handle_interruption(self):
await self.aggregator.reset()
# Add some text to buffer
result = await self.aggregator.aggregate('Buffered <break time="0.')
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, 'Buffered <break time="0.')
# Handle interruption should clear buffer
await self.aggregator.handle_interruption()
self.assertEqual(self.aggregator.text, "")
async def test_reset(self):
await self.aggregator.reset()
# Add some text to buffer
result = await self.aggregator.aggregate('Some <break time="0.')
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, 'Some <break time="0.')
# Reset should clear buffer
await self.aggregator.reset()
self.assertEqual(self.aggregator.text, "")
async def test_property_access(self):
await self.aggregator.reset()
# Test that text property works
self.assertEqual(self.aggregator.text, "")
await self.aggregator.aggregate('Test <break time="0.')
self.assertEqual(self.aggregator.text, 'Test <break time="0.')
async def test_malformed_tags_ignored(self):
await self.aggregator.reset()
# Malformed tags (not matching pattern) should be ignored
result = await self.aggregator.aggregate('Test <break_time="0.1s"/> normal.')
self.assertEqual(result, 'Test <break_time="0.1s"/> normal.')
self.assertEqual(self.aggregator.text, "")
async def test_edge_case_empty_strings(self):
await self.aggregator.reset()
# Empty string should not cause issues
result = await self.aggregator.aggregate("")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "")
# Add real content after empty string
result = await self.aggregator.aggregate("Hello world.")
self.assertEqual(result, "Hello world.")
self.assertEqual(self.aggregator.text, "")