Merge pull request #1130 from pedro-a-n-moreira/piper-tts

Add support for Piper TTS
This commit is contained in:
Filipi da Silva Fuchter
2025-03-27 08:08:05 -03:00
committed by GitHub
6 changed files with 299 additions and 0 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added support for a new TTS service, `PiperTTSService`.
(see https://github.com/rhasspy/piper/)
- It is now possible to tell whether `UserStartedSpeakingFrame` or
`UserStoppedSpeakingFrame` have been generated because of emulation frames.

View File

@@ -6,6 +6,7 @@ pre-commit~=4.0.1
pyright~=1.1.397
pytest~=8.3.4
pytest-asyncio~=0.25.3
pytest-aiohttp==1.1.0
ruff~=0.11.1
setuptools~=70.0.0
setuptools_scm~=8.1.0

View File

@@ -90,3 +90,6 @@ ASSEMBLYAI_API_KEY=...
# OpenRouter
OPENROUTER_API_KEY=...
# Piper
PIPER_BASE_URL=...

View File

@@ -0,0 +1,57 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import EndFrame, TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.services.piper import PiperTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
transport = DailyTransport(
room_url, None, "Say One Thing", DailyParams(audio_out_enabled=True)
)
tts = PiperTTSService(
base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000
)
runner = PipelineRunner()
task = PipelineTask(Pipeline([tts, transport.output()]))
# Register an event handler so we can play the audio when the
# participant joins.
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await task.queue_frames(
[TTSSpeakFrame(f"Hello there, how are you today ?"), EndFrame()]
)
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,103 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService
# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md
class PiperTTSService(TTSService):
"""Piper TTS service implementation.
Provides integration with Piper's TTS server.
Args:
base_url: API base URL
aiohttp_session: aiohttp ClientSession
sample_rate: Output sample rate
"""
def __init__(
self,
*,
base_url: str,
aiohttp_session: aiohttp.ClientSession,
# When using Piper, the sample rate of the generated audio depends on the
# voice model being used.
sample_rate: Optional[int] = None,
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
if base_url.endswith("/"):
logger.warning("Base URL ends with a slash, this is not allowed.")
base_url = base_url[:-1]
self._base_url = base_url
self._session = aiohttp_session
self._settings = {"base_url": base_url}
def can_generate_metrics(self) -> bool:
return True
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Piper API.
Args:
text: The text to convert to speech
Yields:
Frames containing audio data and status information
"""
logger.debug(f"{self}: Generating TTS [{text}]")
headers = {
"Content-Type": "text/plain",
}
try:
await self.start_ttfb_metrics()
async with self._session.post(self._base_url, data=text, headers=headers) as response:
if response.status != 200:
eror = await response.text()
logger.error(
f"{self} error getting audio (status: {response.status}, error: {eror})"
)
yield ErrorFrame(
f"Error getting audio (status: {response.status}, error: {eror})"
)
return
await self.start_tts_usage_metrics(text)
# Process the streaming response
CHUNK_SIZE = 1024
yield TTSStartedFrame()
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
# remove wav header if present
if chunk.startswith(b"RIFF"):
chunk = chunk[44:]
if len(chunk) > 0:
await self.stop_ttfb_metrics()
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
except Exception as e:
logger.error(f"Error in run_tts: {e}")
yield ErrorFrame(error=str(e))
finally:
logger.debug(f"{self}: Finished TTS [{text}]")
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

132
tests/test_piper_tts.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for PiperTTSService."""
import asyncio
import aiohttp
import pytest
from aiohttp import web
from pipecat.frames.frames import (
ErrorFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
)
from pipecat.services.piper import PiperTTSService
from pipecat.tests.utils import run_test
@pytest.mark.asyncio
async def test_run_piper_tts_success(aiohttp_client):
"""Test successful TTS generation with chunked audio data.
Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
"""
async def handler(request):
# The service expects a /?text= param
# Here we're just returning dummy chunked bytes to simulate an audio response
text_query = request.rel_url.query.get("text", "")
print(f"Mock server received text param: {text_query}")
# Prepare a StreamResponse with chunked data
resp = web.StreamResponse(
status=200,
reason="OK",
headers={"Content-Type": "audio/raw"},
)
await resp.prepare(request)
# Write out some chunked byte data
# In reality, youd return WAV data or similar
data_chunk_1 = b"\x00\x01\x02\x03" * 1024 # 4096 bytes, 04 TTSAudioRawFrame
data_chunk_2 = b"\x04\x05\x06\x07" * 1024 # another chunk
await resp.write(data_chunk_1)
await asyncio.sleep(0.01) # simulate async chunk delay
await resp.write(data_chunk_2)
await resp.write_eof()
return resp
# Create an aiohttp test server
app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)
# Remove trailing slash if present in the test URL
base_url = str(client.make_url("")).rstrip("/")
async with aiohttp.ClientSession() as session:
# Instantiate PiperTTSService with our mock server
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
frames_to_send = [
TTSSpeakFrame(text="Hello world."),
]
expected_returned_frames = [
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
TTSTextFrame,
]
frames_received = await run_test(
tts_service,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
down_frames = frames_received[0]
audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)]
for a_frame in audio_frames:
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
@pytest.mark.asyncio
async def test_run_piper_tts_error(aiohttp_client):
"""Test how the service handles a non-200 response from the server.
Expects an ErrorFrame to be returned.
"""
async def handler(_request):
# Return an error status for any request
return web.Response(status=404, text="Not found")
app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)
base_url = str(client.make_url("")).rstrip("/")
async with aiohttp.ClientSession() as session:
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
frames_to_send = [
TTSSpeakFrame(text="Error case."),
]
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]
frames_received = await run_test(
tts_service,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
up_frames = frames_received[1]
assert isinstance(up_frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
assert "status: 404" in up_frames[0].error, (
"ErrorFrame should contain details about the 404"
)