services: add a generic mechanism to produce TTSStoppedFrames

This commit is contained in:
Sharvil Nanavati
2024-08-22 05:28:02 +00:00
parent f4fd7b7028
commit 8a39d3f4eb
2 changed files with 31 additions and 1 deletions

View File

@@ -8,7 +8,8 @@ import io
import wave
from abc import abstractmethod
from typing import AsyncGenerator
from asyncio import Task, sleep
from typing import AsyncGenerator, Optional
from pipecat.frames.frames import (
AudioRawFrame,
@@ -20,6 +21,8 @@ from pipecat.frames.frames import (
StartFrame,
StartInterruptionFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSVoiceUpdateFrame,
TextFrame,
UserImageRequestFrame,
@@ -156,10 +159,17 @@ class TTSService(AIService):
aggregate_sentences: bool = True,
# if True, subclass is responsible for pushing TextFrames and LLMFullResponseEndFrames
push_text_frames: bool = True,
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
push_stop_frames: bool = False,
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
stop_frame_timeout_s: float = 0.8,
**kwargs):
super().__init__(**kwargs)
self._aggregate_sentences: bool = aggregate_sentences
self._push_text_frames: bool = push_text_frames
self._push_stop_frames: bool = push_stop_frames
self._stop_frame_timeout_s: float = stop_frame_timeout_s
self._stop_frame_task: Optional[Task] = None
self._current_sentence: str = ""
@abstractmethod
@@ -227,6 +237,22 @@ class TTSService(AIService):
else:
await self.push_frame(frame, direction)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)
if isinstance(frame, AudioRawFrame) and self._stop_frame_task is not None:
# Reschedule timeout task if it was already running
self._stop_frame_task.cancel()
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
elif isinstance(frame, TTSStartedFrame) and self._push_stop_frames:
# Start timeout task if necessary
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
async def _stop_frame_handler(self):
await sleep(self._stop_frame_timeout_s)
await self.push_frame(TTSStoppedFrame())
self._stop_frame_task = None
class STTService(AIService):
"""STTService is a base class for speech-to-text services."""

View File

@@ -49,6 +49,10 @@ class LmntTTSService(TTSService):
**kwargs):
super().__init__(**kwargs)
# Let TTSService produce TTSStoppedFrames after a short delay of
# no activity.
self._push_stop_frames = True
self._api_key = api_key
self._voice_id = voice_id
self._output_format = {