From 73da8c1910fc7ae37f3f1ac6bf8c50d89aabd927 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 25 Sep 2024 22:40:36 -0400 Subject: [PATCH] Improve usability of Deepgram TTS: use Deepgram client, remove aiohttp --- .../07c-interruptible-deepgram.py | 16 ++-- src/pipecat/services/deepgram.py | 79 +++++++++---------- 2 files changed, 44 insertions(+), 51 deletions(-) diff --git a/examples/foundational/07c-interruptible-deepgram.py b/examples/foundational/07c-interruptible-deepgram.py index 41bef8a47..fc33c246f 100644 --- a/examples/foundational/07c-interruptible-deepgram.py +++ b/examples/foundational/07c-interruptible-deepgram.py @@ -5,10 +5,14 @@ # import asyncio -import aiohttp import os import sys +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + from pipecat.frames.frames import LLMMessagesFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -22,12 +26,6 @@ from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport from pipecat.vad.silero import SileroVADAnalyzer -from runner import configure - -from loguru import logger - -from dotenv import load_dotenv - load_dotenv(override=True) logger.remove(0) @@ -52,9 +50,7 @@ async def main(): stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) - tts = DeepgramTTSService( - aiohttp_session=session, api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en" - ) + tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en") llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index 914bc2ec2..6929e66e5 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -4,10 +4,11 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import aiohttp - +import asyncio from typing import AsyncGenerator +from loguru import logger + from pipecat.frames.frames import ( CancelFrame, EndFrame, @@ -15,27 +16,25 @@ from pipecat.frames.frames import ( Frame, InterimTranscriptionFrame, StartFrame, + TranscriptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, - TranscriptionFrame, ) from pipecat.services.ai_services import STTService, TTSService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 -from loguru import logger - - # See .env.example for Deepgram configuration needed try: from deepgram import ( AsyncListenWebSocketClient, DeepgramClient, DeepgramClientOptions, - LiveTranscriptionEvents, LiveOptions, LiveResultResponse, + LiveTranscriptionEvents, + SpeakOptions, ) except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -50,9 +49,7 @@ class DeepgramTTSService(TTSService): self, *, api_key: str, - aiohttp_session: aiohttp.ClientSession, voice: str = "aura-helios-en", - base_url: str = "https://api.deepgram.com/v1/speak", sample_rate: int = 16000, encoding: str = "linear16", **kwargs, @@ -60,11 +57,9 @@ class DeepgramTTSService(TTSService): super().__init__(**kwargs) self._voice = voice - self._api_key = api_key - self._base_url = base_url self._sample_rate = sample_rate self._encoding = encoding - self._aiohttp_session = aiohttp_session + self._deepgram_client = DeepgramClient(api_key=api_key) def can_generate_metrics(self) -> bool: return True @@ -76,43 +71,45 @@ class DeepgramTTSService(TTSService): async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") - base_url = self._base_url - request_url = f"{base_url}?model={self._voice}&encoding={self._encoding}&container=none&sample_rate={self._sample_rate}" - headers = {"authorization": f"token {self._api_key}"} - body = {"text": text} + options = SpeakOptions( + model=self._voice, + encoding=self._encoding, + sample_rate=self._sample_rate, + container="none", + ) try: await self.start_ttfb_metrics() - async with self._aiohttp_session.post(request_url, headers=headers, json=body) as r: - if r.status != 200: - response_text = await r.text() - # If we get a a "Bad Request: Input is unutterable", just print out a debug log. - # All other unsuccesful requests should emit an error frame. If not specifically - # handled by the running PipelineTask, the ErrorFrame will cancel the task. - if "unutterable" in response_text: - logger.debug(f"Unutterable text: [{text}]") - return - logger.error( - f"{self} error getting audio (status: {r.status}, error: {response_text})" - ) - yield ErrorFrame( - f"Error getting audio (status: {r.status}, error: {response_text})" - ) - return + response = await asyncio.to_thread( + self._deepgram_client.speak.v("1").stream, {"text": text}, options + ) - await self.start_tts_usage_metrics(text) + await self.start_tts_usage_metrics(text) + await self.push_frame(TTSStartedFrame()) + + # The response.stream_memory is already a BytesIO object + audio_buffer = response.stream_memory + + if audio_buffer is None: + raise ValueError("No audio data received from Deepgram") + + # Read and yield the audio data in chunks + audio_buffer.seek(0) # Ensure we're at the start of the buffer + chunk_size = 8192 # Use a fixed buffer size + while True: + await self.stop_ttfb_metrics() + chunk = audio_buffer.read(chunk_size) + if not chunk: + break + frame = TTSAudioRawFrame(audio=chunk, sample_rate=self._sample_rate, num_channels=1) + yield frame + + await self.push_frame(TTSStoppedFrame()) - await self.push_frame(TTSStartedFrame()) - async for data in r.content: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=data, sample_rate=self._sample_rate, num_channels=1 - ) - yield frame - await self.push_frame(TTSStoppedFrame()) except Exception as e: logger.exception(f"{self} exception: {e}") + yield ErrorFrame(f"Error getting audio: {str(e)}") class DeepgramSTTService(STTService):