From 79ac6969736d3e41a83efabab27eecefb1b202bc Mon Sep 17 00:00:00 2001 From: Pedro Moreira Date: Tue, 4 Feb 2025 13:51:33 -0300 Subject: [PATCH] Add support for Piper TTS --- src/pipecat/services/piper.py | 103 ++++++++++++++++++++++++++++++++++ test-requirements.txt | 1 + tests/test_piper_tts.py | 101 +++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 src/pipecat/services/piper.py create mode 100644 tests/test_piper_tts.py diff --git a/src/pipecat/services/piper.py b/src/pipecat/services/piper.py new file mode 100644 index 000000000..ecb831eb7 --- /dev/null +++ b/src/pipecat/services/piper.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import AsyncGenerator + +import aiohttp +from loguru import logger + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService + +# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md + + +class PiperTTSService(TTSService): + """Piper TTS service implementation. + + Provides integration with Piper's TTS server. + """ + + def __init__( + self, + *, + base_url: str, + aiohttp_session: aiohttp.ClientSession | None = None, + sample_rate: int = 24000, + **kwargs, + ): + """Initialize the PiperTTSService class instance. + + Args: + base_url (str): Base URL of the Piper TTS server (should not end with a slash). + aiohttp_session (aiohttp.ClientSession, optional): Optional aiohttp session to use for requests. Defaults to None. + sample_rate (int, optional): Sample rate in Hz. Defaults to 24000. + **kwargs (dict): Additional keyword arguments. + """ + super().__init__(sample_rate=sample_rate, **kwargs) + if not aiohttp_session: + aiohttp_session = aiohttp.ClientSession() + + if base_url.endswith("/"): + logger.warning("Base URL ends with a slash, this is not allowed.") + base_url = base_url[:-1] + + self._settings = {"base_url": base_url} + self.set_voice("voice_id") + self._aiohttp_session = aiohttp_session + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + url = self._settings["base_url"] + "/?text=" + text.replace(".", "").replace("*", "") + + await self.start_ttfb_metrics() + + async with self._aiohttp_session.get(url) as r: + if r.status != 200: + text = await r.text() + logger.error(f"{self} error getting audio (status: {r.status}, error: {text})") + yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})") + return + + await self.start_tts_usage_metrics(text) + + yield TTSStartedFrame() + + buffer = bytearray() + async for chunk in r.content.iter_chunked(1024): + if len(chunk) > 0: + await self.stop_ttfb_metrics() + # Append new chunk to the buffer. + buffer.extend(chunk) + + # Check if buffer has enough data for processing. + while ( + len(buffer) >= 48000 + ): # Assuming at least 0.5 seconds of audio data at 24000 Hz + # Process the buffer up to a safe size for resampling. + process_data = buffer[:48000] + # Remove processed data from buffer. + buffer = buffer[48000:] + + frame = TTSAudioRawFrame(process_data, self._sample_rate, 1) + yield frame + + # Process any remaining data in the buffer. + if len(buffer) > 0: + frame = TTSAudioRawFrame(buffer, self._sample_rate, 1) + yield frame + + yield TTSStoppedFrame() diff --git a/test-requirements.txt b/test-requirements.txt index 36e64060b..aae1c2dbb 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -21,6 +21,7 @@ pyaudio~=0.2.14 pydantic~=2.8.2 pyloudnorm~=0.1.1 pyht~=0.1.4 +pytest-aiohttp==1.1.0 python-dotenv~=1.0.1 silero-vad~=5.1 soxr~=0.5.0 diff --git a/tests/test_piper_tts.py b/tests/test_piper_tts.py new file mode 100644 index 000000000..296de7fae --- /dev/null +++ b/tests/test_piper_tts.py @@ -0,0 +1,101 @@ +"""Tests for PiperTTSService.""" + +import asyncio + +import pytest +from aiohttp import web + +from pipecat.frames.frames import ( + ErrorFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.piper import PiperTTSService + + +@pytest.mark.asyncio +async def test_run_piper_tts_success(aiohttp_client): + """Test successful TTS generation with chunked audio data. + + Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame. + """ + + async def handler(request): + # The service expects a /?text= param + # Here we're just returning dummy chunked bytes to simulate an audio response + text_query = request.rel_url.query.get("text", "") + print(f"Mock server received text param: {text_query}") + + # Prepare a StreamResponse with chunked data + resp = web.StreamResponse( + status=200, + reason="OK", + headers={"Content-Type": "audio/raw"}, + ) + await resp.prepare(request) + + # Write out some chunked byte data + # In reality, you’d return WAV data or similar + data_chunk_1 = b"\x00\x01\x02\x03" * 12000 # 48000 bytes + data_chunk_2 = b"\x04\x05\x06\x07" * 6000 # another chunk + await resp.write(data_chunk_1) + await asyncio.sleep(0.01) # simulate async chunk delay + await resp.write(data_chunk_2) + await resp.write_eof() + + return resp + + # Create an aiohttp test server + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + # Remove trailing slash if present in the test URL + base_url = str(client.make_url("")).rstrip("/") + + # Instantiate PiperTTSService with our mock server + tts_service = PiperTTSService(base_url=base_url) + + # Collect frames from the generator + frames = [] + async for frame in tts_service.run_tts("Hello world."): + frames.append(frame) + + # Ensure we received frames in the expected order/types + assert len(frames) >= 3, "Expecting at least TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame" + assert isinstance(frames[0], TTSStartedFrame), "First frame must be TTSStartedFrame" + assert isinstance(frames[-1], TTSStoppedFrame), "Last frame must be TTSStoppedFrame" + + # Check we have at least one TTSAudioRawFrame + audio_frames = [f for f in frames if isinstance(f, TTSAudioRawFrame)] + assert len(audio_frames) > 0, "Should have received at least one TTSAudioRawFrame" + for a_frame in audio_frames: + assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)" + + +@pytest.mark.asyncio +async def test_run_piper_tts_error(aiohttp_client): + """Test how the service handles a non-200 response from the server. + + Expects an ErrorFrame to be returned. + """ + + async def handler(_request): + # Return an error status for any request + return web.Response(status=404, text="Not found") + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + base_url = str(client.make_url("")).rstrip("/") + + tts_service = PiperTTSService(base_url=base_url) + + frames = [] + async for frame in tts_service.run_tts("Error case."): + frames.append(frame) + + assert len(frames) == 1, "Should only receive a single ErrorFrame" + assert isinstance(frames[0], ErrorFrame), "Must receive an ErrorFrame for 404" + assert "status: 404" in frames[0].error, "ErrorFrame should contain details about the 404"