BaseTextFilter: make functions async
This commit is contained in:
@@ -68,6 +68,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `BaseTextFilter` methods `filter()`, `update_settings()`,
|
||||
`handle_interruption()` and `reset_interruption()` are now async.
|
||||
|
||||
- `BaseTextAggregator` methods `aggregate()`, `handle_interruption()` and
|
||||
`reset()` are now async.
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ class HostResponseTextFilter(BaseTextFilter):
|
||||
# No settings to update for this filter
|
||||
pass
|
||||
|
||||
def filter(self, text: str) -> str:
|
||||
async def filter(self, text: str) -> str:
|
||||
# Remove case and whitespace for comparison
|
||||
clean_text = text.strip().upper()
|
||||
|
||||
@@ -198,10 +198,10 @@ class HostResponseTextFilter(BaseTextFilter):
|
||||
|
||||
return text
|
||||
|
||||
def handle_interruption(self):
|
||||
async def handle_interruption(self):
|
||||
self._interrupted = True
|
||||
|
||||
def reset_interruption(self):
|
||||
async def reset_interruption(self):
|
||||
self._interrupted = False
|
||||
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ class HostResponseTextFilter(BaseTextFilter):
|
||||
# No settings to update for this filter
|
||||
pass
|
||||
|
||||
def filter(self, text: str) -> str:
|
||||
async def filter(self, text: str) -> str:
|
||||
# Remove case and whitespace for comparison
|
||||
clean_text = text.strip().upper()
|
||||
|
||||
@@ -188,10 +188,10 @@ class HostResponseTextFilter(BaseTextFilter):
|
||||
|
||||
return text
|
||||
|
||||
def handle_interruption(self):
|
||||
async def handle_interruption(self):
|
||||
self._interrupted = True
|
||||
|
||||
def reset_interruption(self):
|
||||
async def reset_interruption(self):
|
||||
self._interrupted = False
|
||||
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ class TTSService(AIService):
|
||||
self.set_voice(value)
|
||||
elif key == "text_filter":
|
||||
for filter in self._text_filters:
|
||||
filter.update_settings(value)
|
||||
await filter.update_settings(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for TTS service: {key}")
|
||||
|
||||
@@ -236,7 +236,7 @@ class TTSService(AIService):
|
||||
self._processing_text = False
|
||||
await self._text_aggregator.handle_interruption()
|
||||
for filter in self._text_filters:
|
||||
filter.handle_interruption()
|
||||
await filter.handle_interruption()
|
||||
|
||||
async def _maybe_pause_frame_processing(self):
|
||||
if self._processing_text and self._pause_frame_processing:
|
||||
@@ -274,8 +274,8 @@ class TTSService(AIService):
|
||||
|
||||
# Process all filter.
|
||||
for filter in self._text_filters:
|
||||
filter.reset_interruption()
|
||||
text = filter.filter(text)
|
||||
await filter.reset_interruption()
|
||||
text = await filter.filter(text)
|
||||
|
||||
if text:
|
||||
await self.process_generator(self.run_tts(text))
|
||||
|
||||
@@ -10,17 +10,17 @@ from typing import Any, Mapping
|
||||
|
||||
class BaseTextFilter(ABC):
|
||||
@abstractmethod
|
||||
def update_settings(self, settings: Mapping[str, Any]):
|
||||
async def update_settings(self, settings: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, text: str) -> str:
|
||||
async def filter(self, text: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_interruption(self):
|
||||
async def handle_interruption(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_interruption(self):
|
||||
async def reset_interruption(self):
|
||||
pass
|
||||
|
||||
@@ -33,12 +33,12 @@ class MarkdownTextFilter(BaseTextFilter):
|
||||
self._in_table = False
|
||||
self._interrupted = False
|
||||
|
||||
def update_settings(self, settings: Mapping[str, Any]):
|
||||
async def update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if hasattr(self._settings, key):
|
||||
setattr(self._settings, key, value)
|
||||
|
||||
def filter(self, text: str) -> str:
|
||||
async def filter(self, text: str) -> str:
|
||||
if self._settings.enable_text_filter:
|
||||
# Remove newlines and replace with a space only when there's no text before or after
|
||||
filtered_text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE)
|
||||
@@ -104,12 +104,12 @@ class MarkdownTextFilter(BaseTextFilter):
|
||||
else:
|
||||
return text
|
||||
|
||||
def handle_interruption(self):
|
||||
async def handle_interruption(self):
|
||||
self._interrupted = True
|
||||
self._in_code_block = False
|
||||
self._in_table = False
|
||||
|
||||
def reset_interruption(self):
|
||||
async def reset_interruption(self):
|
||||
self._interrupted = False
|
||||
|
||||
#
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
Some inline code here
|
||||
"""
|
||||
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(result.strip(), expected_text.strip())
|
||||
|
||||
async def test_space_preservation(self):
|
||||
@@ -45,7 +45,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
|
||||
for text in input_text:
|
||||
result = self.filter.filter(text)
|
||||
result = await self.filter.filter(text)
|
||||
self.assertEqual(
|
||||
len(result), len(text), f"Space preservation failed for: '{text}'\nGot: '{result}'"
|
||||
)
|
||||
@@ -71,7 +71,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
for input_text, expected in test_cases.items():
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
@@ -88,7 +88,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
2. Second item
|
||||
3. Third item with bold"""
|
||||
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(
|
||||
result.strip(),
|
||||
expected.strip(),
|
||||
@@ -106,7 +106,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
for input_text, expected in test_cases.items():
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(result, expected, f"HTML entity conversion failed for: '{input_text}'")
|
||||
|
||||
async def test_asterisk_removal(self):
|
||||
@@ -120,7 +120,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
for input_text, expected in test_cases.items():
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(result, expected, f"Asterisk removal failed for: '{input_text}'")
|
||||
|
||||
async def test_newline_handling(self):
|
||||
@@ -132,7 +132,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
for input_text, expected in test_cases.items():
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(
|
||||
result, expected, f"Newline handling failed for:\n{input_text}\nGot:\n{result}"
|
||||
)
|
||||
@@ -148,7 +148,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
for input_text, expected in test_cases.items():
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected,
|
||||
@@ -166,7 +166,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
}
|
||||
|
||||
for input_text, expected in test_cases.items():
|
||||
result = self.filter.filter(input_text)
|
||||
result = await self.filter.filter(input_text)
|
||||
self.assertEqual(result, expected, f"Inline code handling failed for: '{input_text}'")
|
||||
|
||||
async def test_simple_table_removal(self):
|
||||
@@ -177,7 +177,7 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
expected = ""
|
||||
|
||||
result = filter.filter(input_text)
|
||||
result = await filter.filter(input_text)
|
||||
self.assertEqual(
|
||||
result.strip(),
|
||||
expected.strip(),
|
||||
@@ -198,15 +198,15 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
# Test with text filtering disabled
|
||||
text_with_markdown = "**bold** and *italic* with `code`"
|
||||
self.assertEqual(
|
||||
filter.filter(text_with_markdown),
|
||||
await filter.filter(text_with_markdown),
|
||||
text_with_markdown,
|
||||
"Disabled filter should not modify text",
|
||||
)
|
||||
|
||||
# Enable just text filtering
|
||||
filter.update_settings({"enable_text_filter": True})
|
||||
await filter.update_settings({"enable_text_filter": True})
|
||||
self.assertEqual(
|
||||
filter.filter(text_with_markdown),
|
||||
await filter.filter(text_with_markdown),
|
||||
"bold and italic with code",
|
||||
"Enabled filter should remove markdown",
|
||||
)
|
||||
@@ -217,14 +217,18 @@ class TestMarkdownTextFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Initial state - formatting should be removed
|
||||
input_text = "**bold** and *italic*"
|
||||
self.assertEqual(filter.filter(input_text), "bold and italic")
|
||||
self.assertEqual(await filter.filter(input_text), "bold and italic")
|
||||
|
||||
# Disable text filtering
|
||||
filter.update_settings({"enable_text_filter": False})
|
||||
self.assertEqual(filter.filter(input_text), input_text, "Text filtering should be disabled")
|
||||
await filter.update_settings({"enable_text_filter": False})
|
||||
self.assertEqual(
|
||||
await filter.filter(input_text), input_text, "Text filtering should be disabled"
|
||||
)
|
||||
|
||||
# Re-enable text filtering
|
||||
filter.update_settings({"enable_text_filter": True})
|
||||
await filter.update_settings({"enable_text_filter": True})
|
||||
self.assertEqual(
|
||||
filter.filter(input_text), "bold and italic", "Text filtering should be re-enabled"
|
||||
await filter.filter(input_text),
|
||||
"bold and italic",
|
||||
"Text filtering should be re-enabled",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user