From cbccbcd9e7d6e4c5b1587e0eb8aac3f81174f19e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 19 May 2025 09:55:16 -0700 Subject: [PATCH] BaseTextFilter: make functions async --- CHANGELOG.md | 3 ++ .../server/bot_phone_local.py | 6 +-- .../server/bot_phone_twilio.py | 6 +-- src/pipecat/services/tts_service.py | 8 ++-- src/pipecat/utils/text/base_text_filter.py | 8 ++-- .../utils/text/markdown_text_filter.py | 8 ++-- tests/test_markdown_text_filter.py | 40 ++++++++++--------- 7 files changed, 43 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61f0e2c1c..d8db0d65d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- `BaseTextFilter` methods `filter()`, `update_settings()`, + `handle_interruption()` and `reset_interruption()` are now async. + - `BaseTextAggregator` methods `aggregate()`, `handle_interruption()` and `reset()` are now async. diff --git a/examples/word-wrangler-gemini-live/server/bot_phone_local.py b/examples/word-wrangler-gemini-live/server/bot_phone_local.py index b3fb791b9..84f15ea30 100644 --- a/examples/word-wrangler-gemini-live/server/bot_phone_local.py +++ b/examples/word-wrangler-gemini-live/server/bot_phone_local.py @@ -188,7 +188,7 @@ class HostResponseTextFilter(BaseTextFilter): # No settings to update for this filter pass - def filter(self, text: str) -> str: + async def filter(self, text: str) -> str: # Remove case and whitespace for comparison clean_text = text.strip().upper() @@ -198,10 +198,10 @@ class HostResponseTextFilter(BaseTextFilter): return text - def handle_interruption(self): + async def handle_interruption(self): self._interrupted = True - def reset_interruption(self): + async def reset_interruption(self): self._interrupted = False diff --git a/examples/word-wrangler-gemini-live/server/bot_phone_twilio.py b/examples/word-wrangler-gemini-live/server/bot_phone_twilio.py index ab2783f6b..2af8b2d4a 100644 --- a/examples/word-wrangler-gemini-live/server/bot_phone_twilio.py +++ b/examples/word-wrangler-gemini-live/server/bot_phone_twilio.py @@ -178,7 +178,7 @@ class HostResponseTextFilter(BaseTextFilter): # No settings to update for this filter pass - def filter(self, text: str) -> str: + async def filter(self, text: str) -> str: # Remove case and whitespace for comparison clean_text = text.strip().upper() @@ -188,10 +188,10 @@ class HostResponseTextFilter(BaseTextFilter): return text - def handle_interruption(self): + async def handle_interruption(self): self._interrupted = True - def reset_interruption(self): + async def reset_interruption(self): self._interrupted = False diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 5e751067c..0bdcd0d1c 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -157,7 +157,7 @@ class TTSService(AIService): self.set_voice(value) elif key == "text_filter": for filter in self._text_filters: - filter.update_settings(value) + await filter.update_settings(value) else: logger.warning(f"Unknown setting for TTS service: {key}") @@ -236,7 +236,7 @@ class TTSService(AIService): self._processing_text = False await self._text_aggregator.handle_interruption() for filter in self._text_filters: - filter.handle_interruption() + await filter.handle_interruption() async def _maybe_pause_frame_processing(self): if self._processing_text and self._pause_frame_processing: @@ -274,8 +274,8 @@ class TTSService(AIService): # Process all filter. for filter in self._text_filters: - filter.reset_interruption() - text = filter.filter(text) + await filter.reset_interruption() + text = await filter.filter(text) if text: await self.process_generator(self.run_tts(text)) diff --git a/src/pipecat/utils/text/base_text_filter.py b/src/pipecat/utils/text/base_text_filter.py index 0bedb7d7f..787a1a9da 100644 --- a/src/pipecat/utils/text/base_text_filter.py +++ b/src/pipecat/utils/text/base_text_filter.py @@ -10,17 +10,17 @@ from typing import Any, Mapping class BaseTextFilter(ABC): @abstractmethod - def update_settings(self, settings: Mapping[str, Any]): + async def update_settings(self, settings: Mapping[str, Any]): pass @abstractmethod - def filter(self, text: str) -> str: + async def filter(self, text: str) -> str: pass @abstractmethod - def handle_interruption(self): + async def handle_interruption(self): pass @abstractmethod - def reset_interruption(self): + async def reset_interruption(self): pass diff --git a/src/pipecat/utils/text/markdown_text_filter.py b/src/pipecat/utils/text/markdown_text_filter.py index c5ace159e..6f5e16bd0 100644 --- a/src/pipecat/utils/text/markdown_text_filter.py +++ b/src/pipecat/utils/text/markdown_text_filter.py @@ -33,12 +33,12 @@ class MarkdownTextFilter(BaseTextFilter): self._in_table = False self._interrupted = False - def update_settings(self, settings: Mapping[str, Any]): + async def update_settings(self, settings: Mapping[str, Any]): for key, value in settings.items(): if hasattr(self._settings, key): setattr(self._settings, key, value) - def filter(self, text: str) -> str: + async def filter(self, text: str) -> str: if self._settings.enable_text_filter: # Remove newlines and replace with a space only when there's no text before or after filtered_text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE) @@ -104,12 +104,12 @@ class MarkdownTextFilter(BaseTextFilter): else: return text - def handle_interruption(self): + async def handle_interruption(self): self._interrupted = True self._in_code_block = False self._in_table = False - def reset_interruption(self): + async def reset_interruption(self): self._interrupted = False # diff --git a/tests/test_markdown_text_filter.py b/tests/test_markdown_text_filter.py index c8cae97cc..a82a85811 100644 --- a/tests/test_markdown_text_filter.py +++ b/tests/test_markdown_text_filter.py @@ -30,7 +30,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): Some inline code here """ - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual(result.strip(), expected_text.strip()) async def test_space_preservation(self): @@ -45,7 +45,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): ] for text in input_text: - result = self.filter.filter(text) + result = await self.filter.filter(text) self.assertEqual( len(result), len(text), f"Space preservation failed for: '{text}'\nGot: '{result}'" ) @@ -71,7 +71,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): } for input_text, expected in test_cases.items(): - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual( result, expected, @@ -88,7 +88,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): 2. Second item 3. Third item with bold""" - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual( result.strip(), expected.strip(), @@ -106,7 +106,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): } for input_text, expected in test_cases.items(): - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual(result, expected, f"HTML entity conversion failed for: '{input_text}'") async def test_asterisk_removal(self): @@ -120,7 +120,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): } for input_text, expected in test_cases.items(): - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual(result, expected, f"Asterisk removal failed for: '{input_text}'") async def test_newline_handling(self): @@ -132,7 +132,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): } for input_text, expected in test_cases.items(): - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual( result, expected, f"Newline handling failed for:\n{input_text}\nGot:\n{result}" ) @@ -148,7 +148,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): } for input_text, expected in test_cases.items(): - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual( result, expected, @@ -166,7 +166,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): } for input_text, expected in test_cases.items(): - result = self.filter.filter(input_text) + result = await self.filter.filter(input_text) self.assertEqual(result, expected, f"Inline code handling failed for: '{input_text}'") async def test_simple_table_removal(self): @@ -177,7 +177,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): expected = "" - result = filter.filter(input_text) + result = await filter.filter(input_text) self.assertEqual( result.strip(), expected.strip(), @@ -198,15 +198,15 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): # Test with text filtering disabled text_with_markdown = "**bold** and *italic* with `code`" self.assertEqual( - filter.filter(text_with_markdown), + await filter.filter(text_with_markdown), text_with_markdown, "Disabled filter should not modify text", ) # Enable just text filtering - filter.update_settings({"enable_text_filter": True}) + await filter.update_settings({"enable_text_filter": True}) self.assertEqual( - filter.filter(text_with_markdown), + await filter.filter(text_with_markdown), "bold and italic with code", "Enabled filter should remove markdown", ) @@ -217,14 +217,18 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase): # Initial state - formatting should be removed input_text = "**bold** and *italic*" - self.assertEqual(filter.filter(input_text), "bold and italic") + self.assertEqual(await filter.filter(input_text), "bold and italic") # Disable text filtering - filter.update_settings({"enable_text_filter": False}) - self.assertEqual(filter.filter(input_text), input_text, "Text filtering should be disabled") + await filter.update_settings({"enable_text_filter": False}) + self.assertEqual( + await filter.filter(input_text), input_text, "Text filtering should be disabled" + ) # Re-enable text filtering - filter.update_settings({"enable_text_filter": True}) + await filter.update_settings({"enable_text_filter": True}) self.assertEqual( - filter.filter(input_text), "bold and italic", "Text filtering should be re-enabled" + await filter.filter(input_text), + "bold and italic", + "Text filtering should be re-enabled", )