Add support for Piper TTS
This commit is contained in:
103
src/pipecat/services/piper.py
Normal file
103
src/pipecat/services/piper.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md
|
||||
|
||||
|
||||
class PiperTTSService(TTSService):
|
||||
"""Piper TTS service implementation.
|
||||
|
||||
Provides integration with Piper's TTS server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
sample_rate: int = 24000,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the PiperTTSService class instance.
|
||||
|
||||
Args:
|
||||
base_url (str): Base URL of the Piper TTS server (should not end with a slash).
|
||||
aiohttp_session (aiohttp.ClientSession, optional): Optional aiohttp session to use for requests. Defaults to None.
|
||||
sample_rate (int, optional): Sample rate in Hz. Defaults to 24000.
|
||||
**kwargs (dict): Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
if not aiohttp_session:
|
||||
aiohttp_session = aiohttp.ClientSession()
|
||||
|
||||
if base_url.endswith("/"):
|
||||
logger.warning("Base URL ends with a slash, this is not allowed.")
|
||||
base_url = base_url[:-1]
|
||||
|
||||
self._settings = {"base_url": base_url}
|
||||
self.set_voice("voice_id")
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
url = self._settings["base_url"] + "/?text=" + text.replace(".", "").replace("*", "")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._aiohttp_session.get(url) as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
buffer = bytearray()
|
||||
async for chunk in r.content.iter_chunked(1024):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Append new chunk to the buffer.
|
||||
buffer.extend(chunk)
|
||||
|
||||
# Check if buffer has enough data for processing.
|
||||
while (
|
||||
len(buffer) >= 48000
|
||||
): # Assuming at least 0.5 seconds of audio data at 24000 Hz
|
||||
# Process the buffer up to a safe size for resampling.
|
||||
process_data = buffer[:48000]
|
||||
# Remove processed data from buffer.
|
||||
buffer = buffer[48000:]
|
||||
|
||||
frame = TTSAudioRawFrame(process_data, self._sample_rate, 1)
|
||||
yield frame
|
||||
|
||||
# Process any remaining data in the buffer.
|
||||
if len(buffer) > 0:
|
||||
frame = TTSAudioRawFrame(buffer, self._sample_rate, 1)
|
||||
yield frame
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
@@ -21,6 +21,7 @@ pyaudio~=0.2.14
|
||||
pydantic~=2.8.2
|
||||
pyloudnorm~=0.1.1
|
||||
pyht~=0.1.4
|
||||
pytest-aiohttp==1.1.0
|
||||
python-dotenv~=1.0.1
|
||||
silero-vad~=5.1
|
||||
soxr~=0.5.0
|
||||
|
||||
101
tests/test_piper_tts.py
Normal file
101
tests/test_piper_tts.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Tests for PiperTTSService."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.piper import PiperTTSService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_piper_tts_success(aiohttp_client):
|
||||
"""Test successful TTS generation with chunked audio data.
|
||||
|
||||
Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
|
||||
"""
|
||||
|
||||
async def handler(request):
|
||||
# The service expects a /?text= param
|
||||
# Here we're just returning dummy chunked bytes to simulate an audio response
|
||||
text_query = request.rel_url.query.get("text", "")
|
||||
print(f"Mock server received text param: {text_query}")
|
||||
|
||||
# Prepare a StreamResponse with chunked data
|
||||
resp = web.StreamResponse(
|
||||
status=200,
|
||||
reason="OK",
|
||||
headers={"Content-Type": "audio/raw"},
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
# Write out some chunked byte data
|
||||
# In reality, you’d return WAV data or similar
|
||||
data_chunk_1 = b"\x00\x01\x02\x03" * 12000 # 48000 bytes
|
||||
data_chunk_2 = b"\x04\x05\x06\x07" * 6000 # another chunk
|
||||
await resp.write(data_chunk_1)
|
||||
await asyncio.sleep(0.01) # simulate async chunk delay
|
||||
await resp.write(data_chunk_2)
|
||||
await resp.write_eof()
|
||||
|
||||
return resp
|
||||
|
||||
# Create an aiohttp test server
|
||||
app = web.Application()
|
||||
app.router.add_get("/", handler)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Remove trailing slash if present in the test URL
|
||||
base_url = str(client.make_url("")).rstrip("/")
|
||||
|
||||
# Instantiate PiperTTSService with our mock server
|
||||
tts_service = PiperTTSService(base_url=base_url)
|
||||
|
||||
# Collect frames from the generator
|
||||
frames = []
|
||||
async for frame in tts_service.run_tts("Hello world."):
|
||||
frames.append(frame)
|
||||
|
||||
# Ensure we received frames in the expected order/types
|
||||
assert len(frames) >= 3, "Expecting at least TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame"
|
||||
assert isinstance(frames[0], TTSStartedFrame), "First frame must be TTSStartedFrame"
|
||||
assert isinstance(frames[-1], TTSStoppedFrame), "Last frame must be TTSStoppedFrame"
|
||||
|
||||
# Check we have at least one TTSAudioRawFrame
|
||||
audio_frames = [f for f in frames if isinstance(f, TTSAudioRawFrame)]
|
||||
assert len(audio_frames) > 0, "Should have received at least one TTSAudioRawFrame"
|
||||
for a_frame in audio_frames:
|
||||
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_piper_tts_error(aiohttp_client):
|
||||
"""Test how the service handles a non-200 response from the server.
|
||||
|
||||
Expects an ErrorFrame to be returned.
|
||||
"""
|
||||
|
||||
async def handler(_request):
|
||||
# Return an error status for any request
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/", handler)
|
||||
client = await aiohttp_client(app)
|
||||
base_url = str(client.make_url("")).rstrip("/")
|
||||
|
||||
tts_service = PiperTTSService(base_url=base_url)
|
||||
|
||||
frames = []
|
||||
async for frame in tts_service.run_tts("Error case."):
|
||||
frames.append(frame)
|
||||
|
||||
assert len(frames) == 1, "Should only receive a single ErrorFrame"
|
||||
assert isinstance(frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
|
||||
assert "status: 404" in frames[0].error, "ErrorFrame should contain details about the 404"
|
||||
Reference in New Issue
Block a user