diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c06240bc..ef26f0266 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,59 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- ElevenLabs TTS services now support a sample rate of 8000. +- Added support to `ProtobufFrameSerializer` to send the messages from `TransportMessageFrame` and `TransportMessageUrgentFrame`. + +- Added support for a new TTS service, `PiperTTSService`. + (see https://github.com/rhasspy/piper/) + +- It is now possible to tell whether `UserStartedSpeakingFrame` or + `UserStoppedSpeakingFrame` have been generated because of emulation frames. ### Fixed +- Fixed an issue that would cause `SegmentedSTTService` based services + (e.g. `OpenAISTTService`) to try to transcribe non-spoken audio, causing + invalid transcriptions. + +- Fixed an issue where `GoogleTTSService` was emitting two `TTSStoppedFrames`. + +## [0.0.61] - 2025-03-26 + +### Added + +- Added a new frame, `LLMSetToolChoiceFrame`, which provides a mechanism + for modifying the `tool_choice` in the context. + +- Added `GroqTTSService` which provides text-to-speech functionality using + Groq's API. + +- Added support in `DailyTransport` for updating remote participants' + `canReceive` permission via the `update_remote_participants()` method, by + bumping the daily-python dependency to >= 0.16.0. + +- ElevenLabs TTS services now support a sample rate of 8000. + +- Added support for `instructions` in `OpenAITTSService`. + +- Added support for `base_url` in `OpenAIImageGenService` and + `OpenAITTSService`. + +### Fixed + +- Fixed an issue in `RTVIObserver` that prevented handling of Google LLM + context messages. The observer now processes both OpenAI-style and + Google-style contexts. + +- Fixed an issue in Daily involving switching virtual devices, by bumping the + daily-python dependency to >= 0.16.1. + +- Fixed a `GoogleAssistantContextAggregator` issue where function calls + placeholders where not being updated when then function call result was + different from a string. + +- Fixed an issue that would cause `LLMAssistantContextAggregator` to block + processing more frames while processing a function call result. + - Fixed an issue where the `RTVIObserver` would report two bot started and stopped speaking events for each bot turn. diff --git a/README.md b/README.md index f28005618..c49478d27 100644 --- a/README.md +++ b/README.md @@ -55,17 +55,17 @@ pip install "pipecat-ai[option,...]" ### Available services -| Category | Services | Install Command Example | -| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- | -| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` | -| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` | -| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` | -| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[google]"` | -| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | `pip install "pipecat-ai[daily]"` | -| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` | -| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | `pip install "pipecat-ai[moondream]"` | -| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` | -| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` | +| Category | Services | Install Command Example | +| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- | +| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` | +| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` | +| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` | +| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[google]"` | +| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | `pip install "pipecat-ai[daily]"` | +| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` | +| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | `pip install "pipecat-ai[moondream]"` | +| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` | +| Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` | 📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services) diff --git a/dev-requirements.txt b/dev-requirements.txt index e65c2755c..af1c35721 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,6 +6,7 @@ pre-commit~=4.0.1 pyright~=1.1.397 pytest~=8.3.4 pytest-asyncio~=0.25.3 +pytest-aiohttp==1.1.0 ruff~=0.11.1 setuptools~=70.0.0 setuptools_scm~=8.1.0 diff --git a/dot-env.template b/dot-env.template index 2da20fc0b..f0b5bdc0f 100644 --- a/dot-env.template +++ b/dot-env.template @@ -90,3 +90,6 @@ ASSEMBLYAI_API_KEY=... # OpenRouter OPENROUTER_API_KEY=... + +# Piper +PIPER_BASE_URL=... \ No newline at end of file diff --git a/examples/foundational/01-say-one-thing-piper.py b/examples/foundational/01-say-one-thing-piper.py new file mode 100644 index 000000000..256447c23 --- /dev/null +++ b/examples/foundational/01-say-one-thing-piper.py @@ -0,0 +1,57 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import EndFrame, TTSSpeakFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.services.piper import PiperTTSService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, _) = await configure(session) + + transport = DailyTransport( + room_url, None, "Say One Thing", DailyParams(audio_out_enabled=True) + ) + + tts = PiperTTSService( + base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000 + ) + + runner = PipelineRunner() + + task = PipelineTask(Pipeline([tts, transport.output()])) + + # Register an event handler so we can play the audio when the + # participant joins. + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await task.queue_frames( + [TTSSpeakFrame(f"Hello there, how are you today ?"), EndFrame()] + ) + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/07e-interruptible-playht.py b/examples/foundational/07e-interruptible-playht.py index 5ccb96c15..c402c09db 100644 --- a/examples/foundational/07e-interruptible-playht.py +++ b/examples/foundational/07e-interruptible-playht.py @@ -48,7 +48,7 @@ async def main(): tts = PlayHTTTSService( user_id=os.getenv("PLAYHT_USER_ID"), api_key=os.getenv("PLAYHT_API_KEY"), - voice_url="s3://voice-cloning-zero-shot/d9ff78ba-d016-47f6-b0ef-dd630f59414e/female-cs/manifest.json", + voice_url="s3://voice-cloning-zero-shot/e46b4027-b38d-4d24-b292-38fbca2be0ef/original/manifest.json", params=PlayHTTTSService.InputParams(language=Language.EN), ) diff --git a/examples/foundational/07y-interruptible-groq.py b/examples/foundational/07y-interruptible-groq.py new file mode 100644 index 000000000..9e5719c21 --- /dev/null +++ b/examples/foundational/07y-interruptible-groq.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.groq import GroqLLMService, GroqSTTService, GroqTTSService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + # transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY")) + + llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile") + + tts = GroqTTSService(api_key=os.getenv("GROQ_API_KEY")) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_participant_left") + async def on_participant_left(transport, participant, reason): + await task.cancel() + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 0d1e46ed9..0f2e993ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ] neuphonic = [ "pyneuphonic~=1.5.13", "websockets~=13.1" ] cerebras = [] deepseek = [] -daily = [ "daily-python~=0.15.0" ] +daily = [ "daily-python~=0.16.1" ] deepgram = [ "deepgram-sdk~=3.8.0" ] elevenlabs = [ "websockets~=13.1" ] fal = [ "fal-client~=0.5.9" ] @@ -56,7 +56,7 @@ fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ] grok = [] -groq = [] +groq = [ "groq~=0.20.0" ] gstreamer = [ "pygobject~=3.50.0" ] fireworks = [] krisp = [ "pipecat-ai-krisp~=0.3.0" ] diff --git a/src/pipecat/frames/frames.proto b/src/pipecat/frames/frames.proto index 98dc014db..ebdb16fcc 100644 --- a/src/pipecat/frames/frames.proto +++ b/src/pipecat/frames/frames.proto @@ -35,10 +35,15 @@ message TranscriptionFrame { string timestamp = 5; } +message MessageFrame { + string data = 1; +} + message Frame { oneof frame { TextFrame text = 1; AudioRawFrame audio = 2; TranscriptionFrame transcription = 3; + MessageFrame message = 4; } } diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index c2a79461f..30d8622d9 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -363,6 +363,13 @@ class LLMSetToolsFrame(DataFrame): tools: List[dict] +@dataclass +class LLMSetToolChoiceFrame(DataFrame): + """A frame containing a tool choice for an LLM to use for function calling.""" + + tool_choice: Literal["none", "auto", "required"] | dict + + @dataclass class LLMEnablePromptCachingFrame(DataFrame): """A frame to enable/disable prompt caching in certain LLMs.""" @@ -384,7 +391,7 @@ class FunctionCallResultFrame(DataFrame): function_name: str tool_call_id: str - arguments: str + arguments: Any result: Any properties: Optional[FunctionCallResultProperties] = None @@ -555,14 +562,14 @@ class UserStartedSpeakingFrame(SystemFrame): """ - pass + emulated: bool = False @dataclass class UserStoppedSpeakingFrame(SystemFrame): """Emitted by the VAD to indicate that a user stopped speaking.""" - pass + emulated: bool = False @dataclass @@ -633,8 +640,8 @@ class FunctionCallInProgressFrame(SystemFrame): function_name: str tool_call_id: str - arguments: str - cancel_on_interruption: bool + arguments: Any + cancel_on_interruption: bool = False @dataclass diff --git a/src/pipecat/frames/protobufs/frames_pb2.py b/src/pipecat/frames/protobufs/frames_pb2.py index d58bc8baa..7884c6ccc 100644 --- a/src/pipecat/frames/protobufs/frames_pb2.py +++ b/src/pipecat/frames/protobufs/frames_pb2.py @@ -1,12 +1,22 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: frames.proto -# Protobuf Python Version: 4.25.1 +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 27, + 2, + '', + 'frames.proto' +) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -14,19 +24,21 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"}\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\x12\x10\n\x03pts\x18\x06 \x01(\x04H\x00\x88\x01\x01\x42\x06\n\x04_pts\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x93\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"}\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\x12\x10\n\x03pts\x18\x06 \x01(\x04H\x00\x88\x01\x01\x42\x06\n\x04_pts\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x1c\n\x0cMessageFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\t\"\xbd\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x12(\n\x07message\x18\x04 \x01(\x0b\x32\x15.pipecat.MessageFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None _globals['_TEXTFRAME']._serialized_start=25 _globals['_TEXTFRAME']._serialized_end=76 _globals['_AUDIORAWFRAME']._serialized_start=78 _globals['_AUDIORAWFRAME']._serialized_end=203 _globals['_TRANSCRIPTIONFRAME']._serialized_start=205 _globals['_TRANSCRIPTIONFRAME']._serialized_end=301 - _globals['_FRAME']._serialized_start=304 - _globals['_FRAME']._serialized_end=451 + _globals['_MESSAGEFRAME']._serialized_start=303 + _globals['_MESSAGEFRAME']._serialized_end=331 + _globals['_FRAME']._serialized_start=334 + _globals['_FRAME']._serialized_end=523 # @@protoc_insertion_point(module_scope) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 75435a214..e40ad266b 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -6,7 +6,7 @@ import asyncio from abc import abstractmethod -from typing import Dict, List +from typing import Dict, List, Literal, Set from loguru import logger @@ -26,6 +26,7 @@ from pipecat.frames.frames import ( LLMMessagesAppendFrame, LLMMessagesFrame, LLMMessagesUpdateFrame, + LLMSetToolChoiceFrame, LLMSetToolsFrame, LLMTextFrame, OpenAILLMContextAssistantTimestampFrame, @@ -140,6 +141,11 @@ class BaseLLMResponseAggregator(FrameProcessor): """Set LLM tools to be used in the current conversation.""" pass + @abstractmethod + def set_tool_choice(self, tool_choice): + """Set the tool choice. This should modify the LLM context.""" + pass + @abstractmethod def reset(self): """Reset the internals of this aggregator. This should not modify the @@ -204,6 +210,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): def set_tools(self, tools: List): self._context.set_tools(tools) + def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict): + self._context.set_tool_choice(tool_choice) + def reset(self): self._aggregation = "" @@ -240,7 +249,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): self._waiting_for_aggregation = False async def handle_aggregation(self, aggregation: str): - self._context.add_message({"role": self.role, "content": self._aggregation}) + self._context.add_message({"role": self.role, "content": aggregation}) async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -274,17 +283,21 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): self.set_messages(frame.messages) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) + elif isinstance(frame, LLMSetToolChoiceFrame): + self.set_tool_choice(frame.tool_choice) else: await self.push_frame(frame, direction) async def push_aggregation(self): if len(self._aggregation) > 0: - await self.handle_aggregation(self._aggregation) + aggregation = self._aggregation # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. self.reset() + await self.handle_aggregation(aggregation) + frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame) @@ -297,10 +310,16 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): async def _cancel(self, frame: CancelFrame): await self._cancel_aggregation_task() - async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame): + async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): self._user_speaking = True self._waiting_for_aggregation = True + # If we get a non-emulated UserStartedSpeakingFrame but we are in the + # middle of emulating VAD, let's stop emulating VAD (i.e. don't send the + # EmulateUserStoppedSpeakingFrame). + if not frame.emulated and self._emulating_vad: + self._emulating_vad = False + async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): self._user_speaking = False # We just stopped speaking. Let's see if there's some aggregation to @@ -380,6 +399,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): self._started = 0 self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {} + self._context_updated_tasks: Set[asyncio.Task] = set() async def handle_aggregation(self, aggregation: str): self._context.add_message({"role": "assistant", "content": aggregation}) @@ -414,6 +434,8 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): self.set_messages(frame.messages) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) + elif isinstance(frame, LLMSetToolChoiceFrame): + self.set_tool_choice(frame.tool_choice) elif isinstance(frame, FunctionCallInProgressFrame): await self._handle_function_call_in_progress(frame) elif isinstance(frame, FunctionCallResultFrame): @@ -486,10 +508,14 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): if run_llm: await self.push_context_frame(FrameDirection.UPSTREAM) - # Emit the on_context_updated callback once the function call - # result is added to the context + # Call the `on_context_updated` callback once the function call result + # is added to the context. Also, run this in a separate task to make + # sure we don't block the pipeline. if properties and properties.on_context_updated: - await properties.on_context_updated() + task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated" + task = self.create_task(properties.on_context_updated(), task_name) + self._context_updated_tasks.add(task) + task.add_done_callback(self._context_updated_task_finished) async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame): logger.debug( @@ -535,6 +561,13 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): else: self._aggregation += frame.text + def _context_updated_task_finished(self, task: asyncio.Task): + self._context_updated_tasks.discard(task) + # The task is finished so this should exit immediately. We need to do + # this because otherwise the task manager would report a dangling task + # if we don't remove it. + asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop()) + class LLMUserResponseAggregator(LLMUserContextAggregator): def __init__(self, messages: List[dict] = [], **kwargs): diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 847cdf175..590698e7f 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -147,10 +147,13 @@ class FrameProcessor(BaseObject): await self.stop_ttfb_metrics() await self.stop_processing_metrics() - def create_task(self, coroutine: Coroutine) -> asyncio.Task: + def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task: if not self._task_manager: raise Exception(f"{self} TaskManager is still not initialized.") - name = f"{self}::{coroutine.cr_code.co_name}" + if name: + name = f"{self}::{name}" + else: + name = f"{self}::{coroutine.cr_code.co_name}" return self._task_manager.create_task(coroutine, name) async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None): diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index f782e6ea8..eec07a29f 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -540,10 +540,23 @@ class RTVIObserver(BaseObserver): await self.push_transport_message_urgent(message) async def _handle_context(self, frame: OpenAILLMContextFrame): + """Process LLM context frames to extract user messages for the RTVI client.""" try: messages = frame.context.messages - if len(messages) > 0: - message = messages[-1] + if not messages: + return + + message = messages[-1] + + # Handle Google LLM format (protobuf objects with attributes) + if hasattr(message, "role") and message.role == "user" and hasattr(message, "parts"): + text = "".join(part.text for part in message.parts if hasattr(part, "text")) + if text: + rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text)) + await self.push_transport_message_urgent(rtvi_message) + + # Handle OpenAI format (original implementation) + elif isinstance(message, dict): if message["role"] == "user": content = message["content"] if isinstance(content, list): @@ -552,7 +565,8 @@ class RTVIObserver(BaseObserver): text = content rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text)) await self.push_transport_message_urgent(rtvi_message) - except TypeError as e: + + except Exception as e: logger.warning(f"Caught an error while trying to handle context: {e}") async def _handle_metrics(self, frame: MetricsFrame): diff --git a/src/pipecat/serializers/protobuf.py b/src/pipecat/serializers/protobuf.py index 125f2037f..c3b6d86af 100644 --- a/src/pipecat/serializers/protobuf.py +++ b/src/pipecat/serializers/protobuf.py @@ -5,6 +5,7 @@ # import dataclasses +import json from loguru import logger @@ -15,15 +16,24 @@ from pipecat.frames.frames import ( OutputAudioRawFrame, TextFrame, TranscriptionFrame, + TransportMessageFrame, + TransportMessageUrgentFrame, ) from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType +# Data class for converting transport messages into Protobuf format. +@dataclasses.dataclass +class MessageFrame: + data: str + + class ProtobufFrameSerializer(FrameSerializer): SERIALIZABLE_TYPES = { TextFrame: "text", OutputAudioRawFrame: "audio", TranscriptionFrame: "transcription", + MessageFrame: "message", } SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()} @@ -42,6 +52,12 @@ class ProtobufFrameSerializer(FrameSerializer): return FrameSerializerType.BINARY async def serialize(self, frame: Frame) -> str | bytes | None: + # Wrapping this messages as a JSONFrame to send + if isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)): + frame = MessageFrame( + data=json.dumps(frame.message), + ) + proto_frame = frame_protos.Frame() if type(frame) not in self.SERIALIZABLE_TYPES: logger.warning(f"Frame type {type(frame)} is not serializable") diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 9f9804e65..97aad0d40 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -369,7 +369,7 @@ class LLMService(AIService): if tuple_to_remove: self._function_call_tasks.discard(tuple_to_remove) # The task is finished so this should exit immediately. We need to - # do this because otherwise the task manager would have a dangling + # do this because otherwise the task manager would report a dangling # task if we don't remove it. asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop()) @@ -1048,9 +1048,14 @@ class SegmentedSTTService(STTService): await self._handle_user_stopped_speaking(frame) async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): + if frame.emulated: + return self._user_speaking = True async def _handle_user_stopped_speaking(self, frame: UserStoppedSpeakingFrame): + if frame.emulated: + return + self._user_speaking = False content = io.BytesIO() @@ -1068,7 +1073,7 @@ class SegmentedSTTService(STTService): self._audio_buffer.clear() async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): - # If the user is speaking the audio buffer will keep growin. + # If the user is speaking the audio buffer will keep growing. self._audio_buffer += frame.audio # If the user is not speaking we keep just a little bit of audio. diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 6a95d04e2..3e369075a 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -725,7 +725,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): ) async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: str + self, function_name: str, tool_call_id: str, result: Any ): for message in self._context.messages: if message["role"] == "user": diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index bfddce46d..95c5a1edb 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -601,13 +601,8 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): async def handle_function_call_result(self, frame: FunctionCallResultFrame): if frame.result: - if not isinstance(frame.result, str): - return - - response = {"response": frame.result} - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, response + frame.function_name, frame.tool_call_id, frame.result ) else: await self._update_function_call_result( @@ -626,7 +621,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): if message.role == "user": for part in message.parts: if part.function_response and part.function_response.id == tool_call_id: - part.function_response.response = {"response": result} + part.function_response.response = {"value": json.dumps(result)} async def handle_user_image_frame(self, frame: UserImageRawFrame): await self._update_function_call_result( @@ -1348,6 +1343,7 @@ class GoogleVertexLLMService(OpenAILLMService): **kwargs, ): """Initializes the VertexLLMService. + Args: credentials (Optional[str]): JSON string of service account credentials. credentials_path (Optional[str]): Path to the service account JSON file. @@ -1371,9 +1367,11 @@ class GoogleVertexLLMService(OpenAILLMService): @staticmethod def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str: """Retrieves an authentication token using Google service account credentials. + Args: credentials (Optional[str]): JSON string of service account credentials. credentials_path (Optional[str]): Path to the service account JSON file. + Returns: str: OAuth token for API authentication. """ @@ -1562,8 +1560,6 @@ class GoogleTTSService(TTSService): logger.exception(f"{self} error generating TTS: {e}") error_message = f"TTS generation error: {str(e)}" yield ErrorFrame(error=error_message) - finally: - yield TTSStoppedFrame() class GoogleImageGenService(ImageGenService): diff --git a/src/pipecat/services/groq.py b/src/pipecat/services/groq.py index 66cc9357f..bf0304df2 100644 --- a/src/pipecat/services/groq.py +++ b/src/pipecat/services/groq.py @@ -5,14 +5,26 @@ # -from typing import Optional +from typing import AsyncGenerator, Optional from loguru import logger +from pydantic import BaseModel +from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame +from pipecat.services.ai_services import TTSService from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription from pipecat.services.openai import OpenAILLMService from pipecat.transcriptions.language import Language +try: + from groq import AsyncGroq +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Groq, you need to `pip install pipecat-ai[groq]`. Also, set a `GROQ_API_KEY` environment variable." + ) + raise Exception(f"Missing module: {e}") + class GroqLLMService(OpenAILLMService): """A service for interacting with Groq's API using the OpenAI-compatible interface. @@ -98,3 +110,68 @@ class GroqSTTService(BaseWhisperSTTService): kwargs["temperature"] = self._temperature return await self._client.audio.transcriptions.create(**kwargs) + + +class GroqTTSService(TTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN + speed: Optional[float] = 1.0 + seed: Optional[int] = None + + GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate + + def __init__( + self, + *, + api_key: str, + output_format: str = "wav", + params: InputParams = InputParams(), + model_name: str = "playai-tts", + voice_id: str = "Celeste-PlayAI", + sample_rate: Optional[int] = GROQ_SAMPLE_RATE, + **kwargs, + ): + if sample_rate != self.GROQ_SAMPLE_RATE: + logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ") + super().__init__( + pause_frame_processing=True, + sample_rate=sample_rate, + **kwargs, + ) + + self._api_key = api_key + self._model_name = model_name + self._output_format = output_format + self._voice_id = voice_id + self._params = params + + self._client = AsyncGroq(api_key=self._api_key) + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + measuring_ttfb = True + await self.start_ttfb_metrics() + yield TTSStartedFrame() + + response = await self._client.audio.speech.create( + model=self._model_name, + voice=self._voice_id, + response_format=self._output_format, + input=text, + ) + + async for data in response.iter_bytes(): + if measuring_ttfb: + await self.stop_ttfb_metrics() + measuring_ttfb = False + # remove wav header if present + if data.startswith(b"RIFF"): + data = data[44:] + if len(data) == 0: + continue + yield TTSAudioRawFrame(data, self.sample_rate, 1) + + yield TTSStoppedFrame() diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index ff7bc0442..3f85d917c 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -391,6 +391,7 @@ class OpenAIImageGenService(ImageGenService): self, *, api_key: str, + base_url: Optional[str] = None, aiohttp_session: aiohttp.ClientSession, image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], model: str = "dall-e-3", @@ -398,7 +399,7 @@ class OpenAIImageGenService(ImageGenService): super().__init__() self.set_model_name(model) self._image_size = image_size - self._client = AsyncOpenAI(api_key=api_key) + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) self._aiohttp_session = aiohttp_session async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: @@ -501,9 +502,11 @@ class OpenAITTSService(TTSService): self, *, api_key: Optional[str] = None, + base_url: Optional[str] = None, voice: str = "alloy", model: str = "gpt-4o-mini-tts", sample_rate: Optional[int] = None, + instructions: Optional[str] = None, **kwargs, ): if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE: @@ -515,8 +518,8 @@ class OpenAITTSService(TTSService): self.set_model_name(model) self.set_voice(voice) - - self._client = AsyncOpenAI(api_key=api_key) + self._instructions = instructions + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) def can_generate_metrics(self) -> bool: return True @@ -538,11 +541,17 @@ class OpenAITTSService(TTSService): try: await self.start_ttfb_metrics() + # Setup extra body parameters + extra_body = {} + if self._instructions: + extra_body["instructions"] = self._instructions + async with self._client.audio.speech.with_streaming_response.create( input=text or " ", # Text must contain at least one character model=self.model_name, voice=VALID_VOICES[self._voice_id], response_format="pcm", + extra_body=extra_body, ) as r: if r.status_code != 200: error = await r.text() @@ -613,7 +622,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): ) async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: str + self, function_name: str, tool_call_id: str, result: Any ): for message in self._context.messages: if ( diff --git a/src/pipecat/services/piper.py b/src/pipecat/services/piper.py new file mode 100644 index 000000000..12a936889 --- /dev/null +++ b/src/pipecat/services/piper.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import AsyncGenerator, Optional + +import aiohttp +from loguru import logger + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService + + +# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md +class PiperTTSService(TTSService): + """Piper TTS service implementation. + + Provides integration with Piper's TTS server. + + Args: + base_url: API base URL + aiohttp_session: aiohttp ClientSession + sample_rate: Output sample rate + """ + + def __init__( + self, + *, + base_url: str, + aiohttp_session: aiohttp.ClientSession, + # When using Piper, the sample rate of the generated audio depends on the + # voice model being used. + sample_rate: Optional[int] = None, + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + if base_url.endswith("/"): + logger.warning("Base URL ends with a slash, this is not allowed.") + base_url = base_url[:-1] + + self._base_url = base_url + self._session = aiohttp_session + self._settings = {"base_url": base_url} + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + """Generate speech from text using Piper API. + + Args: + text: The text to convert to speech + + Yields: + Frames containing audio data and status information + """ + logger.debug(f"{self}: Generating TTS [{text}]") + headers = { + "Content-Type": "text/plain", + } + try: + await self.start_ttfb_metrics() + + async with self._session.post(self._base_url, data=text, headers=headers) as response: + if response.status != 200: + eror = await response.text() + logger.error( + f"{self} error getting audio (status: {response.status}, error: {eror})" + ) + yield ErrorFrame( + f"Error getting audio (status: {response.status}, error: {eror})" + ) + return + + await self.start_tts_usage_metrics(text) + + # Process the streaming response + CHUNK_SIZE = 1024 + + yield TTSStartedFrame() + async for chunk in response.content.iter_chunked(CHUNK_SIZE): + # remove wav header if present + if chunk.startswith(b"RIFF"): + chunk = chunk[44:] + if len(chunk) > 0: + await self.stop_ttfb_metrics() + yield TTSAudioRawFrame(chunk, self.sample_rate, 1) + except Exception as e: + logger.error(f"Error in run_tts: {e}") + yield ErrorFrame(error=str(e)) + finally: + logger.debug(f"{self}: Finished TTS [{text}]") + await self.stop_ttfb_metrics() + yield TTSStoppedFrame() diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 971dfe066..26f386576 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -117,10 +117,10 @@ class BaseInputTransport(FrameProcessor): await self._handle_bot_interruption(frame) elif isinstance(frame, EmulateUserStartedSpeakingFrame): logger.debug("Emulating user started speaking") - await self._handle_user_interruption(UserStartedSpeakingFrame()) + await self._handle_user_interruption(UserStartedSpeakingFrame(emulated=True)) elif isinstance(frame, EmulateUserStoppedSpeakingFrame): logger.debug("Emulating user stopped speaking") - await self._handle_user_interruption(UserStoppedSpeakingFrame()) + await self._handle_user_interruption(UserStoppedSpeakingFrame(emulated=True)) # All other system frames elif isinstance(frame, SystemFrame): await self.push_frame(frame, direction) diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index 185725632..df00f11e4 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -4,13 +4,18 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import json import unittest +from typing import Any import google.ai.generativelanguage as glm from pipecat.frames.frames import ( EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + FunctionCallResultProperties, InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -21,10 +26,7 @@ from pipecat.frames.frames import ( UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantContextAggregator, - LLMUserContextAggregator, -) +from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -423,6 +425,9 @@ class BaseTestAssistantContextAggreagator: ): assert context.messages[index]["content"] == content + def check_function_call_result(self, context: OpenAILLMContext, index: int, content: str): + assert json.loads(context.messages[index]["content"]) == content + async def test_empty(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" @@ -556,9 +561,76 @@ class BaseTestAssistantContextAggreagator: self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") self.check_message_multi_content(context, 0, 1, "How are you?") + async def test_function_call(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + FunctionCallInProgressFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + cancel_on_interruption=False, + ), + SleepFrame(), + FunctionCallResultFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + result={"conditions": "Sunny"}, + ), + ] + expected_down_frames = [] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_function_call_result(context, -1, {"conditions": "Sunny"}) + + async def test_function_call_on_context_updated(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context_updated = False + + async def on_context_updated(): + nonlocal context_updated + context_updated = True + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + FunctionCallInProgressFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + cancel_on_interruption=False, + ), + SleepFrame(), + FunctionCallResultFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + result={"conditions": "Sunny"}, + properties=FunctionCallResultProperties(on_context_updated=on_context_updated), + ), + SleepFrame(), + ] + expected_down_frames = [] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_function_call_result(context, -1, {"conditions": "Sunny"}) + assert context_updated + # -# LLMUserContextAggregator, LLMAssistantContextAggregator +# LLMUserContextAggregator # @@ -567,14 +639,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola AGGREGATOR_CLASS = LLMUserContextAggregator -class TestLLMAssistantContextAggregator( - BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = LLMAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - # # OpenAI # @@ -626,6 +690,9 @@ class TestAnthropicAssistantContextAggregator( messages = context.messages[content_index] assert messages["content"][index]["text"] == content + def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): + assert context.messages[index]["content"][0]["content"] == json.dumps(content) + # # Google @@ -665,3 +732,7 @@ class TestGoogleAssistantContextAggregator( ): obj = glm.Content.to_dict(context.messages[index]) assert obj["parts"][0]["text"] == content + + def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content) diff --git a/tests/test_piper_tts.py b/tests/test_piper_tts.py new file mode 100644 index 000000000..8db6bf1d1 --- /dev/null +++ b/tests/test_piper_tts.py @@ -0,0 +1,132 @@ +"""Tests for PiperTTSService.""" + +import asyncio + +import aiohttp +import pytest +from aiohttp import web + +from pipecat.frames.frames import ( + ErrorFrame, + TTSAudioRawFrame, + TTSSpeakFrame, + TTSStartedFrame, + TTSStoppedFrame, + TTSTextFrame, +) +from pipecat.services.piper import PiperTTSService +from pipecat.tests.utils import run_test + + +@pytest.mark.asyncio +async def test_run_piper_tts_success(aiohttp_client): + """Test successful TTS generation with chunked audio data. + + Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame. + """ + + async def handler(request): + # The service expects a /?text= param + # Here we're just returning dummy chunked bytes to simulate an audio response + text_query = request.rel_url.query.get("text", "") + print(f"Mock server received text param: {text_query}") + + # Prepare a StreamResponse with chunked data + resp = web.StreamResponse( + status=200, + reason="OK", + headers={"Content-Type": "audio/raw"}, + ) + await resp.prepare(request) + + # Write out some chunked byte data + # In reality, you’d return WAV data or similar + data_chunk_1 = b"\x00\x01\x02\x03" * 1024 # 4096 bytes, 04 TTSAudioRawFrame + data_chunk_2 = b"\x04\x05\x06\x07" * 1024 # another chunk + await resp.write(data_chunk_1) + await asyncio.sleep(0.01) # simulate async chunk delay + await resp.write(data_chunk_2) + await resp.write_eof() + + return resp + + # Create an aiohttp test server + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + # Remove trailing slash if present in the test URL + base_url = str(client.make_url("")).rstrip("/") + + async with aiohttp.ClientSession() as session: + # Instantiate PiperTTSService with our mock server + tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000) + + frames_to_send = [ + TTSSpeakFrame(text="Hello world."), + ] + + expected_returned_frames = [ + TTSStartedFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSAudioRawFrame, + TTSStoppedFrame, + TTSTextFrame, + ] + + frames_received = await run_test( + tts_service, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + down_frames = frames_received[0] + audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)] + for a_frame in audio_frames: + assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)" + + +@pytest.mark.asyncio +async def test_run_piper_tts_error(aiohttp_client): + """Test how the service handles a non-200 response from the server. + + Expects an ErrorFrame to be returned. + """ + + async def handler(_request): + # Return an error status for any request + return web.Response(status=404, text="Not found") + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + base_url = str(client.make_url("")).rstrip("/") + + async with aiohttp.ClientSession() as session: + tts_service = PiperTTSService(base_url=base_url, aiohttp_session=session, sample_rate=24000) + + frames_to_send = [ + TTSSpeakFrame(text="Error case."), + ] + + expected_down_frames = [TTSStoppedFrame, TTSTextFrame] + + expected_up_frames = [ErrorFrame] + + frames_received = await run_test( + tts_service, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + up_frames = frames_received[1] + + assert isinstance(up_frames[0], ErrorFrame), "Must receive an ErrorFrame for 404" + assert "status: 404" in up_frames[0].error, ( + "ErrorFrame should contain details about the 404" + )