Merge branch 'main' into aiortc_example

This commit is contained in:
Filipi Fuchter
2025-03-27 17:50:46 -03:00
24 changed files with 764 additions and 70 deletions

View File

@@ -9,10 +9,59 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- ElevenLabs TTS services now support a sample rate of 8000.
- Added support to `ProtobufFrameSerializer` to send the messages from `TransportMessageFrame` and `TransportMessageUrgentFrame`.
- Added support for a new TTS service, `PiperTTSService`.
(see https://github.com/rhasspy/piper/)
- It is now possible to tell whether `UserStartedSpeakingFrame` or
`UserStoppedSpeakingFrame` have been generated because of emulation frames.
### Fixed
- Fixed an issue that would cause `SegmentedSTTService` based services
(e.g. `OpenAISTTService`) to try to transcribe non-spoken audio, causing
invalid transcriptions.
- Fixed an issue where `GoogleTTSService` was emitting two `TTSStoppedFrames`.
## [0.0.61] - 2025-03-26
### Added
- Added a new frame, `LLMSetToolChoiceFrame`, which provides a mechanism
for modifying the `tool_choice` in the context.
- Added `GroqTTSService` which provides text-to-speech functionality using
Groq's API.
- Added support in `DailyTransport` for updating remote participants'
`canReceive` permission via the `update_remote_participants()` method, by
bumping the daily-python dependency to >= 0.16.0.
- ElevenLabs TTS services now support a sample rate of 8000.
- Added support for `instructions` in `OpenAITTSService`.
- Added support for `base_url` in `OpenAIImageGenService` and
`OpenAITTSService`.
### Fixed
- Fixed an issue in `RTVIObserver` that prevented handling of Google LLM
context messages. The observer now processes both OpenAI-style and
Google-style contexts.
- Fixed an issue in Daily involving switching virtual devices, by bumping the
daily-python dependency to >= 0.16.1.
- Fixed a `GoogleAssistantContextAggregator` issue where function calls
placeholders where not being updated when then function call result was
different from a string.
- Fixed an issue that would cause `LLMAssistantContextAggregator` to block
processing more frames while processing a function call result.
- Fixed an issue where the `RTVIObserver` would report two bot started and
stopped speaking events for each bot turn.

View File

@@ -55,17 +55,17 @@ pip install "pipecat-ai[option,...]"
### Available services
| Category | Services | Install Command Example |
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[google]"` |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | `pip install "pipecat-ai[daily]"` |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | `pip install "pipecat-ai[moondream]"` |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |
| Category | Services | Install Command Example |
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[google]"` |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | `pip install "pipecat-ai[daily]"` |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | `pip install "pipecat-ai[moondream]"` |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)

View File

@@ -6,6 +6,7 @@ pre-commit~=4.0.1
pyright~=1.1.397
pytest~=8.3.4
pytest-asyncio~=0.25.3
pytest-aiohttp==1.1.0
ruff~=0.11.1
setuptools~=70.0.0
setuptools_scm~=8.1.0

View File

@@ -90,3 +90,6 @@ ASSEMBLYAI_API_KEY=...
# OpenRouter
OPENROUTER_API_KEY=...
# Piper
PIPER_BASE_URL=...

View File

@@ -0,0 +1,57 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import EndFrame, TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.services.piper import PiperTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
transport = DailyTransport(
room_url, None, "Say One Thing", DailyParams(audio_out_enabled=True)
)
tts = PiperTTSService(
base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000
)
runner = PipelineRunner()
task = PipelineTask(Pipeline([tts, transport.output()]))
# Register an event handler so we can play the audio when the
# participant joins.
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await task.queue_frames(
[TTSSpeakFrame(f"Hello there, how are you today ?"), EndFrame()]
)
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -48,7 +48,7 @@ async def main():
tts = PlayHTTTSService(
user_id=os.getenv("PLAYHT_USER_ID"),
api_key=os.getenv("PLAYHT_API_KEY"),
voice_url="s3://voice-cloning-zero-shot/d9ff78ba-d016-47f6-b0ef-dd630f59414e/female-cs/manifest.json",
voice_url="s3://voice-cloning-zero-shot/e46b4027-b38d-4d24-b292-38fbca2be0ef/original/manifest.json",
params=PlayHTTTSService.InputParams(language=Language.EN),
)

View File

@@ -0,0 +1,101 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.groq import GroqLLMService, GroqSTTService, GroqTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
# transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY"))
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
tts = GroqTTSService(api_key=os.getenv("GROQ_API_KEY"))
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):
await task.cancel()
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -48,7 +48,7 @@ cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
neuphonic = [ "pyneuphonic~=1.5.13", "websockets~=13.1" ]
cerebras = []
deepseek = []
daily = [ "daily-python~=0.15.0" ]
daily = [ "daily-python~=0.16.1" ]
deepgram = [ "deepgram-sdk~=3.8.0" ]
elevenlabs = [ "websockets~=13.1" ]
fal = [ "fal-client~=0.5.9" ]
@@ -56,7 +56,7 @@ fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
grok = []
groq = []
groq = [ "groq~=0.20.0" ]
gstreamer = [ "pygobject~=3.50.0" ]
fireworks = []
krisp = [ "pipecat-ai-krisp~=0.3.0" ]

