Merge pull request #1130 from pedro-a-n-moreira/piper-tts
Add support for Piper TTS
This commit is contained in:
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for a new TTS service, `PiperTTSService`.
|
||||
(see https://github.com/rhasspy/piper/)
|
||||
|
||||
- It is now possible to tell whether `UserStartedSpeakingFrame` or
|
||||
`UserStoppedSpeakingFrame` have been generated because of emulation frames.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ pre-commit~=4.0.1
|
||||
pyright~=1.1.397
|
||||
pytest~=8.3.4
|
||||
pytest-asyncio~=0.25.3
|
||||
pytest-aiohttp==1.1.0
|
||||
ruff~=0.11.1
|
||||
setuptools~=70.0.0
|
||||
setuptools_scm~=8.1.0
|
||||
|
||||
@@ -90,3 +90,6 @@ ASSEMBLYAI_API_KEY=...
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_API_KEY=...
|
||||
|
||||
# Piper
|
||||
PIPER_BASE_URL=...
|
||||
57
examples/foundational/01-say-one-thing-piper.py
Normal file
57
examples/foundational/01-say-one-thing-piper.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import EndFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.piper import PiperTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, _) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url, None, "Say One Thing", DailyParams(audio_out_enabled=True)
|
||||
)
|
||||
|
||||
tts = PiperTTSService(
|
||||
base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000
|
||||
)
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
task = PipelineTask(Pipeline([tts, transport.output()]))
|
||||
|
||||
# Register an event handler so we can play the audio when the
|
||||
# participant joins.
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await task.queue_frames(
|
||||
[TTSSpeakFrame(f"Hello there, how are you today ?"), EndFrame()]
|
||||
)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
103
src/pipecat/services/piper.py
Normal file
103
src/pipecat/services/piper.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
|
||||
# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md
|
||||
class PiperTTSService(TTSService):
|
||||
"""Piper TTS service implementation.
|
||||
|
||||
Provides integration with Piper's TTS server.
|
||||
|
||||
Args:
|
||||
base_url: API base URL
|
||||
aiohttp_session: aiohttp ClientSession
|
||||
sample_rate: Output sample rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
# When using Piper, the sample rate of the generated audio depends on the
|
||||
# voice model being used.
|
||||
sample_rate: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
if base_url.endswith("/"):
|
||||
logger.warning("Base URL ends with a slash, this is not allowed.")
|
||||
base_url = base_url[:-1]
|
||||
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
self._settings = {"base_url": base_url}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Piper API.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech
|
||||
|
||||
Yields:
|
||||
Frames containing audio data and status information
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
headers = {
|
||||
"Content-Type": "text/plain",
|
||||
}
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(self._base_url, data=text, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
eror = await response.text()
|
||||
logger.error(
|
||||
f"{self} error getting audio (status: {response.status}, error: {eror})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {response.status}, error: {eror})"
|
||||
)
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Process the streaming response
|
||||
CHUNK_SIZE = 1024
|
||||
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||
# remove wav header if present
|
||||
if chunk.startswith(b"RIFF"):
|
||||
chunk = chunk[44:]
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
132
tests/test_piper_tts.py
Normal file
132
tests/test_piper_tts.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for PiperTTSService."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.services.piper import PiperTTSService
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_piper_tts_success(aiohttp_client):
|
||||
"""Test successful TTS generation with chunked audio data.
|
||||
|
||||
Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
|
||||
"""
|
||||
|
||||
async def handler(request):
|
||||
# The service expects a /?text= param
|
||||
# Here we're just returning dummy chunked bytes to simulate an audio response
|
||||
text_query = request.rel_url.query.get("text", "")
|
||||
print(f"Mock server received text param: {text_query}")
|
||||
|
||||
# Prepare a StreamResponse with chunked data
|
||||
resp = web.StreamResponse(
|
||||
status=200,
|
||||
reason="OK",
|
||||
headers={"Content-Type": "audio/raw"},
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
# Write out some chunked byte data
|
||||
# In reality, you’d return WAV data or similar
|
||||
data_chunk_1 = b"\x00\x01\x02\x03" * 1024 # 4096 bytes, 04 TTSAudioRawFrame
|
||||
data_chunk_2 = b"\x04\x05\x06\x07" * 1024 # another chunk
|
||||
await resp.write(data_chunk_1)
|
||||
await asyncio.sleep(0.01) # simulate async chunk delay
|
||||
await resp.write(data_chunk_2)
|
||||
await resp.write_eof()
|
||||
|
||||
return resp
|
||||
|
||||
# Create an aiohttp test server
|
||||
app = web.Application()
|
||||
app.router.add_post("/", handler)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Remove trailing slash if present in the test URL
|
||||
base_url = str(client.make_url("")).rstrip("/")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Instantiate PiperTTSService with our mock server
|
||||
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="Hello world."),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
]
|
||||
|
||||
frames_received = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
down_frames = frames_received[0]
|
||||
audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)]
|
||||
for a_frame in audio_frames:
|
||||
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_piper_tts_error(aiohttp_client):
|
||||
"""Test how the service handles a non-200 response from the server.
|
||||
|
||||
Expects an ErrorFrame to be returned.
|
||||
"""
|
||||
|
||||
async def handler(_request):
|
||||
# Return an error status for any request
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/", handler)
|
||||
client = await aiohttp_client(app)
|
||||
base_url = str(client.make_url("")).rstrip("/")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="Error case."),
|
||||
]
|
||||
|
||||
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
frames_received = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
up_frames = frames_received[1]
|
||||
|
||||
assert isinstance(up_frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
|
||||
assert "status: 404" in up_frames[0].error, (
|
||||
"ErrorFrame should contain details about the 404"
|
||||
)
|
||||
Reference in New Issue
Block a user