Improve usability of Deepgram TTS: use Deepgram client, remove aiohttp
This commit is contained in:
@@ -5,10 +5,14 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -22,12 +26,6 @@ from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -52,9 +50,7 @@ async def main():
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = DeepgramTTSService(
|
||||
aiohttp_session=session, api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en"
|
||||
)
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en")
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -15,27 +16,25 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# See .env.example for Deepgram configuration needed
|
||||
try:
|
||||
from deepgram import (
|
||||
AsyncListenWebSocketClient,
|
||||
DeepgramClient,
|
||||
DeepgramClientOptions,
|
||||
LiveTranscriptionEvents,
|
||||
LiveOptions,
|
||||
LiveResultResponse,
|
||||
LiveTranscriptionEvents,
|
||||
SpeakOptions,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -50,9 +49,7 @@ class DeepgramTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
voice: str = "aura-helios-en",
|
||||
base_url: str = "https://api.deepgram.com/v1/speak",
|
||||
sample_rate: int = 16000,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
@@ -60,11 +57,9 @@ class DeepgramTTSService(TTSService):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._sample_rate = sample_rate
|
||||
self._encoding = encoding
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._deepgram_client = DeepgramClient(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -76,43 +71,45 @@ class DeepgramTTSService(TTSService):
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
base_url = self._base_url
|
||||
request_url = f"{base_url}?model={self._voice}&encoding={self._encoding}&container=none&sample_rate={self._sample_rate}"
|
||||
headers = {"authorization": f"token {self._api_key}"}
|
||||
body = {"text": text}
|
||||
options = SpeakOptions(
|
||||
model=self._voice,
|
||||
encoding=self._encoding,
|
||||
sample_rate=self._sample_rate,
|
||||
container="none",
|
||||
)
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
async with self._aiohttp_session.post(request_url, headers=headers, json=body) as r:
|
||||
if r.status != 200:
|
||||
response_text = await r.text()
|
||||
# If we get a a "Bad Request: Input is unutterable", just print out a debug log.
|
||||
# All other unsuccesful requests should emit an error frame. If not specifically
|
||||
# handled by the running PipelineTask, the ErrorFrame will cancel the task.
|
||||
if "unutterable" in response_text:
|
||||
logger.debug(f"Unutterable text: [{text}]")
|
||||
return
|
||||
|
||||
logger.error(
|
||||
f"{self} error getting audio (status: {r.status}, error: {response_text})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {r.status}, error: {response_text})"
|
||||
)
|
||||
return
|
||||
response = await asyncio.to_thread(
|
||||
self._deepgram_client.speak.v("1").stream, {"text": text}, options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
|
||||
# The response.stream_memory is already a BytesIO object
|
||||
audio_buffer = response.stream_memory
|
||||
|
||||
if audio_buffer is None:
|
||||
raise ValueError("No audio data received from Deepgram")
|
||||
|
||||
# Read and yield the audio data in chunks
|
||||
audio_buffer.seek(0) # Ensure we're at the start of the buffer
|
||||
chunk_size = 8192 # Use a fixed buffer size
|
||||
while True:
|
||||
await self.stop_ttfb_metrics()
|
||||
chunk = audio_buffer.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(audio=chunk, sample_rate=self._sample_rate, num_channels=1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
async for data in r.content:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=data, sample_rate=self._sample_rate, num_channels=1
|
||||
)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
|
||||
|
||||
class DeepgramSTTService(STTService):
|
||||
|
||||
Reference in New Issue
Block a user