From 8a39d3f4eb5e0bb8b028149c9e405ae2c80bdc8f Mon Sep 17 00:00:00 2001 From: Sharvil Nanavati Date: Thu, 22 Aug 2024 05:28:02 +0000 Subject: [PATCH] services: add a generic mechanism to produce TTSStoppedFrames --- src/pipecat/services/ai_services.py | 28 +++++++++++++++++++++++++++- src/pipecat/services/lmnt.py | 4 ++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 197a73bac..47a7d39b1 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -8,7 +8,8 @@ import io import wave from abc import abstractmethod -from typing import AsyncGenerator +from asyncio import Task, sleep +from typing import AsyncGenerator, Optional from pipecat.frames.frames import ( AudioRawFrame, @@ -20,6 +21,8 @@ from pipecat.frames.frames import ( StartFrame, StartInterruptionFrame, TTSSpeakFrame, + TTSStartedFrame, + TTSStoppedFrame, TTSVoiceUpdateFrame, TextFrame, UserImageRequestFrame, @@ -156,10 +159,17 @@ class TTSService(AIService): aggregate_sentences: bool = True, # if True, subclass is responsible for pushing TextFrames and LLMFullResponseEndFrames push_text_frames: bool = True, + # if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it + push_stop_frames: bool = False, + # if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame + stop_frame_timeout_s: float = 0.8, **kwargs): super().__init__(**kwargs) self._aggregate_sentences: bool = aggregate_sentences self._push_text_frames: bool = push_text_frames + self._push_stop_frames: bool = push_stop_frames + self._stop_frame_timeout_s: float = stop_frame_timeout_s + self._stop_frame_task: Optional[Task] = None self._current_sentence: str = "" @abstractmethod @@ -227,6 +237,22 @@ class TTSService(AIService): else: await self.push_frame(frame, direction) + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): + await super().push_frame(frame, direction) + + if isinstance(frame, AudioRawFrame) and self._stop_frame_task is not None: + # Reschedule timeout task if it was already running + self._stop_frame_task.cancel() + self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler()) + elif isinstance(frame, TTSStartedFrame) and self._push_stop_frames: + # Start timeout task if necessary + self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler()) + + async def _stop_frame_handler(self): + await sleep(self._stop_frame_timeout_s) + await self.push_frame(TTSStoppedFrame()) + self._stop_frame_task = None + class STTService(AIService): """STTService is a base class for speech-to-text services.""" diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index e66df24ee..74daaf017 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -49,6 +49,10 @@ class LmntTTSService(TTSService): **kwargs): super().__init__(**kwargs) + # Let TTSService produce TTSStoppedFrames after a short delay of + # no activity. + self._push_stop_frames = True + self._api_key = api_key self._voice_id = voice_id self._output_format = {