From cacb07f4c209869314213885acbb39d48d051f4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 14 Feb 2025 11:18:24 -0800 Subject: [PATCH] introduce AudioContextWordTTSService --- CHANGELOG.md | 3 + src/pipecat/services/ai_services.py | 113 +++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 90c80da6e..85848f22b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added new `AudioContextWordTTSService`. This is a TTS base class for TTS + services that handling multiple separate audio requests. + - Added new frames `EmulateUserStartedSpeakingFrame` and `EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior without VAD being present or not being triggered. diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index ac1c4582d..0fce2482a 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -419,7 +419,7 @@ class WordTTSService(TTSService): async def start(self, frame: StartFrame): await super().start(frame) - await self._create_words_task() + self._create_words_task() async def stop(self, frame: EndFrame): await super().stop(frame) @@ -439,7 +439,7 @@ class WordTTSService(TTSService): await super()._handle_interruption(frame, direction) self.reset_word_timestamps() - async def _create_words_task(self): + def _create_words_task(self): self._words_task = self.create_task(self._words_task_handler()) async def _stop_words_task(self): @@ -469,6 +469,115 @@ class WordTTSService(TTSService): self._words_queue.task_done() +class AudioContextWordTTSService(WordTTSService): + """This services allow us to send multiple TTS request to the services. Each + request could be multiple sentences long which are grouped by context. For + this to work, the TTS service needs to support handling multiple requests at + once (i.e. multiple simultaneous contexts). + + The audio received from the TTS will be played in context order. That is, if + we requested audio for a context "A" and then audio for context "B", the + audio from context ID "A" will be played first. + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._contexts_queue = asyncio.Queue() + self._contexts: Dict[str, asyncio.Queue] = {} + self._audio_context_task = None + + async def create_audio_context(self, context_id: str): + """Create a new audio context.""" + await self._contexts_queue.put(context_id) + self._contexts[context_id] = asyncio.Queue() + logger.trace(f"{self} created audio context {context_id}") + + async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame): + """Append audio to an existing context.""" + if self.audio_context_available(context_id): + logger.trace(f"{self} appending audio {frame} to audio context {context_id}") + await self._contexts[context_id].put(frame) + else: + logger.warning(f"{self} unable to append audio to context {context_id}") + + async def remove_audio_context(self, context_id: str): + """Remove an existing audio context.""" + if self.audio_context_available(context_id): + # We just mark the audio context for deletion by appending + # None. Once we reach None while handling audio we know we can + # safely remove the context. + logger.trace(f"{self} marking audio context {context_id} for deletion") + await self._contexts[context_id].put(None) + else: + logger.warning(f"{self} unable to remove context {context_id}") + + def audio_context_available(self, context_id: str) -> bool: + """Checks whether the given audio context is registered.""" + return context_id in self._contexts + + async def start(self, frame: StartFrame): + await super().start(frame) + self._create_audio_context_task() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._stop_audio_context_task() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._stop_audio_context_task() + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + await self._stop_audio_context_task() + self._create_audio_context_task() + + def _create_audio_context_task(self): + self._contexts_queue = asyncio.Queue() + self._contexts: Dict[str, asyncio.Queue] = {} + self._audio_context_task = self.create_task(self._audio_context_task_handler()) + + async def _stop_audio_context_task(self): + if self._audio_context_task: + await self.cancel_task(self._audio_context_task) + self._audio_context_task = None + + async def _audio_context_task_handler(self): + """In this task we process audio contexts in order.""" + while True: + context_id = await self._contexts_queue.get() + + # Process the audio context until the context doesn't have more + # audio available (i.e. we find None). + await self._handle_audio_context(context_id) + + # We just finished processing the context, so we can safely remove it. + del self._contexts[context_id] + self._contexts_queue.task_done() + + # Append some silence between sentences. + silence = b"\x00" * self.sample_rate + frame = TTSAudioRawFrame(audio=silence, sample_rate=self.sample_rate, num_channels=1) + await self.push_frame(frame) + + async def _handle_audio_context(self, context_id: str): + # If we don't receive any audio during this time, we consider the context finished. + AUDIO_CONTEXT_TIMEOUT = 3.0 + queue = self._contexts[context_id] + running = True + while running: + try: + frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT) + if frame: + await self.push_frame(frame) + running = frame is not None + except asyncio.TimeoutError: + # We didn't get audio, so let's consider this context finished. + logger.trace(f"{self} time out on audio context {context_id}") + break + + class STTService(AIService): """STTService is a base class for speech-to-text services."""