Add support for Piper TTS

This commit is contained in:
Pedro Moreira
2025-02-04 13:51:33 -03:00
parent cc54255c41
commit 79ac696973
3 changed files with 205 additions and 0 deletions

View File

@@ -0,0 +1,103 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import AsyncGenerator
import aiohttp
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService
# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md
class PiperTTSService(TTSService):
"""Piper TTS service implementation.
Provides integration with Piper's TTS server.
"""
def __init__(
self,
*,
base_url: str,
aiohttp_session: aiohttp.ClientSession | None = None,
sample_rate: int = 24000,
**kwargs,
):
"""Initialize the PiperTTSService class instance.
Args:
base_url (str): Base URL of the Piper TTS server (should not end with a slash).
aiohttp_session (aiohttp.ClientSession, optional): Optional aiohttp session to use for requests. Defaults to None.
sample_rate (int, optional): Sample rate in Hz. Defaults to 24000.
**kwargs (dict): Additional keyword arguments.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
if not aiohttp_session:
aiohttp_session = aiohttp.ClientSession()
if base_url.endswith("/"):
logger.warning("Base URL ends with a slash, this is not allowed.")
base_url = base_url[:-1]
self._settings = {"base_url": base_url}
self.set_voice("voice_id")
self._aiohttp_session = aiohttp_session
def can_generate_metrics(self) -> bool:
return True
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
url = self._settings["base_url"] + "/?text=" + text.replace(".", "").replace("*", "")
await self.start_ttfb_metrics()
async with self._aiohttp_session.get(url) as r:
if r.status != 200:
text = await r.text()
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame()
buffer = bytearray()
async for chunk in r.content.iter_chunked(1024):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
# Append new chunk to the buffer.
buffer.extend(chunk)
# Check if buffer has enough data for processing.
while (
len(buffer) >= 48000
): # Assuming at least 0.5 seconds of audio data at 24000 Hz
# Process the buffer up to a safe size for resampling.
process_data = buffer[:48000]
# Remove processed data from buffer.
buffer = buffer[48000:]
frame = TTSAudioRawFrame(process_data, self._sample_rate, 1)
yield frame
# Process any remaining data in the buffer.
if len(buffer) > 0:
frame = TTSAudioRawFrame(buffer, self._sample_rate, 1)
yield frame
yield TTSStoppedFrame()

View File

@@ -21,6 +21,7 @@ pyaudio~=0.2.14
pydantic~=2.8.2
pyloudnorm~=0.1.1
pyht~=0.1.4
pytest-aiohttp==1.1.0
python-dotenv~=1.0.1
silero-vad~=5.1
soxr~=0.5.0

101
tests/test_piper_tts.py Normal file
View File

@@ -0,0 +1,101 @@
"""Tests for PiperTTSService."""
import asyncio
import pytest
from aiohttp import web
from pipecat.frames.frames import (
ErrorFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.piper import PiperTTSService
@pytest.mark.asyncio
async def test_run_piper_tts_success(aiohttp_client):
"""Test successful TTS generation with chunked audio data.
Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
"""
async def handler(request):
# The service expects a /?text= param
# Here we're just returning dummy chunked bytes to simulate an audio response
text_query = request.rel_url.query.get("text", "")
print(f"Mock server received text param: {text_query}")
# Prepare a StreamResponse with chunked data
resp = web.StreamResponse(
status=200,
reason="OK",
headers={"Content-Type": "audio/raw"},
)
await resp.prepare(request)
# Write out some chunked byte data
# In reality, youd return WAV data or similar
data_chunk_1 = b"\x00\x01\x02\x03" * 12000 # 48000 bytes
data_chunk_2 = b"\x04\x05\x06\x07" * 6000 # another chunk
await resp.write(data_chunk_1)
await asyncio.sleep(0.01) # simulate async chunk delay
await resp.write(data_chunk_2)
await resp.write_eof()
return resp
# Create an aiohttp test server
app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)
# Remove trailing slash if present in the test URL
base_url = str(client.make_url("")).rstrip("/")
# Instantiate PiperTTSService with our mock server
tts_service = PiperTTSService(base_url=base_url)
# Collect frames from the generator
frames = []
async for frame in tts_service.run_tts("Hello world."):
frames.append(frame)
# Ensure we received frames in the expected order/types
assert len(frames) >= 3, "Expecting at least TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame"
assert isinstance(frames[0], TTSStartedFrame), "First frame must be TTSStartedFrame"
assert isinstance(frames[-1], TTSStoppedFrame), "Last frame must be TTSStoppedFrame"
# Check we have at least one TTSAudioRawFrame
audio_frames = [f for f in frames if isinstance(f, TTSAudioRawFrame)]
assert len(audio_frames) > 0, "Should have received at least one TTSAudioRawFrame"
for a_frame in audio_frames:
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
@pytest.mark.asyncio
async def test_run_piper_tts_error(aiohttp_client):
"""Test how the service handles a non-200 response from the server.
Expects an ErrorFrame to be returned.
"""
async def handler(_request):
# Return an error status for any request
return web.Response(status=404, text="Not found")
app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)
base_url = str(client.make_url("")).rstrip("/")
tts_service = PiperTTSService(base_url=base_url)
frames = []
async for frame in tts_service.run_tts("Error case."):
frames.append(frame)
assert len(frames) == 1, "Should only receive a single ErrorFrame"
assert isinstance(frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
assert "status: 404" in frames[0].error, "ErrorFrame should contain details about the 404"