Compare commits

...

1 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
dcf7e454f6 services(tts): don't wait for BotStoppedSpeakingFrame to resume input queue 2024-12-02 19:00:47 -08:00
3 changed files with 31 additions and 15 deletions

View File

@@ -14,7 +14,6 @@ from loguru import logger
from pydantic.main import BaseModel
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
@@ -259,18 +258,25 @@ class CartesiaTTSService(WordTTSService):
except Exception as e:
logger.error(f"{self} exception: {e}")
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)
# We generate LLMFullResponseEndFrame after we have received all the
# audio from the service which means we can resume processing frames.
if isinstance(frame, LLMFullResponseEndFrame):
await self.resume_processing_frames()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we received a TTSSpeakFrame and the LLM response included text (it
# If we received a TTSSpeakFrame or the LLM response included text (it
# might be that it's only a function calling response) we pause
# processing more frames until we receive a BotStoppedSpeakingFrame.
# processing more frames until we have generated LLMFullResponseEndFrame
# (see push_frame()).
if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._context_id:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

View File

@@ -13,7 +13,6 @@ from loguru import logger
from pydantic import BaseModel, model_validator
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
Frame,
@@ -262,23 +261,28 @@ class ElevenLabsTTSService(WordTTSService):
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)])
# We generate LLMFullResponseEndFrame after we have received all the
# audio from the service which means we can resume processing frames.
if isinstance(frame, LLMFullResponseEndFrame):
await self.resume_processing_frames()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we received a TTSSpeakFrame and the LLM response included text (it
# If we received a TTSSpeakFrame or the LLM response included text (it
# might be that it's only a function calling response) we pause
# processing more frames until we receive a BotStoppedSpeakingFrame.
# processing more frames until we have generated LLMFullResponseEndFrame
# (see push_frame()).
if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._started:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()
async def _connect(self):
try:

View File

@@ -17,7 +17,6 @@ from loguru import logger
from pydantic.main import BaseModel
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
@@ -235,18 +234,25 @@ class PlayHTTTSService(TTSService):
except Exception as e:
logger.error(f"{self} exception in receive task: {e}")
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)
# We generate LLMFullResponseEndFrame after we have received all the
# audio from the service which means we can resume processing frames.
if isinstance(frame, LLMFullResponseEndFrame):
await self.resume_processing_frames()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we received a TTSSpeakFrame and the LLM response included text (it
# If we received a TTSSpeakFrame or the LLM response included text (it
# might be that it's only a function calling response) we pause
# processing more frames until we receive a BotStoppedSpeakingFrame.
# processing more frames until we have generated LLMFullResponseEndFrame
# (see push_frame()).
if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._request_id:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")