View File

@@ -35,10 +35,15 @@ message TranscriptionFrame {
string timestamp = 5;
}
message MessageFrame {
string data = 1;
}
message Frame {
oneof frame {
TextFrame text = 1;
AudioRawFrame audio = 2;
TranscriptionFrame transcription = 3;
MessageFrame message = 4;
}
}

View File

@@ -363,6 +363,13 @@ class LLMSetToolsFrame(DataFrame):
tools: List[dict]
@dataclass
class LLMSetToolChoiceFrame(DataFrame):
"""A frame containing a tool choice for an LLM to use for function calling."""
tool_choice: Literal["none", "auto", "required"] | dict
@dataclass
class LLMEnablePromptCachingFrame(DataFrame):
"""A frame to enable/disable prompt caching in certain LLMs."""
@@ -384,7 +391,7 @@ class FunctionCallResultFrame(DataFrame):
function_name: str
tool_call_id: str
arguments: str
arguments: Any
result: Any
properties: Optional[FunctionCallResultProperties] = None
@@ -555,14 +562,14 @@ class UserStartedSpeakingFrame(SystemFrame):
"""
pass
emulated: bool = False
@dataclass
class UserStoppedSpeakingFrame(SystemFrame):
"""Emitted by the VAD to indicate that a user stopped speaking."""
pass
emulated: bool = False
@dataclass
@@ -633,8 +640,8 @@ class FunctionCallInProgressFrame(SystemFrame):
function_name: str
tool_call_id: str
arguments: str
cancel_on_interruption: bool
arguments: Any
cancel_on_interruption: bool = False
@dataclass

View File

@@ -1,12 +1,22 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: frames.proto
# Protobuf Python Version: 4.25.1
# Protobuf Python Version: 5.27.2
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
27,
2,
'',
'frames.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -14,19 +24,21 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"}\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\x12\x10\n\x03pts\x18\x06 \x01(\x04H\x00\x88\x01\x01\x42\x06\n\x04_pts\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x93\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"}\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\x12\x10\n\x03pts\x18\x06 \x01(\x04H\x00\x88\x01\x01\x42\x06\n\x04_pts\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x1c\n\x0cMessageFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\t\"\xbd\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x12(\n\x07message\x18\x04 \x01(\x0b\x32\x15.pipecat.MessageFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TEXTFRAME']._serialized_start=25
_globals['_TEXTFRAME']._serialized_end=76
_globals['_AUDIORAWFRAME']._serialized_start=78
_globals['_AUDIORAWFRAME']._serialized_end=203
_globals['_TRANSCRIPTIONFRAME']._serialized_start=205
_globals['_TRANSCRIPTIONFRAME']._serialized_end=301
_globals['_FRAME']._serialized_start=304
_globals['_FRAME']._serialized_end=451
_globals['_MESSAGEFRAME']._serialized_start=303
_globals['_MESSAGEFRAME']._serialized_end=331
_globals['_FRAME']._serialized_start=334
_globals['_FRAME']._serialized_end=523
# @@protoc_insertion_point(module_scope)

View File

@@ -6,7 +6,7 @@
import asyncio
from abc import abstractmethod
from typing import Dict, List
from typing import Dict, List, Literal, Set
from loguru import logger
@@ -26,6 +26,7 @@ from pipecat.frames.frames import (
LLMMessagesAppendFrame,
LLMMessagesFrame,
LLMMessagesUpdateFrame,
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
LLMTextFrame,
OpenAILLMContextAssistantTimestampFrame,
@@ -140,6 +141,11 @@ class BaseLLMResponseAggregator(FrameProcessor):
"""Set LLM tools to be used in the current conversation."""
pass
@abstractmethod
def set_tool_choice(self, tool_choice):
"""Set the tool choice. This should modify the LLM context."""
pass
@abstractmethod
def reset(self):
"""Reset the internals of this aggregator. This should not modify the
@@ -204,6 +210,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
def set_tools(self, tools: List):
self._context.set_tools(tools)
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
self._context.set_tool_choice(tool_choice)
def reset(self):
self._aggregation = ""
@@ -240,7 +249,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
self._waiting_for_aggregation = False
async def handle_aggregation(self, aggregation: str):
self._context.add_message({"role": self.role, "content": self._aggregation})
self._context.add_message({"role": self.role, "content": aggregation})
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -274,17 +283,21 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
self.set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
else:
await self.push_frame(frame, direction)
async def push_aggregation(self):
if len(self._aggregation) > 0:
await self.handle_aggregation(self._aggregation)
aggregation = self._aggregation
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self.reset()
await self.handle_aggregation(aggregation)
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
@@ -297,10 +310,16 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
async def _cancel(self, frame: CancelFrame):
await self._cancel_aggregation_task()
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
self._user_speaking = True
self._waiting_for_aggregation = True
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
# EmulateUserStoppedSpeakingFrame).
if not frame.emulated and self._emulating_vad:
self._emulating_vad = False
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
self._user_speaking = False
# We just stopped speaking. Let's see if there's some aggregation to
@@ -380,6 +399,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
self._started = 0
self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {}
self._context_updated_tasks: Set[asyncio.Task] = set()
async def handle_aggregation(self, aggregation: str):
self._context.add_message({"role": "assistant", "content": aggregation})
@@ -414,6 +434,8 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
self.set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
self.set_tool_choice(frame.tool_choice)
elif isinstance(frame, FunctionCallInProgressFrame):
await self._handle_function_call_in_progress(frame)
elif isinstance(frame, FunctionCallResultFrame):
@@ -486,10 +508,14 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
if run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call
# result is added to the context
# Call the `on_context_updated` callback once the function call result
# is added to the context. Also, run this in a separate task to make
# sure we don't block the pipeline.
if properties and properties.on_context_updated:
await properties.on_context_updated()
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
task = self.create_task(properties.on_context_updated(), task_name)
self._context_updated_tasks.add(task)
task.add_done_callback(self._context_updated_task_finished)
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
logger.debug(
@@ -535,6 +561,13 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
else:
self._aggregation += frame.text
def _context_updated_task_finished(self, task: asyncio.Task):
self._context_updated_tasks.discard(task)
# The task is finished so this should exit immediately. We need to do
# this because otherwise the task manager would report a dangling task
# if we don't remove it.
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
class LLMUserResponseAggregator(LLMUserContextAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):

View File

@@ -147,10 +147,13 @@ class FrameProcessor(BaseObject):
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
def create_task(self, coroutine: Coroutine) -> asyncio.Task:
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
name = f"{self}::{coroutine.cr_code.co_name}"
if name:
name = f"{self}::{name}"
else:
name = f"{self}::{coroutine.cr_code.co_name}"
return self._task_manager.create_task(coroutine, name)
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):

