BaseTextFilter: make functions async

This commit is contained in:
Aleix Conchillo Flaqué
2025-05-19 09:55:16 -07:00
parent 54b1d7fcc1
commit cbccbcd9e7
7 changed files with 43 additions and 36 deletions

View File

@@ -68,6 +68,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `BaseTextFilter` methods `filter()`, `update_settings()`,
`handle_interruption()` and `reset_interruption()` are now async.
- `BaseTextAggregator` methods `aggregate()`, `handle_interruption()` and
`reset()` are now async.

View File

@@ -188,7 +188,7 @@ class HostResponseTextFilter(BaseTextFilter):
# No settings to update for this filter
pass
def filter(self, text: str) -> str:
async def filter(self, text: str) -> str:
# Remove case and whitespace for comparison
clean_text = text.strip().upper()
@@ -198,10 +198,10 @@ class HostResponseTextFilter(BaseTextFilter):
return text
def handle_interruption(self):
async def handle_interruption(self):
self._interrupted = True
def reset_interruption(self):
async def reset_interruption(self):
self._interrupted = False

View File

@@ -178,7 +178,7 @@ class HostResponseTextFilter(BaseTextFilter):
# No settings to update for this filter
pass
def filter(self, text: str) -> str:
async def filter(self, text: str) -> str:
# Remove case and whitespace for comparison
clean_text = text.strip().upper()
@@ -188,10 +188,10 @@ class HostResponseTextFilter(BaseTextFilter):
return text
def handle_interruption(self):
async def handle_interruption(self):
self._interrupted = True
def reset_interruption(self):
async def reset_interruption(self):
self._interrupted = False

View File

@@ -157,7 +157,7 @@ class TTSService(AIService):
self.set_voice(value)
elif key == "text_filter":
for filter in self._text_filters:
filter.update_settings(value)
await filter.update_settings(value)
else:
logger.warning(f"Unknown setting for TTS service: {key}")
@@ -236,7 +236,7 @@ class TTSService(AIService):
self._processing_text = False
await self._text_aggregator.handle_interruption()
for filter in self._text_filters:
filter.handle_interruption()
await filter.handle_interruption()
async def _maybe_pause_frame_processing(self):
if self._processing_text and self._pause_frame_processing:
@@ -274,8 +274,8 @@ class TTSService(AIService):
# Process all filter.
for filter in self._text_filters:
filter.reset_interruption()
text = filter.filter(text)
await filter.reset_interruption()
text = await filter.filter(text)
if text:
await self.process_generator(self.run_tts(text))

View File

@@ -10,17 +10,17 @@ from typing import Any, Mapping
class BaseTextFilter(ABC):
@abstractmethod
def update_settings(self, settings: Mapping[str, Any]):
async def update_settings(self, settings: Mapping[str, Any]):
pass
@abstractmethod
def filter(self, text: str) -> str:
async def filter(self, text: str) -> str:
pass
@abstractmethod
def handle_interruption(self):
async def handle_interruption(self):
pass
@abstractmethod
def reset_interruption(self):
async def reset_interruption(self):
pass

View File

@@ -33,12 +33,12 @@ class MarkdownTextFilter(BaseTextFilter):
self._in_table = False
self._interrupted = False
def update_settings(self, settings: Mapping[str, Any]):
async def update_settings(self, settings: Mapping[str, Any]):
for key, value in settings.items():
if hasattr(self._settings, key):
setattr(self._settings, key, value)
def filter(self, text: str) -> str:
async def filter(self, text: str) -> str:
if self._settings.enable_text_filter:
# Remove newlines and replace with a space only when there's no text before or after
filtered_text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE)
@@ -104,12 +104,12 @@ class MarkdownTextFilter(BaseTextFilter):
else:
return text
def handle_interruption(self):
async def handle_interruption(self):
self._interrupted = True
self._in_code_block = False
self._in_table = False
def reset_interruption(self):
async def reset_interruption(self):
self._interrupted = False
#

View File

@@ -30,7 +30,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
Some inline code here
"""
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(result.strip(), expected_text.strip())
async def test_space_preservation(self):
@@ -45,7 +45,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
]
for text in input_text:
result = self.filter.filter(text)
result = await self.filter.filter(text)
self.assertEqual(
len(result), len(text), f"Space preservation failed for: '{text}'\nGot: '{result}'"
)
@@ -71,7 +71,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
}
for input_text, expected in test_cases.items():
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(
result,
expected,
@@ -88,7 +88,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
2. Second item
3. Third item with bold"""
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(
result.strip(),
expected.strip(),
@@ -106,7 +106,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
}
for input_text, expected in test_cases.items():
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(result, expected, f"HTML entity conversion failed for: '{input_text}'")
async def test_asterisk_removal(self):
@@ -120,7 +120,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
}
for input_text, expected in test_cases.items():
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(result, expected, f"Asterisk removal failed for: '{input_text}'")
async def test_newline_handling(self):
@@ -132,7 +132,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
}
for input_text, expected in test_cases.items():
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(
result, expected, f"Newline handling failed for:\n{input_text}\nGot:\n{result}"
)
@@ -148,7 +148,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
}
for input_text, expected in test_cases.items():
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(
result,
expected,
@@ -166,7 +166,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
}
for input_text, expected in test_cases.items():
result = self.filter.filter(input_text)
result = await self.filter.filter(input_text)
self.assertEqual(result, expected, f"Inline code handling failed for: '{input_text}'")
async def test_simple_table_removal(self):
@@ -177,7 +177,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
expected = ""
result = filter.filter(input_text)
result = await filter.filter(input_text)
self.assertEqual(
result.strip(),
expected.strip(),
@@ -198,15 +198,15 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
# Test with text filtering disabled
text_with_markdown = "**bold** and *italic* with `code`"
self.assertEqual(
filter.filter(text_with_markdown),
await filter.filter(text_with_markdown),
text_with_markdown,
"Disabled filter should not modify text",
)
# Enable just text filtering
filter.update_settings({"enable_text_filter": True})
await filter.update_settings({"enable_text_filter": True})
self.assertEqual(
filter.filter(text_with_markdown),
await filter.filter(text_with_markdown),
"bold and italic with code",
"Enabled filter should remove markdown",
)
@@ -217,14 +217,18 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
# Initial state - formatting should be removed
input_text = "**bold** and *italic*"
self.assertEqual(filter.filter(input_text), "bold and italic")
self.assertEqual(await filter.filter(input_text), "bold and italic")
# Disable text filtering
filter.update_settings({"enable_text_filter": False})
self.assertEqual(filter.filter(input_text), input_text, "Text filtering should be disabled")
await filter.update_settings({"enable_text_filter": False})
self.assertEqual(
await filter.filter(input_text), input_text, "Text filtering should be disabled"
)
# Re-enable text filtering
filter.update_settings({"enable_text_filter": True})
await filter.update_settings({"enable_text_filter": True})
self.assertEqual(
filter.filter(input_text), "bold and italic", "Text filtering should be re-enabled"
await filter.filter(input_text),
"bold and italic",
"Text filtering should be re-enabled",
)