Add Fish Audio TTS service

This commit is contained in:
Mark Backman
2024-12-20 14:41:57 -05:00
parent 503eddf7d6
commit dac4468ca1
3 changed files with 334 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.fish import FishAudioTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = FishAudioTTSService(
api_key=os.getenv("FISH_API_KEY"),
model="4ce7e917cedd4bc2bb2e6ff3a46acaa1", # Barack Obama
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -50,6 +50,7 @@ daily = [ "daily-python~=0.14.0" ]
deepgram = [ "deepgram-sdk~=3.7.7" ]
elevenlabs = [ "websockets~=13.1" ]
fal = [ "fal-client~=0.4.1" ]
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.21.1" ]
grok = [ "openai~=1.57.2" ]

View File

@@ -0,0 +1,234 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import uuid
from typing import AsyncGenerator, Literal, Optional
from loguru import logger
from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.transcriptions.language import Language
try:
import ormsgpack
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
# FishAudio supports various output formats
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]
class FishAudioTTSService(TTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
latency: Optional[str] = "normal" # "normal" or "balanced"
prosody_speed: Optional[float] = 1.0 # Speech speed (0.5-2.0)
prosody_volume: Optional[int] = 0 # Volume adjustment in dB
def __init__(
self,
*,
api_key: str,
model: str, # This is the reference_id
output_format: FishAudioOutputFormat = "pcm",
sample_rate: int = 24000,
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
self._api_key = api_key
self._base_url = "wss://api.fish.audio/v1/tts/live"
self._websocket = None
self._receive_task = None
self._request_id = None
self._started = False
self._settings = {
"sample_rate": sample_rate,
"latency": params.latency,
"format": output_format,
"prosody": {
"speed": params.prosody_speed,
"volume": params.prosody_volume,
},
"reference_id": model,
}
self.set_model_name(model)
def can_generate_metrics(self) -> bool:
return True
async def set_model(self, model: str):
self._settings["reference_id"] = model
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")
async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
async def _connect_websocket(self):
try:
logger.debug("Connecting to Fish Audio")
headers = {"Authorization": f"Bearer {self._api_key}"}
self._websocket = await websockets.connect(self._base_url, extra_headers=headers)
# Send initial start message with ormsgpack
start_message = {"event": "start", "request": {"text": "", **self._settings}}
await self._websocket.send(ormsgpack.packb(start_message))
logger.debug("Sent start message to Fish Audio")
except Exception as e:
logger.error(f"Fish Audio initialization error: {e}")
self._websocket = None
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Fish Audio")
# Send stop event with ormsgpack
stop_message = {"event": "stop"}
await self._websocket.send(ormsgpack.packb(stop_message))
await self._websocket.close()
self._websocket = None
self._request_id = None
self._started = False
except Exception as e:
logger.error(f"Error closing websocket: {e}")
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _receive_messages(self):
async for message in self._get_websocket():
try:
if isinstance(message, bytes):
msg = ormsgpack.unpackb(message)
if isinstance(msg, dict):
event = msg.get("event")
print(f"Received event: {event}")
if event == "audio":
await self.stop_ttfb_metrics()
audio_data = msg.get("audio")
# Only process larger chunks to remove msgpack overhead
if audio_data and len(audio_data) > 1024:
frame = TTSAudioRawFrame(
audio_data, self._settings["sample_rate"], 1
)
await self.push_frame(frame)
continue
except Exception as e:
logger.error(f"Error processing message: {e}")
async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"Fish Audio reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()
async def _receive_task_handler(self):
while True:
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
message = f"Fish Audio error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._request_id:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating Fish TTS: [{text}]")
try:
if not self._websocket or self._websocket.closed:
await self._connect()
if not self._request_id:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._request_id = str(uuid.uuid4())
text_message = {
"event": "text",
"text": text,
}
await self._get_websocket().send(ormsgpack.packb(text_message))
await self.start_tts_usage_metrics(text)
yield None
except Exception as e:
logger.error(f"Error generating TTS: {e}")
yield ErrorFrame(f"Error: {str(e)}")