View File

@@ -540,10 +540,23 @@ class RTVIObserver(BaseObserver):
await self.push_transport_message_urgent(message)
async def _handle_context(self, frame: OpenAILLMContextFrame):
"""Process LLM context frames to extract user messages for the RTVI client."""
try:
messages = frame.context.messages
if len(messages) > 0:
message = messages[-1]
if not messages:
return
message = messages[-1]
# Handle Google LLM format (protobuf objects with attributes)
if hasattr(message, "role") and message.role == "user" and hasattr(message, "parts"):
text = "".join(part.text for part in message.parts if hasattr(part, "text"))
if text:
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
await self.push_transport_message_urgent(rtvi_message)
# Handle OpenAI format (original implementation)
elif isinstance(message, dict):
if message["role"] == "user":
content = message["content"]
if isinstance(content, list):
@@ -552,7 +565,8 @@ class RTVIObserver(BaseObserver):
text = content
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
await self.push_transport_message_urgent(rtvi_message)
except TypeError as e:
except Exception as e:
logger.warning(f"Caught an error while trying to handle context: {e}")
async def _handle_metrics(self, frame: MetricsFrame):

View File

@@ -5,6 +5,7 @@
#
import dataclasses
import json
from loguru import logger
@@ -15,15 +16,24 @@ from pipecat.frames.frames import (
OutputAudioRawFrame,
TextFrame,
TranscriptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
)
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
# Data class for converting transport messages into Protobuf format.
@dataclasses.dataclass
class MessageFrame:
data: str
class ProtobufFrameSerializer(FrameSerializer):
SERIALIZABLE_TYPES = {
TextFrame: "text",
OutputAudioRawFrame: "audio",
TranscriptionFrame: "transcription",
MessageFrame: "message",
}
SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}
@@ -42,6 +52,12 @@ class ProtobufFrameSerializer(FrameSerializer):
return FrameSerializerType.BINARY
async def serialize(self, frame: Frame) -> str | bytes | None:
# Wrapping this messages as a JSONFrame to send
if isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)):
frame = MessageFrame(
data=json.dumps(frame.message),
)
proto_frame = frame_protos.Frame()
if type(frame) not in self.SERIALIZABLE_TYPES:
logger.warning(f"Frame type {type(frame)} is not serializable")

