Add Fish Audio TTS service
This commit is contained in:
99
examples/foundational/07t-interruptible-fish.py
Normal file
99
examples/foundational/07t-interruptible-fish.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.fish import FishAudioTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = FishAudioTTSService(
|
||||
api_key=os.getenv("FISH_API_KEY"),
|
||||
model="4ce7e917cedd4bc2bb2e6ff3a46acaa1", # Barack Obama
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -50,6 +50,7 @@ daily = [ "daily-python~=0.14.0" ]
|
||||
deepgram = [ "deepgram-sdk~=3.7.7" ]
|
||||
elevenlabs = [ "websockets~=13.1" ]
|
||||
fal = [ "fal-client~=0.4.1" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.21.1" ]
|
||||
grok = [ "openai~=1.57.2" ]
|
||||
|
||||
234
src/pipecat/services/fish.py
Normal file
234
src/pipecat/services/fish.py
Normal file
@@ -0,0 +1,234 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import ormsgpack
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# FishAudio supports various output formats
|
||||
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]
|
||||
|
||||
|
||||
class FishAudioTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
latency: Optional[str] = "normal" # "normal" or "balanced"
|
||||
prosody_speed: Optional[float] = 1.0 # Speech speed (0.5-2.0)
|
||||
prosody_volume: Optional[int] = 0 # Volume adjustment in dB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str, # This is the reference_id
|
||||
output_format: FishAudioOutputFormat = "pcm",
|
||||
sample_rate: int = 24000,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = "wss://api.fish.audio/v1/tts/live"
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"latency": params.latency,
|
||||
"format": output_format,
|
||||
"prosody": {
|
||||
"speed": params.prosody_speed,
|
||||
"volume": params.prosody_volume,
|
||||
},
|
||||
"reference_id": model,
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
self._settings["reference_id"] = model
|
||||
await super().set_model(model)
|
||||
logger.info(f"Switching TTS model to: [{model}]")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to Fish Audio")
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websockets.connect(self._base_url, extra_headers=headers)
|
||||
|
||||
# Send initial start message with ormsgpack
|
||||
start_message = {"event": "start", "request": {"text": "", **self._settings}}
|
||||
await self._websocket.send(ormsgpack.packb(start_message))
|
||||
logger.debug("Sent start message to Fish Audio")
|
||||
except Exception as e:
|
||||
logger.error(f"Fish Audio initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Fish Audio")
|
||||
# Send stop event with ormsgpack
|
||||
stop_message = {"event": "stop"}
|
||||
await self._websocket.send(ormsgpack.packb(stop_message))
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing websocket: {e}")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
msg = ormsgpack.unpackb(message)
|
||||
if isinstance(msg, dict):
|
||||
event = msg.get("event")
|
||||
print(f"Received event: {event}")
|
||||
if event == "audio":
|
||||
await self.stop_ttfb_metrics()
|
||||
audio_data = msg.get("audio")
|
||||
# Only process larger chunks to remove msgpack overhead
|
||||
if audio_data and len(audio_data) > 1024:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio_data, self._settings["sample_rate"], 1
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
|
||||
async def _reconnect_websocket(self, retry_state: RetryCallState):
|
||||
logger.warning(f"Fish Audio reconnecting (attempt: {retry_state.attempt_number})")
|
||||
await self._disconnect_websocket()
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
async for attempt in AsyncRetrying(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
before_sleep=self._reconnect_websocket,
|
||||
reraise=True,
|
||||
):
|
||||
with attempt:
|
||||
await self._receive_messages()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
message = f"Fish Audio error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
await self.push_error(ErrorFrame(message, fatal=True))
|
||||
break
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
await self.pause_processing_frames()
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) and self._request_id:
|
||||
await self.pause_processing_frames()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self.resume_processing_frames()
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
self._request_id = None
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating Fish TTS: [{text}]")
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
await self._connect()
|
||||
|
||||
if not self._request_id:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._request_id = str(uuid.uuid4())
|
||||
|
||||
text_message = {
|
||||
"event": "text",
|
||||
"text": text,
|
||||
}
|
||||
await self._get_websocket().send(ormsgpack.packb(text_message))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(f"Error: {str(e)}")
|
||||
Reference in New Issue
Block a user