Merge branch 'main' into aiortc_example
This commit is contained in:
51
CHANGELOG.md
51
CHANGELOG.md
@@ -9,10 +9,59 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- ElevenLabs TTS services now support a sample rate of 8000.
|
||||
- Added support to `ProtobufFrameSerializer` to send the messages from `TransportMessageFrame` and `TransportMessageUrgentFrame`.
|
||||
|
||||
- Added support for a new TTS service, `PiperTTSService`.
|
||||
(see https://github.com/rhasspy/piper/)
|
||||
|
||||
- It is now possible to tell whether `UserStartedSpeakingFrame` or
|
||||
`UserStoppedSpeakingFrame` have been generated because of emulation frames.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue that would cause `SegmentedSTTService` based services
|
||||
(e.g. `OpenAISTTService`) to try to transcribe non-spoken audio, causing
|
||||
invalid transcriptions.
|
||||
|
||||
- Fixed an issue where `GoogleTTSService` was emitting two `TTSStoppedFrames`.
|
||||
|
||||
## [0.0.61] - 2025-03-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added a new frame, `LLMSetToolChoiceFrame`, which provides a mechanism
|
||||
for modifying the `tool_choice` in the context.
|
||||
|
||||
- Added `GroqTTSService` which provides text-to-speech functionality using
|
||||
Groq's API.
|
||||
|
||||
- Added support in `DailyTransport` for updating remote participants'
|
||||
`canReceive` permission via the `update_remote_participants()` method, by
|
||||
bumping the daily-python dependency to >= 0.16.0.
|
||||
|
||||
- ElevenLabs TTS services now support a sample rate of 8000.
|
||||
|
||||
- Added support for `instructions` in `OpenAITTSService`.
|
||||
|
||||
- Added support for `base_url` in `OpenAIImageGenService` and
|
||||
`OpenAITTSService`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `RTVIObserver` that prevented handling of Google LLM
|
||||
context messages. The observer now processes both OpenAI-style and
|
||||
Google-style contexts.
|
||||
|
||||
- Fixed an issue in Daily involving switching virtual devices, by bumping the
|
||||
daily-python dependency to >= 0.16.1.
|
||||
|
||||
- Fixed a `GoogleAssistantContextAggregator` issue where function calls
|
||||
placeholders where not being updated when then function call result was
|
||||
different from a string.
|
||||
|
||||
- Fixed an issue that would cause `LLMAssistantContextAggregator` to block
|
||||
processing more frames while processing a function call result.
|
||||
|
||||
- Fixed an issue where the `RTVIObserver` would report two bot started and
|
||||
stopped speaking events for each bot turn.
|
||||
|
||||
|
||||
22
README.md
22
README.md
@@ -55,17 +55,17 @@ pip install "pipecat-ai[option,...]"
|
||||
|
||||
### Available services
|
||||
|
||||
| Category | Services | Install Command Example |
|
||||
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[google]"` |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | `pip install "pipecat-ai[daily]"` |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | `pip install "pipecat-ai[moondream]"` |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
|
||||
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |
|
||||
| Category | Services | Install Command Example |
|
||||
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[google]"` |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | `pip install "pipecat-ai[daily]"` |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | `pip install "pipecat-ai[moondream]"` |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` |
|
||||
| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` |
|
||||
|
||||
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ pre-commit~=4.0.1
|
||||
pyright~=1.1.397
|
||||
pytest~=8.3.4
|
||||
pytest-asyncio~=0.25.3
|
||||
pytest-aiohttp==1.1.0
|
||||
ruff~=0.11.1
|
||||
setuptools~=70.0.0
|
||||
setuptools_scm~=8.1.0
|
||||
|
||||
@@ -90,3 +90,6 @@ ASSEMBLYAI_API_KEY=...
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_API_KEY=...
|
||||
|
||||
# Piper
|
||||
PIPER_BASE_URL=...
|
||||
57
examples/foundational/01-say-one-thing-piper.py
Normal file
57
examples/foundational/01-say-one-thing-piper.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import EndFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.piper import PiperTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, _) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url, None, "Say One Thing", DailyParams(audio_out_enabled=True)
|
||||
)
|
||||
|
||||
tts = PiperTTSService(
|
||||
base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000
|
||||
)
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
task = PipelineTask(Pipeline([tts, transport.output()]))
|
||||
|
||||
# Register an event handler so we can play the audio when the
|
||||
# participant joins.
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await task.queue_frames(
|
||||
[TTSSpeakFrame(f"Hello there, how are you today ?"), EndFrame()]
|
||||
)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -48,7 +48,7 @@ async def main():
|
||||
tts = PlayHTTTSService(
|
||||
user_id=os.getenv("PLAYHT_USER_ID"),
|
||||
api_key=os.getenv("PLAYHT_API_KEY"),
|
||||
voice_url="s3://voice-cloning-zero-shot/d9ff78ba-d016-47f6-b0ef-dd630f59414e/female-cs/manifest.json",
|
||||
voice_url="s3://voice-cloning-zero-shot/e46b4027-b38d-4d24-b292-38fbca2be0ef/original/manifest.json",
|
||||
params=PlayHTTTSService.InputParams(language=Language.EN),
|
||||
)
|
||||
|
||||
|
||||
101
examples/foundational/07y-interruptible-groq.py
Normal file
101
examples/foundational/07y-interruptible-groq.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.groq import GroqLLMService, GroqSTTService, GroqTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
# transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY"))
|
||||
|
||||
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
|
||||
|
||||
tts = GroqTTSService(api_key=os.getenv("GROQ_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -48,7 +48,7 @@ cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
|
||||
neuphonic = [ "pyneuphonic~=1.5.13", "websockets~=13.1" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.15.0" ]
|
||||
daily = [ "daily-python~=0.16.1" ]
|
||||
deepgram = [ "deepgram-sdk~=3.8.0" ]
|
||||
elevenlabs = [ "websockets~=13.1" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
@@ -56,7 +56,7 @@ fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
|
||||
grok = []
|
||||
groq = []
|
||||
groq = [ "groq~=0.20.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
fireworks = []
|
||||
krisp = [ "pipecat-ai-krisp~=0.3.0" ]
|
||||
|
||||
@@ -35,10 +35,15 @@ message TranscriptionFrame {
|
||||
string timestamp = 5;
|
||||
}
|
||||
|
||||
message MessageFrame {
|
||||
string data = 1;
|
||||
}
|
||||
|
||||
message Frame {
|
||||
oneof frame {
|
||||
TextFrame text = 1;
|
||||
AudioRawFrame audio = 2;
|
||||
TranscriptionFrame transcription = 3;
|
||||
MessageFrame message = 4;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,6 +363,13 @@ class LLMSetToolsFrame(DataFrame):
|
||||
tools: List[dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSetToolChoiceFrame(DataFrame):
|
||||
"""A frame containing a tool choice for an LLM to use for function calling."""
|
||||
|
||||
tool_choice: Literal["none", "auto", "required"] | dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMEnablePromptCachingFrame(DataFrame):
|
||||
"""A frame to enable/disable prompt caching in certain LLMs."""
|
||||
@@ -384,7 +391,7 @@ class FunctionCallResultFrame(DataFrame):
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
arguments: Any
|
||||
result: Any
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
@@ -555,14 +562,14 @@ class UserStartedSpeakingFrame(SystemFrame):
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
emulated: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStoppedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by the VAD to indicate that a user stopped speaking."""
|
||||
|
||||
pass
|
||||
emulated: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -633,8 +640,8 @@ class FunctionCallInProgressFrame(SystemFrame):
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
cancel_on_interruption: bool
|
||||
arguments: Any
|
||||
cancel_on_interruption: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: frames.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
# Protobuf Python Version: 5.27.2
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
27,
|
||||
2,
|
||||
'',
|
||||
'frames.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
@@ -14,19 +24,21 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"}\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\x12\x10\n\x03pts\x18\x06 \x01(\x04H\x00\x88\x01\x01\x42\x06\n\x04_pts\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x93\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"}\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\x12\x10\n\x03pts\x18\x06 \x01(\x04H\x00\x88\x01\x01\x42\x06\n\x04_pts\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x1c\n\x0cMessageFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\t\"\xbd\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x12(\n\x07message\x18\x04 \x01(\x0b\x32\x15.pipecat.MessageFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TEXTFRAME']._serialized_start=25
|
||||
_globals['_TEXTFRAME']._serialized_end=76
|
||||
_globals['_AUDIORAWFRAME']._serialized_start=78
|
||||
_globals['_AUDIORAWFRAME']._serialized_end=203
|
||||
_globals['_TRANSCRIPTIONFRAME']._serialized_start=205
|
||||
_globals['_TRANSCRIPTIONFRAME']._serialized_end=301
|
||||
_globals['_FRAME']._serialized_start=304
|
||||
_globals['_FRAME']._serialized_end=451
|
||||
_globals['_MESSAGEFRAME']._serialized_start=303
|
||||
_globals['_MESSAGEFRAME']._serialized_end=331
|
||||
_globals['_FRAME']._serialized_start=334
|
||||
_globals['_FRAME']._serialized_end=523
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Literal, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -26,6 +26,7 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
@@ -140,6 +141,11 @@ class BaseLLMResponseAggregator(FrameProcessor):
|
||||
"""Set LLM tools to be used in the current conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_tool_choice(self, tool_choice):
|
||||
"""Set the tool choice. This should modify the LLM context."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""Reset the internals of this aggregator. This should not modify the
|
||||
@@ -204,6 +210,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
def set_tools(self, tools: List):
|
||||
self._context.set_tools(tools)
|
||||
|
||||
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
|
||||
self._context.set_tool_choice(tool_choice)
|
||||
|
||||
def reset(self):
|
||||
self._aggregation = ""
|
||||
|
||||
@@ -240,7 +249,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
self._waiting_for_aggregation = False
|
||||
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
self._context.add_message({"role": self.role, "content": aggregation})
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -274,17 +283,21 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
self.set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
await self.handle_aggregation(self._aggregation)
|
||||
aggregation = self._aggregation
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self.reset()
|
||||
|
||||
await self.handle_aggregation(aggregation)
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -297,10 +310,16 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
|
||||
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
self._waiting_for_aggregation = True
|
||||
|
||||
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
|
||||
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
|
||||
# EmulateUserStoppedSpeakingFrame).
|
||||
if not frame.emulated and self._emulating_vad:
|
||||
self._emulating_vad = False
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
# We just stopped speaking. Let's see if there's some aggregation to
|
||||
@@ -380,6 +399,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
self._started = 0
|
||||
self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {}
|
||||
self._context_updated_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
@@ -414,6 +434,8 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
self.set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
await self._handle_function_call_in_progress(frame)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
@@ -486,10 +508,14 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
if run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call
|
||||
# result is added to the context
|
||||
# Call the `on_context_updated` callback once the function call result
|
||||
# is added to the context. Also, run this in a separate task to make
|
||||
# sure we don't block the pipeline.
|
||||
if properties and properties.on_context_updated:
|
||||
await properties.on_context_updated()
|
||||
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
|
||||
task = self.create_task(properties.on_context_updated(), task_name)
|
||||
self._context_updated_tasks.add(task)
|
||||
task.add_done_callback(self._context_updated_task_finished)
|
||||
|
||||
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
logger.debug(
|
||||
@@ -535,6 +561,13 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
# The task is finished so this should exit immediately. We need to do
|
||||
# this because otherwise the task manager would report a dangling task
|
||||
# if we don't remove it.
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
|
||||
@@ -147,10 +147,13 @@ class FrameProcessor(BaseObject):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
def create_task(self, coroutine: Coroutine) -> asyncio.Task:
|
||||
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
|
||||
if not self._task_manager:
|
||||
raise Exception(f"{self} TaskManager is still not initialized.")
|
||||
name = f"{self}::{coroutine.cr_code.co_name}"
|
||||
if name:
|
||||
name = f"{self}::{name}"
|
||||
else:
|
||||
name = f"{self}::{coroutine.cr_code.co_name}"
|
||||
return self._task_manager.create_task(coroutine, name)
|
||||
|
||||
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):
|
||||
|
||||
@@ -540,10 +540,23 @@ class RTVIObserver(BaseObserver):
|
||||
await self.push_transport_message_urgent(message)
|
||||
|
||||
async def _handle_context(self, frame: OpenAILLMContextFrame):
|
||||
"""Process LLM context frames to extract user messages for the RTVI client."""
|
||||
try:
|
||||
messages = frame.context.messages
|
||||
if len(messages) > 0:
|
||||
message = messages[-1]
|
||||
if not messages:
|
||||
return
|
||||
|
||||
message = messages[-1]
|
||||
|
||||
# Handle Google LLM format (protobuf objects with attributes)
|
||||
if hasattr(message, "role") and message.role == "user" and hasattr(message, "parts"):
|
||||
text = "".join(part.text for part in message.parts if hasattr(part, "text"))
|
||||
if text:
|
||||
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.push_transport_message_urgent(rtvi_message)
|
||||
|
||||
# Handle OpenAI format (original implementation)
|
||||
elif isinstance(message, dict):
|
||||
if message["role"] == "user":
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
@@ -552,7 +565,8 @@ class RTVIObserver(BaseObserver):
|
||||
text = content
|
||||
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.push_transport_message_urgent(rtvi_message)
|
||||
except TypeError as e:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Caught an error while trying to handle context: {e}")
|
||||
|
||||
async def _handle_metrics(self, frame: MetricsFrame):
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -15,15 +16,24 @@ from pipecat.frames.frames import (
|
||||
OutputAudioRawFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
|
||||
|
||||
# Data class for converting transport messages into Protobuf format.
|
||||
@dataclasses.dataclass
|
||||
class MessageFrame:
|
||||
data: str
|
||||
|
||||
|
||||
class ProtobufFrameSerializer(FrameSerializer):
|
||||
SERIALIZABLE_TYPES = {
|
||||
TextFrame: "text",
|
||||
OutputAudioRawFrame: "audio",
|
||||
TranscriptionFrame: "transcription",
|
||||
MessageFrame: "message",
|
||||
}
|
||||
SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}
|
||||
|
||||
@@ -42,6 +52,12 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
return FrameSerializerType.BINARY
|
||||
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
# Wrapping this messages as a JSONFrame to send
|
||||
if isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)):
|
||||
frame = MessageFrame(
|
||||
data=json.dumps(frame.message),
|
||||
)
|
||||
|
||||
proto_frame = frame_protos.Frame()
|
||||
if type(frame) not in self.SERIALIZABLE_TYPES:
|
||||
logger.warning(f"Frame type {type(frame)} is not serializable")
|
||||
|
||||
@@ -369,7 +369,7 @@ class LLMService(AIService):
|
||||
if tuple_to_remove:
|
||||
self._function_call_tasks.discard(tuple_to_remove)
|
||||
# The task is finished so this should exit immediately. We need to
|
||||
# do this because otherwise the task manager would have a dangling
|
||||
# do this because otherwise the task manager would report a dangling
|
||||
# task if we don't remove it.
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
|
||||
|
||||
@@ -1048,9 +1048,14 @@ class SegmentedSTTService(STTService):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
|
||||
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
|
||||
if frame.emulated:
|
||||
return
|
||||
self._user_speaking = True
|
||||
|
||||
async def _handle_user_stopped_speaking(self, frame: UserStoppedSpeakingFrame):
|
||||
if frame.emulated:
|
||||
return
|
||||
|
||||
self._user_speaking = False
|
||||
|
||||
content = io.BytesIO()
|
||||
@@ -1068,7 +1073,7 @@ class SegmentedSTTService(STTService):
|
||||
self._audio_buffer.clear()
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
|
||||
# If the user is speaking the audio buffer will keep growin.
|
||||
# If the user is speaking the audio buffer will keep growing.
|
||||
self._audio_buffer += frame.audio
|
||||
|
||||
# If the user is not speaking we keep just a little bit of audio.
|
||||
|
||||
@@ -725,7 +725,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: str
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message["role"] == "user":
|
||||
|
||||
@@ -601,13 +601,8 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
if frame.result:
|
||||
if not isinstance(frame.result, str):
|
||||
return
|
||||
|
||||
response = {"response": frame.result}
|
||||
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, response
|
||||
frame.function_name, frame.tool_call_id, frame.result
|
||||
)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
@@ -626,7 +621,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
if message.role == "user":
|
||||
for part in message.parts:
|
||||
if part.function_response and part.function_response.id == tool_call_id:
|
||||
part.function_response.response = {"response": result}
|
||||
part.function_response.response = {"value": json.dumps(result)}
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
await self._update_function_call_result(
|
||||
@@ -1348,6 +1343,7 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the VertexLLMService.
|
||||
|
||||
Args:
|
||||
credentials (Optional[str]): JSON string of service account credentials.
|
||||
credentials_path (Optional[str]): Path to the service account JSON file.
|
||||
@@ -1371,9 +1367,11 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
@staticmethod
|
||||
def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
|
||||
"""Retrieves an authentication token using Google service account credentials.
|
||||
|
||||
Args:
|
||||
credentials (Optional[str]): JSON string of service account credentials.
|
||||
credentials_path (Optional[str]): Path to the service account JSON file.
|
||||
|
||||
Returns:
|
||||
str: OAuth token for API authentication.
|
||||
"""
|
||||
@@ -1562,8 +1560,6 @@ class GoogleTTSService(TTSService):
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
finally:
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class GoogleImageGenService(ImageGenService):
|
||||
|
||||
@@ -5,14 +5,26 @@
|
||||
#
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
from groq import AsyncGroq
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Groq, you need to `pip install pipecat-ai[groq]`. Also, set a `GROQ_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GroqLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Groq's API using the OpenAI-compatible interface.
|
||||
@@ -98,3 +110,68 @@ class GroqSTTService(BaseWhisperSTTService):
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
|
||||
|
||||
class GroqTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[float] = 1.0
|
||||
seed: Optional[int] = None
|
||||
|
||||
GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
output_format: str = "wav",
|
||||
params: InputParams = InputParams(),
|
||||
model_name: str = "playai-tts",
|
||||
voice_id: str = "Celeste-PlayAI",
|
||||
sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_rate != self.GROQ_SAMPLE_RATE:
|
||||
logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
|
||||
super().__init__(
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._model_name = model_name
|
||||
self._output_format = output_format
|
||||
self._voice_id = voice_id
|
||||
self._params = params
|
||||
|
||||
self._client = AsyncGroq(api_key=self._api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
measuring_ttfb = True
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
response = await self._client.audio.speech.create(
|
||||
model=self._model_name,
|
||||
voice=self._voice_id,
|
||||
response_format=self._output_format,
|
||||
input=text,
|
||||
)
|
||||
|
||||
async for data in response.iter_bytes():
|
||||
if measuring_ttfb:
|
||||
await self.stop_ttfb_metrics()
|
||||
measuring_ttfb = False
|
||||
# remove wav header if present
|
||||
if data.startswith(b"RIFF"):
|
||||
data = data[44:]
|
||||
if len(data) == 0:
|
||||
continue
|
||||
yield TTSAudioRawFrame(data, self.sample_rate, 1)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -391,6 +391,7 @@ class OpenAIImageGenService(ImageGenService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
model: str = "dall-e-3",
|
||||
@@ -398,7 +399,7 @@ class OpenAIImageGenService(ImageGenService):
|
||||
super().__init__()
|
||||
self.set_model_name(model)
|
||||
self._image_size = image_size
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -501,9 +502,11 @@ class OpenAITTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
voice: str = "alloy",
|
||||
model: str = "gpt-4o-mini-tts",
|
||||
sample_rate: Optional[int] = None,
|
||||
instructions: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
@@ -515,8 +518,8 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice)
|
||||
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
self._instructions = instructions
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -538,11 +541,17 @@ class OpenAITTSService(TTSService):
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Setup extra body parameters
|
||||
extra_body = {}
|
||||
if self._instructions:
|
||||
extra_body["instructions"] = self._instructions
|
||||
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text or " ", # Text must contain at least one character
|
||||
model=self.model_name,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
extra_body=extra_body,
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
error = await r.text()
|
||||
@@ -613,7 +622,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: str
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
|
||||
103
src/pipecat/services/piper.py
Normal file
103
src/pipecat/services/piper.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
|
||||
# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md
|
||||
class PiperTTSService(TTSService):
|
||||
"""Piper TTS service implementation.
|
||||
|
||||
Provides integration with Piper's TTS server.
|
||||
|
||||
Args:
|
||||
base_url: API base URL
|
||||
aiohttp_session: aiohttp ClientSession
|
||||
sample_rate: Output sample rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
# When using Piper, the sample rate of the generated audio depends on the
|
||||
# voice model being used.
|
||||
sample_rate: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
if base_url.endswith("/"):
|
||||
logger.warning("Base URL ends with a slash, this is not allowed.")
|
||||
base_url = base_url[:-1]
|
||||
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
self._settings = {"base_url": base_url}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Piper API.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech
|
||||
|
||||
Yields:
|
||||
Frames containing audio data and status information
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
headers = {
|
||||
"Content-Type": "text/plain",
|
||||
}
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(self._base_url, data=text, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
eror = await response.text()
|
||||
logger.error(
|
||||
f"{self} error getting audio (status: {response.status}, error: {eror})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {response.status}, error: {eror})"
|
||||
)
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Process the streaming response
|
||||
CHUNK_SIZE = 1024
|
||||
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||
# remove wav header if present
|
||||
if chunk.startswith(b"RIFF"):
|
||||
chunk = chunk[44:]
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
@@ -117,10 +117,10 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self._handle_bot_interruption(frame)
|
||||
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
|
||||
logger.debug("Emulating user started speaking")
|
||||
await self._handle_user_interruption(UserStartedSpeakingFrame())
|
||||
await self._handle_user_interruption(UserStartedSpeakingFrame(emulated=True))
|
||||
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
|
||||
logger.debug("Emulating user stopped speaking")
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame())
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame(emulated=True))
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -4,13 +4,18 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -21,10 +26,7 @@ from pipecat.frames.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -423,6 +425,9 @@ class BaseTestAssistantContextAggreagator:
|
||||
):
|
||||
assert context.messages[index]["content"] == content
|
||||
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: str):
|
||||
assert json.loads(context.messages[index]["content"]) == content
|
||||
|
||||
async def test_empty(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
@@ -556,9 +561,76 @@ class BaseTestAssistantContextAggreagator:
|
||||
self.check_message_multi_content(context, 0, 0, "Hello Pipecat.")
|
||||
self.check_message_multi_content(context, 0, 1, "How are you?")
|
||||
|
||||
async def test_function_call(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
SleepFrame(),
|
||||
FunctionCallResultFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
result={"conditions": "Sunny"},
|
||||
),
|
||||
]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_function_call_result(context, -1, {"conditions": "Sunny"})
|
||||
|
||||
async def test_function_call_on_context_updated(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context_updated = False
|
||||
|
||||
async def on_context_updated():
|
||||
nonlocal context_updated
|
||||
context_updated = True
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
SleepFrame(),
|
||||
FunctionCallResultFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
result={"conditions": "Sunny"},
|
||||
properties=FunctionCallResultProperties(on_context_updated=on_context_updated),
|
||||
),
|
||||
SleepFrame(),
|
||||
]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_function_call_result(context, -1, {"conditions": "Sunny"})
|
||||
assert context_updated
|
||||
|
||||
|
||||
#
|
||||
# LLMUserContextAggregator, LLMAssistantContextAggregator
|
||||
# LLMUserContextAggregator
|
||||
#
|
||||
|
||||
|
||||
@@ -567,14 +639,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola
|
||||
AGGREGATOR_CLASS = LLMUserContextAggregator
|
||||
|
||||
|
||||
class TestLLMAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = LLMAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
|
||||
#
|
||||
# OpenAI
|
||||
#
|
||||
@@ -626,6 +690,9 @@ class TestAnthropicAssistantContextAggregator(
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
assert context.messages[index]["content"][0]["content"] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# Google
|
||||
@@ -665,3 +732,7 @@ class TestGoogleAssistantContextAggregator(
|
||||
):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)
|
||||
|
||||
132
tests/test_piper_tts.py
Normal file
132
tests/test_piper_tts.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for PiperTTSService."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.services.piper import PiperTTSService
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_piper_tts_success(aiohttp_client):
|
||||
"""Test successful TTS generation with chunked audio data.
|
||||
|
||||
Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
|
||||
"""
|
||||
|
||||
async def handler(request):
|
||||
# The service expects a /?text= param
|
||||
# Here we're just returning dummy chunked bytes to simulate an audio response
|
||||
text_query = request.rel_url.query.get("text", "")
|
||||
print(f"Mock server received text param: {text_query}")
|
||||
|
||||
# Prepare a StreamResponse with chunked data
|
||||
resp = web.StreamResponse(
|
||||
status=200,
|
||||
reason="OK",
|
||||
headers={"Content-Type": "audio/raw"},
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
# Write out some chunked byte data
|
||||
# In reality, you’d return WAV data or similar
|
||||
data_chunk_1 = b"\x00\x01\x02\x03" * 1024 # 4096 bytes, 04 TTSAudioRawFrame
|
||||
data_chunk_2 = b"\x04\x05\x06\x07" * 1024 # another chunk
|
||||
await resp.write(data_chunk_1)
|
||||
await asyncio.sleep(0.01) # simulate async chunk delay
|
||||
await resp.write(data_chunk_2)
|
||||
await resp.write_eof()
|
||||
|
||||
return resp
|
||||
|
||||
# Create an aiohttp test server
|
||||
app = web.Application()
|
||||
app.router.add_post("/", handler)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Remove trailing slash if present in the test URL
|
||||
base_url = str(client.make_url("")).rstrip("/")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Instantiate PiperTTSService with our mock server
|
||||
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="Hello world."),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
]
|
||||
|
||||
frames_received = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
down_frames = frames_received[0]
|
||||
audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)]
|
||||
for a_frame in audio_frames:
|
||||
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_piper_tts_error(aiohttp_client):
|
||||
"""Test how the service handles a non-200 response from the server.
|
||||
|
||||
Expects an ErrorFrame to be returned.
|
||||
"""
|
||||
|
||||
async def handler(_request):
|
||||
# Return an error status for any request
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/", handler)
|
||||
client = await aiohttp_client(app)
|
||||
base_url = str(client.make_url("")).rstrip("/")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000)
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="Error case."),
|
||||
]
|
||||
|
||||
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
frames_received = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
up_frames = frames_received[1]
|
||||
|
||||
assert isinstance(up_frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
|
||||
assert "status: 404" in up_frames[0].error, (
|
||||
"ErrorFrame should contain details about the 404"
|
||||
)
|
||||
Reference in New Issue
Block a user