View File

@@ -369,7 +369,7 @@ class LLMService(AIService):
if tuple_to_remove:
self._function_call_tasks.discard(tuple_to_remove)
# The task is finished so this should exit immediately. We need to
# do this because otherwise the task manager would have a dangling
# do this because otherwise the task manager would report a dangling
# task if we don't remove it.
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
@@ -1048,9 +1048,14 @@ class SegmentedSTTService(STTService):
await self._handle_user_stopped_speaking(frame)
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
if frame.emulated:
return
self._user_speaking = True
async def _handle_user_stopped_speaking(self, frame: UserStoppedSpeakingFrame):
if frame.emulated:
return
self._user_speaking = False
content = io.BytesIO()
@@ -1068,7 +1073,7 @@ class SegmentedSTTService(STTService):
self._audio_buffer.clear()
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
# If the user is speaking the audio buffer will keep growin.
# If the user is speaking the audio buffer will keep growing.
self._audio_buffer += frame.audio
# If the user is not speaking we keep just a little bit of audio.

View File

@@ -725,7 +725,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: str
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if message["role"] == "user":

View File

@@ -601,13 +601,8 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
if frame.result:
if not isinstance(frame.result, str):
return
response = {"response": frame.result}
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, response
frame.function_name, frame.tool_call_id, frame.result
)
else:
await self._update_function_call_result(
@@ -626,7 +621,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
if message.role == "user":
for part in message.parts:
if part.function_response and part.function_response.id == tool_call_id:
part.function_response.response = {"response": result}
part.function_response.response = {"value": json.dumps(result)}
async def handle_user_image_frame(self, frame: UserImageRawFrame):
await self._update_function_call_result(
@@ -1348,6 +1343,7 @@ class GoogleVertexLLMService(OpenAILLMService):
**kwargs,
):
"""Initializes the VertexLLMService.
Args:
credentials (Optional[str]): JSON string of service account credentials.
credentials_path (Optional[str]): Path to the service account JSON file.
@@ -1371,9 +1367,11 @@ class GoogleVertexLLMService(OpenAILLMService):
@staticmethod
def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
"""Retrieves an authentication token using Google service account credentials.
Args:
credentials (Optional[str]): JSON string of service account credentials.
credentials_path (Optional[str]): Path to the service account JSON file.
Returns:
str: OAuth token for API authentication.
"""
@@ -1562,8 +1560,6 @@ class GoogleTTSService(TTSService):
logger.exception(f"{self} error generating TTS: {e}")
error_message = f"TTS generation error: {str(e)}"
yield ErrorFrame(error=error_message)
finally:
yield TTSStoppedFrame()
class GoogleImageGenService(ImageGenService):

View File

@@ -5,14 +5,26 @@
#
from typing import Optional
from typing import AsyncGenerator, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
from pipecat.services.ai_services import TTSService
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
from pipecat.services.openai import OpenAILLMService
from pipecat.transcriptions.language import Language
try:
from groq import AsyncGroq
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Groq, you need to `pip install pipecat-ai[groq]`. Also, set a `GROQ_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
class GroqLLMService(OpenAILLMService):
"""A service for interacting with Groq's API using the OpenAI-compatible interface.
@@ -98,3 +110,68 @@ class GroqSTTService(BaseWhisperSTTService):
kwargs["temperature"] = self._temperature
return await self._client.audio.transcriptions.create(**kwargs)
class GroqTTSService(TTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
speed: Optional[float] = 1.0
seed: Optional[int] = None
GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
def __init__(
self,
*,
api_key: str,
output_format: str = "wav",
params: InputParams = InputParams(),
model_name: str = "playai-tts",
voice_id: str = "Celeste-PlayAI",
sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
**kwargs,
):
if sample_rate != self.GROQ_SAMPLE_RATE:
logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
super().__init__(
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
self._api_key = api_key
self._model_name = model_name
self._output_format = output_format
self._voice_id = voice_id
self._params = params
self._client = AsyncGroq(api_key=self._api_key)
def can_generate_metrics(self) -> bool:
return True
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
measuring_ttfb = True
await self.start_ttfb_metrics()
yield TTSStartedFrame()
response = await self._client.audio.speech.create(
model=self._model_name,
voice=self._voice_id,
response_format=self._output_format,
input=text,
)
async for data in response.iter_bytes():
if measuring_ttfb:
await self.stop_ttfb_metrics()
measuring_ttfb = False
# remove wav header if present
if data.startswith(b"RIFF"):
data = data[44:]
if len(data) == 0:
continue
yield TTSAudioRawFrame(data, self.sample_rate, 1)
yield TTSStoppedFrame()

View File

@@ -391,6 +391,7 @@ class OpenAIImageGenService(ImageGenService):
self,
*,
api_key: str,
base_url: Optional[str] = None,
aiohttp_session: aiohttp.ClientSession,
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
model: str = "dall-e-3",
@@ -398,7 +399,7 @@ class OpenAIImageGenService(ImageGenService):
super().__init__()
self.set_model_name(model)
self._image_size = image_size
self._client = AsyncOpenAI(api_key=api_key)
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
self._aiohttp_session = aiohttp_session
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
@@ -501,9 +502,11 @@ class OpenAITTSService(TTSService):
self,
*,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
voice: str = "alloy",
model: str = "gpt-4o-mini-tts",
sample_rate: Optional[int] = None,
instructions: Optional[str] = None,
**kwargs,
):
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
@@ -515,8 +518,8 @@ class OpenAITTSService(TTSService):
self.set_model_name(model)
self.set_voice(voice)
self._client = AsyncOpenAI(api_key=api_key)
self._instructions = instructions
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
def can_generate_metrics(self) -> bool:
return True
@@ -538,11 +541,17 @@ class OpenAITTSService(TTSService):
try:
await self.start_ttfb_metrics()
# Setup extra body parameters
extra_body = {}
if self._instructions:
extra_body["instructions"] = self._instructions
async with self._client.audio.speech.with_streaming_response.create(
input=text or " ", # Text must contain at least one character
model=self.model_name,
voice=VALID_VOICES[self._voice_id],
response_format="pcm",
extra_body=extra_body,
) as r:
if r.status_code != 200:
error = await r.text()
@@ -613,7 +622,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: str
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if (

View File

@@ -0,0 +1,103 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService
# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md
class PiperTTSService(TTSService):
"""Piper TTS service implementation.
Provides integration with Piper's TTS server.
Args:
base_url: API base URL
aiohttp_session: aiohttp ClientSession
sample_rate: Output sample rate
"""
def __init__(
self,
*,
base_url: str,
aiohttp_session: aiohttp.ClientSession,
# When using Piper, the sample rate of the generated audio depends on the
# voice model being used.
sample_rate: Optional[int] = None,
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
if base_url.endswith("/"):
logger.warning("Base URL ends with a slash, this is not allowed.")
base_url = base_url[:-1]
self._base_url = base_url
self._session = aiohttp_session
self._settings = {"base_url": base_url}
def can_generate_metrics(self) -> bool:
return True
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Piper API.
Args:
text: The text to convert to speech
Yields:
Frames containing audio data and status information
"""
logger.debug(f"{self}: Generating TTS [{text}]")
headers = {
"Content-Type": "text/plain",
}
try:
await self.start_ttfb_metrics()
async with self._session.post(self._base_url, data=text, headers=headers) as response:
if response.status != 200:
eror = await response.text()
logger.error(
f"{self} error getting audio (status: {response.status}, error: {eror})"
)
yield ErrorFrame(
f"Error getting audio (status: {response.status}, error: {eror})"
)
return
await self.start_tts_usage_metrics(text)
# Process the streaming response
CHUNK_SIZE = 1024
yield TTSStartedFrame()
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
# remove wav header if present
if chunk.startswith(b"RIFF"):
chunk = chunk[44:]
if len(chunk) > 0:
await self.stop_ttfb_metrics()
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
except Exception as e:
logger.error(f"Error in run_tts: {e}")
yield ErrorFrame(error=str(e))
finally:
logger.debug(f"{self}: Finished TTS [{text}]")
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -117,10 +117,10 @@ class BaseInputTransport(FrameProcessor):
await self._handle_bot_interruption(frame)
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
logger.debug("Emulating user started speaking")
await self._handle_user_interruption(UserStartedSpeakingFrame())
await self._handle_user_interruption(UserStartedSpeakingFrame(emulated=True))
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
logger.debug("Emulating user stopped speaking")
await self._handle_user_interruption(UserStoppedSpeakingFrame())
await self._handle_user_interruption(UserStoppedSpeakingFrame(emulated=True))
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)

View File

@@ -4,13 +4,18 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import json
import unittest
from typing import Any
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallResultProperties,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -21,10 +26,7 @@ from pipecat.frames.frames import (
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -423,6 +425,9 @@ class BaseTestAssistantContextAggreagator:
):
assert context.messages[index]["content"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: str):
assert json.loads(context.messages[index]["content"]) == content
async def test_empty(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
@@ -556,9 +561,76 @@ class BaseTestAssistantContextAggreagator:
self.check_message_multi_content(context, 0, 0, "Hello Pipecat.")
self.check_message_multi_content(context, 0, 1, "How are you?")
async def test_function_call(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
FunctionCallInProgressFrame(
function_name="get_weather",
tool_call_id="1",
arguments={"location": "Los Angeles"},
cancel_on_interruption=False,
),
SleepFrame(),
FunctionCallResultFrame(
function_name="get_weather",
tool_call_id="1",
arguments={"location": "Los Angeles"},
result={"conditions": "Sunny"},
),
]
expected_down_frames = []
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_function_call_result(context, -1, {"conditions": "Sunny"})
async def test_function_call_on_context_updated(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context_updated = False
async def on_context_updated():
nonlocal context_updated
context_updated = True
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
FunctionCallInProgressFrame(
function_name="get_weather",
tool_call_id="1",
arguments={"location": "Los Angeles"},
cancel_on_interruption=False,
),
SleepFrame(),
FunctionCallResultFrame(
function_name="get_weather",
tool_call_id="1",
arguments={"location": "Los Angeles"},
result={"conditions": "Sunny"},
properties=FunctionCallResultProperties(on_context_updated=on_context_updated),
),
SleepFrame(),
]
expected_down_frames = []
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_function_call_result(context, -1, {"conditions": "Sunny"})
assert context_updated
#
# LLMUserContextAggregator, LLMAssistantContextAggregator
# LLMUserContextAggregator
#
@@ -567,14 +639,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola
AGGREGATOR_CLASS = LLMUserContextAggregator
class TestLLMAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = LLMAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
#
# OpenAI
#
@@ -626,6 +690,9 @@ class TestAnthropicAssistantContextAggregator(
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
assert context.messages[index]["content"][0]["content"] == json.dumps(content)
#
# Google
@@ -665,3 +732,7 @@ class TestGoogleAssistantContextAggregator(
):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)

132
tests/test_piper_tts.py Normal file
View File

@@ -0,0 +1,132 @@
"""Tests for PiperTTSService."""
import asyncio
import aiohttp
import pytest
from aiohttp import web
from pipecat.frames.frames import (
ErrorFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
)
from pipecat.services.piper import PiperTTSService
from pipecat.tests.utils import run_test
@pytest.mark.asyncio
async def test_run_piper_tts_success(aiohttp_client):
"""Test successful TTS generation with chunked audio data.
Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
"""
async def handler(request):
# The service expects a /?text= param
# Here we're just returning dummy chunked bytes to simulate an audio response
text_query = request.rel_url.query.get("text", "")
print(f"Mock server received text param: {text_query}")
# Prepare a StreamResponse with chunked data
resp = web.StreamResponse(
status=200,
reason="OK",
headers={"Content-Type": "audio/raw"},
)
await resp.prepare(request)
# Write out some chunked byte data
# In reality, youd return WAV data or similar
data_chunk_1 = b"\x00\x01\x02\x03" * 1024 # 4096 bytes, 04 TTSAudioRawFrame
data_chunk_2 = b"\x04\x05\x06\x07" * 1024 # another chunk
await resp.write(data_chunk_1)
await asyncio.sleep(0.01) # simulate async chunk delay
await resp.write(data_chunk_2)
await resp.write_eof()
return resp
# Create an aiohttp test server
app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)
# Remove trailing slash if present in the test URL
base_url = str(client.make_url("")).rstrip("/")
async with aiohttp.ClientSession() as session:
# Instantiate PiperTTSService with our mock server
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
frames_to_send = [
TTSSpeakFrame(text="Hello world."),
]
expected_returned_frames = [
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
TTSTextFrame,
]
frames_received = await run_test(
tts_service,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
down_frames = frames_received[0]
audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)]
for a_frame in audio_frames:
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
@pytest.mark.asyncio
async def test_run_piper_tts_error(aiohttp_client):
"""Test how the service handles a non-200 response from the server.
Expects an ErrorFrame to be returned.
"""
async def handler(_request):
# Return an error status for any request
return web.Response(status=404, text="Not found")
app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)
base_url = str(client.make_url("")).rstrip("/")
async with aiohttp.ClientSession() as session:
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
frames_to_send = [
TTSSpeakFrame(text="Error case."),
]
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]
frames_received = await run_test(
tts_service,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
up_frames = frames_received[1]
assert isinstance(up_frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
assert "status: 404" in up_frames[0].error, (
"ErrorFrame should contain details about the 404"
)