introduce AudioContextWordTTSService

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-14 11:18:24 -08:00
parent af66a43056
commit cacb07f4c2
2 changed files with 114 additions and 2 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added new `AudioContextWordTTSService`. This is a TTS base class for TTS
services that handling multiple separate audio requests.
- Added new frames `EmulateUserStartedSpeakingFrame` and
`EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior
without VAD being present or not being triggered.

View File

@@ -419,7 +419,7 @@ class WordTTSService(TTSService):
async def start(self, frame: StartFrame):
await super().start(frame)
await self._create_words_task()
self._create_words_task()
async def stop(self, frame: EndFrame):
await super().stop(frame)
@@ -439,7 +439,7 @@ class WordTTSService(TTSService):
await super()._handle_interruption(frame, direction)
self.reset_word_timestamps()
async def _create_words_task(self):
def _create_words_task(self):
self._words_task = self.create_task(self._words_task_handler())
async def _stop_words_task(self):
@@ -469,6 +469,115 @@ class WordTTSService(TTSService):
self._words_queue.task_done()
class AudioContextWordTTSService(WordTTSService):
"""This services allow us to send multiple TTS request to the services. Each
request could be multiple sentences long which are grouped by context. For
this to work, the TTS service needs to support handling multiple requests at
once (i.e. multiple simultaneous contexts).
The audio received from the TTS will be played in context order. That is, if
we requested audio for a context "A" and then audio for context "B", the
audio from context ID "A" will be played first.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._contexts_queue = asyncio.Queue()
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = None
async def create_audio_context(self, context_id: str):
"""Create a new audio context."""
await self._contexts_queue.put(context_id)
self._contexts[context_id] = asyncio.Queue()
logger.trace(f"{self} created audio context {context_id}")
async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame):
"""Append audio to an existing context."""
if self.audio_context_available(context_id):
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
await self._contexts[context_id].put(frame)
else:
logger.warning(f"{self} unable to append audio to context {context_id}")
async def remove_audio_context(self, context_id: str):
"""Remove an existing audio context."""
if self.audio_context_available(context_id):
# We just mark the audio context for deletion by appending
# None. Once we reach None while handling audio we know we can
# safely remove the context.
logger.trace(f"{self} marking audio context {context_id} for deletion")
await self._contexts[context_id].put(None)
else:
logger.warning(f"{self} unable to remove context {context_id}")
def audio_context_available(self, context_id: str) -> bool:
"""Checks whether the given audio context is registered."""
return context_id in self._contexts
async def start(self, frame: StartFrame):
await super().start(frame)
self._create_audio_context_task()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._stop_audio_context_task()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._stop_audio_context_task()
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self._stop_audio_context_task()
self._create_audio_context_task()
def _create_audio_context_task(self):
self._contexts_queue = asyncio.Queue()
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = self.create_task(self._audio_context_task_handler())
async def _stop_audio_context_task(self):
if self._audio_context_task:
await self.cancel_task(self._audio_context_task)
self._audio_context_task = None
async def _audio_context_task_handler(self):
"""In this task we process audio contexts in order."""
while True:
context_id = await self._contexts_queue.get()
# Process the audio context until the context doesn't have more
# audio available (i.e. we find None).
await self._handle_audio_context(context_id)
# We just finished processing the context, so we can safely remove it.
del self._contexts[context_id]
self._contexts_queue.task_done()
# Append some silence between sentences.
silence = b"\x00" * self.sample_rate
frame = TTSAudioRawFrame(audio=silence, sample_rate=self.sample_rate, num_channels=1)
await self.push_frame(frame)
async def _handle_audio_context(self, context_id: str):
# If we don't receive any audio during this time, we consider the context finished.
AUDIO_CONTEXT_TIMEOUT = 3.0
queue = self._contexts[context_id]
running = True
while running:
try:
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
if frame:
await self.push_frame(frame)
running = frame is not None
except asyncio.TimeoutError:
# We didn't get audio, so let's consider this context finished.
logger.trace(f"{self} time out on audio context {context_id}")
break
class STTService(AIService):
"""STTService is a base class for speech-to-text services."""