Compare commits
65 Commits
hush/realt
...
aleix/smar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9d2cd6d30 | ||
|
|
710eebab09 | ||
|
|
532423eb4c | ||
|
|
bb29e50adb | ||
|
|
4048d6782b | ||
|
|
76d36a312b | ||
|
|
2a75373c04 | ||
|
|
a840b0e815 | ||
|
|
ebcde719a6 | ||
|
|
5c912927bb | ||
|
|
0e55db054e | ||
|
|
5967ac0d4f | ||
|
|
1451483cf7 | ||
|
|
c14b85c12b | ||
|
|
9f3c0219d7 | ||
|
|
ec36fef26e | ||
|
|
5f1848d24b | ||
|
|
d6867bd12f | ||
|
|
17a1f30572 | ||
|
|
8e0dc1f256 | ||
|
|
b9100beee3 | ||
|
|
b8bc3d2565 | ||
|
|
3213e85b7d | ||
|
|
de3bcd64c4 | ||
|
|
ad7f1eec12 | ||
|
|
29310b4e92 | ||
|
|
2f4d36a146 | ||
|
|
6c9bb782b1 | ||
|
|
010d9103d4 | ||
|
|
12131eb7c5 | ||
|
|
80b830322a | ||
|
|
8db9d16174 | ||
|
|
1c92fab1fb | ||
|
|
974717d1b9 | ||
|
|
59fb631390 | ||
|
|
4824220260 | ||
|
|
55a338614d | ||
|
|
f033046963 | ||
|
|
6018fc068c | ||
|
|
d5b634301f | ||
|
|
a37eb1049d | ||
|
|
803ea9d8bc | ||
|
|
499bc25217 | ||
|
|
53d403af4b | ||
|
|
a0a8ea1641 | ||
|
|
26c68ccd7c | ||
|
|
fa010c8644 | ||
|
|
d58f398bc4 | ||
|
|
11383a86a1 | ||
|
|
daa52ff8df | ||
|
|
a5f41e22f7 | ||
|
|
530bb5233d | ||
|
|
4a64e09f6c | ||
|
|
74582bb8d5 | ||
|
|
1ca2101e3a | ||
|
|
e80311c323 | ||
|
|
2f24c422b6 | ||
|
|
0d0b9fddef | ||
|
|
1753cc99f4 | ||
|
|
4f8b036abe | ||
|
|
f83c89c202 | ||
|
|
bb89a036e5 | ||
|
|
b994a03466 | ||
|
|
27161f8e3b | ||
|
|
8acf9a488b |
@@ -1,7 +1,8 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.9.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff-format-hook
|
- id: ruff
|
||||||
name: Check ruff formatting
|
language_version: python3
|
||||||
entry: sh scripts/pre-commit.sh
|
args: [ --select, I, ]
|
||||||
language: system
|
- id: ruff-format
|
||||||
|
|||||||
69
CHANGELOG.md
69
CHANGELOG.md
@@ -9,6 +9,65 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Added new base class `BaseObject` which is now the base class of
|
||||||
|
`FrameProcessor`, `PipelineRunner`, `PipelineTask` and `BaseTransport`. The
|
||||||
|
new `BaseObject` adds supports for event handlers.
|
||||||
|
|
||||||
|
- Added support for a unified format for specifying function calling across all
|
||||||
|
LLM services.
|
||||||
|
|
||||||
|
```python
|
||||||
|
weather_function = FunctionSchema(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
properties={
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required=["location"],
|
||||||
|
)
|
||||||
|
tools = ToolsSchema(standard_tools=[weather_function])
|
||||||
|
```
|
||||||
|
|
||||||
|
- Added `speech_threshold` parameter to `GladiaSTTService`.
|
||||||
|
|
||||||
|
- Allow passing user (`user_kwargs`) and assistant (`assistant_kwargs`) context
|
||||||
|
aggregator parameters when using `create_context_aggregator()`. The values are
|
||||||
|
passed as a mapping that will then be converted to arguments.
|
||||||
|
|
||||||
|
- Added `speed` as an `InputParam` for both `ElevenLabsTTSService` and
|
||||||
|
`ElevenLabsHttpTTSService`.
|
||||||
|
|
||||||
|
- Added new `LLMFullResponseAggregator` to aggregate full LLM completions. At
|
||||||
|
every completion the `on_completion` event handler is triggered.
|
||||||
|
|
||||||
|
- Added a new frame, `RTVIServerMessageFrame`, and RTVI message
|
||||||
|
`RTVIServerMessage` which provides a generic mechanism for sending custom
|
||||||
|
messages from server to client. The `RTVIServerMessageFrame` is processed by
|
||||||
|
the `RTVIObserver` and will be delivered to the client's `onServerMessage`
|
||||||
|
callback or `ServerMessage` event.
|
||||||
|
|
||||||
|
- Added `GoogleLLMOpenAIBetaService` for Google LLM integration with an
|
||||||
|
OpenAI-compatible interface. Added foundational example
|
||||||
|
`14o-function-calling-gemini-openai-format.py`.
|
||||||
|
|
||||||
|
- Added `AzureRealtimeBetaLLMService` to support Azure's OpeanAI Realtime API. Added
|
||||||
|
foundational example `19a-azure-realtime-beta.py`.
|
||||||
|
|
||||||
|
## [0.0.58] - 2025-02-26
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Added track-specific audio event `on_track_audio_data` to
|
||||||
|
`AudioBufferProcessor` for accessing separate input and output audio tracks.
|
||||||
|
|
||||||
- Pipecat version will now be logged on every application startup. This will
|
- Pipecat version will now be logged on every application startup. This will
|
||||||
help us identify what version we are running in case of any issues.
|
help us identify what version we are running in case of any issues.
|
||||||
|
|
||||||
@@ -45,6 +104,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- ⚠️ `PipelineTask` now requires keyword arguments (except for the first one for
|
- ⚠️ `PipelineTask` now requires keyword arguments (except for the first one for
|
||||||
the pipeline).
|
the pipeline).
|
||||||
|
|
||||||
|
- Updated `PlayHTHttpTTSService` to take a `voice_engine` and `protocol` input
|
||||||
|
in the constructor. The previous method of providing a `voice_engine` input
|
||||||
|
that contains the engine and protocol is deprecated by PlayHT.
|
||||||
|
|
||||||
- The base `TTSService` class now strips leading newlines before sending text
|
- The base `TTSService` class now strips leading newlines before sending text
|
||||||
to the TTS provider. This change is to solve issues where some TTS providers,
|
to the TTS provider. This change is to solve issues where some TTS providers,
|
||||||
like Azure, would not output text due to newlines.
|
like Azure, would not output text due to newlines.
|
||||||
@@ -78,6 +141,9 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
|
|||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Fixed a `GoogleLLMService` that was causing an exception when sending inline
|
||||||
|
audio in some cases.
|
||||||
|
|
||||||
- Fixed an `AudioContextWordTTSService` issue that would cause an `EndFrame` to
|
- Fixed an `AudioContextWordTTSService` issue that would cause an `EndFrame` to
|
||||||
disconnect from the TTS service before audio from all the contexts was
|
disconnect from the TTS service before audio from all the contexts was
|
||||||
received. This affected services like Cartesia and Rime.
|
received. This affected services like Cartesia and Rime.
|
||||||
@@ -124,6 +190,9 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
|
|||||||
|
|
||||||
- Added Gemini support to `examples/phone-chatbot`.
|
- Added Gemini support to `examples/phone-chatbot`.
|
||||||
|
|
||||||
|
- Added foundational example `34-audio-recording.py` showing how to use the
|
||||||
|
AudioBufferProcessor callbacks to save merged and track recordings.
|
||||||
|
|
||||||
## [0.0.57] - 2025-02-14
|
## [0.0.57] - 2025-02-14
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ async def main():
|
|||||||
llm,
|
llm,
|
||||||
tts,
|
tts,
|
||||||
transport.output(),
|
transport.output(),
|
||||||
audio_buffer_processor, # captures audio into a buffer
|
|
||||||
canonical, # uploads audio buffer to Canonical AI for metrics
|
canonical, # uploads audio buffer to Canonical AI for metrics
|
||||||
|
audio_buffer_processor, # captures audio into a buffer
|
||||||
context_aggregator.assistant(),
|
context_aggregator.assistant(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
103
examples/foundational/07d-interruptible-elevenlabs-http.py
Normal file
103
examples/foundational/07d-interruptible-elevenlabs-http.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||||
|
from pipecat.services.elevenlabs import ElevenLabsHttpTTSService
|
||||||
|
from pipecat.services.openai import OpenAILLMService
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Respond bot",
|
||||||
|
DailyParams(
|
||||||
|
audio_out_enabled=True,
|
||||||
|
transcription_enabled=True,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tts = ElevenLabsHttpTTSService(
|
||||||
|
api_key=os.getenv("ELEVENLABS_API_KEY", ""),
|
||||||
|
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||||
|
aiohttp_session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages)
|
||||||
|
context_aggregator = llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(), # Transport user input
|
||||||
|
context_aggregator.user(), # User responses
|
||||||
|
llm, # LLM
|
||||||
|
tts, # TTS
|
||||||
|
transport.output(), # Transport bot output
|
||||||
|
context_aggregator.assistant(), # Assistant spoken responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
enable_metrics=True,
|
||||||
|
enable_usage_metrics=True,
|
||||||
|
report_only_initial_ttfb=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
# Kick off the conversation.
|
||||||
|
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
@transport.event_handler("on_participant_left")
|
||||||
|
async def on_participant_left(transport, participant, reason):
|
||||||
|
await task.cancel()
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
|
||||||
|
await runner.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
103
examples/foundational/07q-interruptible-rime-http.py
Normal file
103
examples/foundational/07q-interruptible-rime-http.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||||
|
from pipecat.services.openai import OpenAILLMService
|
||||||
|
from pipecat.services.rime import RimeHttpTTSService
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Respond bot",
|
||||||
|
DailyParams(
|
||||||
|
audio_out_enabled=True,
|
||||||
|
transcription_enabled=True,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tts = RimeHttpTTSService(
|
||||||
|
api_key=os.getenv("RIME_API_KEY", ""),
|
||||||
|
voice_id="rex",
|
||||||
|
aiohttp_session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages)
|
||||||
|
context_aggregator = llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(), # Transport user input
|
||||||
|
context_aggregator.user(), # User responses
|
||||||
|
llm, # LLM
|
||||||
|
tts, # TTS
|
||||||
|
transport.output(), # Transport bot output
|
||||||
|
context_aggregator.assistant(), # Assistant spoken responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
enable_metrics=True,
|
||||||
|
enable_usage_metrics=True,
|
||||||
|
report_only_initial_ttfb=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
# Kick off the conversation.
|
||||||
|
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
@transport.event_handler("on_participant_left")
|
||||||
|
async def on_participant_left(transport, participant, reason):
|
||||||
|
await task.cancel()
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
|
||||||
|
await runner.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.frames.frames import TTSSpeakFrame
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||||
|
from pipecat.services.google import GoogleLLMOpenAIBetaService
|
||||||
|
from pipecat.services.openai import OpenAILLMContext
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
async def start_fetch_weather(function_name, llm, context):
|
||||||
|
"""Push a frame to the LLM; this is handy when the LLM response might take a while."""
|
||||||
|
await llm.push_frame(TTSSpeakFrame("Let me check on that."))
|
||||||
|
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||||
|
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Respond bot",
|
||||||
|
DailyParams(
|
||||||
|
audio_out_enabled=True,
|
||||||
|
transcription_enabled=True,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tts = ElevenLabsTTSService(
|
||||||
|
api_key=os.getenv("ELEVENLABS_API_KEY", ""),
|
||||||
|
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = GoogleLLMOpenAIBetaService(api_key=os.getenv("GEMINI_API_KEY"))
|
||||||
|
# Register a function_name of None to get all functions
|
||||||
|
# sent to the same callback with an additional function_name parameter.
|
||||||
|
llm.register_function(
|
||||||
|
"get_current_weather", fetch_weather_from_api, start_callback=start_fetch_weather
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function={
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Start a conversation with 'Hey there' to get the current weather.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages, tools)
|
||||||
|
context_aggregator = llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(),
|
||||||
|
context_aggregator.user(),
|
||||||
|
llm,
|
||||||
|
tts,
|
||||||
|
transport.output(),
|
||||||
|
context_aggregator.assistant(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
enable_metrics=True,
|
||||||
|
enable_usage_metrics=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
# Kick off the conversation.
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
|
||||||
|
await runner.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
179
examples/foundational/19a-azure-realtime-beta.py
Normal file
179
examples/foundational/19a-azure-realtime-beta.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||||
|
from pipecat.services.openai_realtime_beta import (
|
||||||
|
AzureRealtimeBetaLLMService,
|
||||||
|
InputAudioTranscription,
|
||||||
|
SessionProperties,
|
||||||
|
TurnDetection,
|
||||||
|
)
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||||
|
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||||
|
await result_callback(
|
||||||
|
{
|
||||||
|
"conditions": "nice",
|
||||||
|
"temperature": temperature,
|
||||||
|
"format": args["format"],
|
||||||
|
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Respond bot",
|
||||||
|
DailyParams(
|
||||||
|
audio_in_enabled=True,
|
||||||
|
audio_out_enabled=True,
|
||||||
|
transcription_enabled=False,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||||
|
vad_audio_passthrough=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
session_properties = SessionProperties(
|
||||||
|
input_audio_transcription=InputAudioTranscription(),
|
||||||
|
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||||
|
# on by default
|
||||||
|
# turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||||
|
# Or set to False to disable openai turn detection and use transport VAD
|
||||||
|
# turn_detection=False,
|
||||||
|
# tools=tools,
|
||||||
|
instructions="""Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||||
|
|
||||||
|
Act like a human, but remember that you aren't a human and that you can't do human
|
||||||
|
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||||
|
playful tone.
|
||||||
|
|
||||||
|
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||||
|
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||||
|
even if you're asked about them.
|
||||||
|
-
|
||||||
|
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||||
|
unless specifically asked to elaborate on a topic.
|
||||||
|
|
||||||
|
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = AzureRealtimeBetaLLMService(
|
||||||
|
api_key=os.getenv("AZURE_REALTIME_API_KEY"),
|
||||||
|
base_url=os.getenv("AZURE_REALTIME_BASE_URL"),
|
||||||
|
session_properties=session_properties,
|
||||||
|
start_audio_paused=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# you can either register a single function for all function calls, or specific functions
|
||||||
|
# llm.register_function(None, fetch_weather_from_api)
|
||||||
|
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||||
|
|
||||||
|
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||||
|
# OpenAIRealtimeBetaLLMService will convert this internally to messages that the
|
||||||
|
# openai WebSocket API can understand.
|
||||||
|
context = OpenAILLMContext(
|
||||||
|
[{"role": "user", "content": "Say hello!"}],
|
||||||
|
# [{"role": "user", "content": [{"type": "text", "text": "Say hello!"}]}],
|
||||||
|
# [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {"type": "text", "text": "Say"},
|
||||||
|
# {"type": "text", "text": "yo what's up!"},
|
||||||
|
# ],
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_aggregator = llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(), # Transport user input
|
||||||
|
context_aggregator.user(),
|
||||||
|
llm, # LLM
|
||||||
|
context_aggregator.assistant(),
|
||||||
|
transport.output(), # Transport bot output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
enable_metrics=True,
|
||||||
|
enable_usage_metrics=True,
|
||||||
|
# report_only_initial_ttfb=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
# Kick off the conversation.
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
|
||||||
|
await runner.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -23,7 +23,6 @@ from pipecat.frames.frames import (
|
|||||||
FunctionCallInProgressFrame,
|
FunctionCallInProgressFrame,
|
||||||
FunctionCallResultFrame,
|
FunctionCallResultFrame,
|
||||||
InputAudioRawFrame,
|
InputAudioRawFrame,
|
||||||
LLMFullResponseEndFrame,
|
|
||||||
LLMFullResponseStartFrame,
|
LLMFullResponseStartFrame,
|
||||||
StartFrame,
|
StartFrame,
|
||||||
StartInterruptionFrame,
|
StartInterruptionFrame,
|
||||||
@@ -37,7 +36,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
|||||||
from pipecat.pipeline.pipeline import Pipeline
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
from pipecat.pipeline.runner import PipelineRunner
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
from pipecat.processors.aggregators.llm_response import LLMResponseAggregator
|
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator
|
||||||
from pipecat.processors.aggregators.openai_llm_context import (
|
from pipecat.processors.aggregators.openai_llm_context import (
|
||||||
OpenAILLMContext,
|
OpenAILLMContext,
|
||||||
OpenAILLMContextFrame,
|
OpenAILLMContextFrame,
|
||||||
@@ -389,7 +388,7 @@ class AudioAccumulator(FrameProcessor):
|
|||||||
)
|
)
|
||||||
self._user_speaking = False
|
self._user_speaking = False
|
||||||
context = GoogleLLMContext()
|
context = GoogleLLMContext()
|
||||||
context.add_audio_frames_message(text="Audio follows", audio_frames=self._audio_frames)
|
context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||||
await self.push_frame(OpenAILLMContextFrame(context=context))
|
await self.push_frame(OpenAILLMContextFrame(context=context))
|
||||||
elif isinstance(frame, InputAudioRawFrame):
|
elif isinstance(frame, InputAudioRawFrame):
|
||||||
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
||||||
@@ -432,7 +431,11 @@ class CompletenessCheck(FrameProcessor):
|
|||||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
await super().process_frame(frame, direction)
|
await super().process_frame(frame, direction)
|
||||||
|
|
||||||
if isinstance(frame, UserStartedSpeakingFrame):
|
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||||
|
if self._idle_task:
|
||||||
|
await self.cancel_task(self._idle_task)
|
||||||
|
self._idle_task = None
|
||||||
|
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||||
if self._idle_task:
|
if self._idle_task:
|
||||||
await self.cancel_task(self._idle_task)
|
await self.cancel_task(self._idle_task)
|
||||||
elif isinstance(frame, TextFrame) and frame.text.startswith("YES"):
|
elif isinstance(frame, TextFrame) and frame.text.startswith("YES"):
|
||||||
@@ -474,19 +477,11 @@ class CompletenessCheck(FrameProcessor):
|
|||||||
self._idle_task = None
|
self._idle_task = None
|
||||||
|
|
||||||
|
|
||||||
class UserAggregatorBuffer(LLMResponseAggregator):
|
class LLMAggregatorBuffer(LLMAssistantResponseAggregator):
|
||||||
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
|
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(
|
super().__init__(expect_stripped_words=False)
|
||||||
messages=None,
|
|
||||||
role=None,
|
|
||||||
start_frame=LLMFullResponseStartFrame,
|
|
||||||
end_frame=LLMFullResponseEndFrame,
|
|
||||||
accumulator_frame=TextFrame,
|
|
||||||
handle_interruptions=True,
|
|
||||||
expect_stripped_words=False,
|
|
||||||
)
|
|
||||||
self._transcription = ""
|
self._transcription = ""
|
||||||
|
|
||||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
@@ -544,7 +539,7 @@ class OutputGate(FrameProcessor):
|
|||||||
self,
|
self,
|
||||||
notifier: BaseNotifier,
|
notifier: BaseNotifier,
|
||||||
context: OpenAILLMContext,
|
context: OpenAILLMContext,
|
||||||
user_transcription_buffer: "UserAggregatorBuffer",
|
llm_transcription_buffer: LLMAggregatorBuffer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -552,7 +547,7 @@ class OutputGate(FrameProcessor):
|
|||||||
self._frames_buffer = []
|
self._frames_buffer = []
|
||||||
self._notifier = notifier
|
self._notifier = notifier
|
||||||
self._context = context
|
self._context = context
|
||||||
self._transcription_buffer = user_transcription_buffer
|
self._transcription_buffer = llm_transcription_buffer
|
||||||
self._gate_task = None
|
self._gate_task = None
|
||||||
|
|
||||||
def close_gate(self):
|
def close_gate(self):
|
||||||
@@ -699,10 +694,10 @@ async def main():
|
|||||||
|
|
||||||
conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context)
|
conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context)
|
||||||
|
|
||||||
user_aggregator_buffer = UserAggregatorBuffer()
|
llm_aggregator_buffer = LLMAggregatorBuffer()
|
||||||
|
|
||||||
bot_output_gate = OutputGate(
|
bot_output_gate = OutputGate(
|
||||||
notifier=notifier, context=context, user_transcription_buffer=user_aggregator_buffer
|
notifier=notifier, context=context, llm_transcription_buffer=llm_aggregator_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
@@ -723,7 +718,7 @@ async def main():
|
|||||||
],
|
],
|
||||||
[
|
[
|
||||||
tx_llm,
|
tx_llm,
|
||||||
user_aggregator_buffer,
|
llm_aggregator_buffer,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|||||||
186
examples/foundational/34-audio-recording.py
Normal file
186
examples/foundational/34-audio-recording.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Audio Recording Example with Pipecat.
|
||||||
|
|
||||||
|
This example demonstrates how to record audio from a conversation between a user and an AI assistant,
|
||||||
|
saving both merged and individual audio tracks. It showcases the AudioBufferProcessor's capabilities
|
||||||
|
to handle both combined and separate audio streams.
|
||||||
|
|
||||||
|
The example:
|
||||||
|
1. Sets up a basic conversation with an AI assistant
|
||||||
|
2. Records the entire conversation
|
||||||
|
3. Saves three separate WAV files:
|
||||||
|
- A merged recording of both participants
|
||||||
|
- Individual recording of user audio
|
||||||
|
- Individual recording of assistant audio
|
||||||
|
|
||||||
|
Example usage (run from pipecat root directory):
|
||||||
|
$ pip install "pipecat-ai[daily,openai,cartesia,silero]"
|
||||||
|
$ pip install -r dev-requirements.txt
|
||||||
|
$ python examples/foundational/34-audio-recording.py
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- OpenAI API key (for GPT-4)
|
||||||
|
- Cartesia API key (for text-to-speech)
|
||||||
|
- Daily API key (for video/audio transport)
|
||||||
|
|
||||||
|
Environment variables (.env file):
|
||||||
|
OPENAI_API_KEY=your_openai_key
|
||||||
|
CARTESIA_API_KEY=your_cartesia_key
|
||||||
|
DAILY_API_KEY=your_daily_key
|
||||||
|
|
||||||
|
The recordings will be saved in a 'recordings' directory with timestamps:
|
||||||
|
recordings/
|
||||||
|
merged_20240315_123456.wav (Combined audio)
|
||||||
|
user_20240315_123456.wav (User audio only)
|
||||||
|
bot_20240315_123456.wav (Bot audio only)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This example requires the AudioBufferProcessor with track-specific audio support,
|
||||||
|
which provides both 'on_audio_data' and 'on_track_audio_data' events for
|
||||||
|
handling merged and separate audio tracks respectively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import wave
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||||
|
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||||
|
from pipecat.services.cartesia import CartesiaTTSService
|
||||||
|
from pipecat.services.openai import OpenAILLMService
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
async def save_audio_file(audio: bytes, filename: str, sample_rate: int, num_channels: int):
|
||||||
|
"""Save audio data to a WAV file."""
|
||||||
|
if len(audio) > 0:
|
||||||
|
with io.BytesIO() as buffer:
|
||||||
|
with wave.open(buffer, "wb") as wf:
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setnchannels(num_channels)
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
wf.writeframes(audio)
|
||||||
|
async with aiofiles.open(filename, "wb") as file:
|
||||||
|
await file.write(buffer.getvalue())
|
||||||
|
logger.info(f"Audio saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Recording bot",
|
||||||
|
DailyParams(
|
||||||
|
# audio_in_enabled=True,
|
||||||
|
audio_out_enabled=True,
|
||||||
|
transcription_enabled=True,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
vad_audio_passthrough=True, # Enable audio passthrough for recording
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tts = CartesiaTTSService(
|
||||||
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||||
|
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22",
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4")
|
||||||
|
|
||||||
|
# Create audio buffer processor
|
||||||
|
audiobuffer = AudioBufferProcessor()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant demonstrating audio recording capabilities. Keep your responses brief and clear.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages)
|
||||||
|
context_aggregator = llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(),
|
||||||
|
context_aggregator.user(),
|
||||||
|
llm,
|
||||||
|
tts,
|
||||||
|
transport.output(),
|
||||||
|
audiobuffer, # Add audio buffer to pipeline
|
||||||
|
context_aggregator.assistant(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
await audiobuffer.start_recording()
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Greet the user and explain that this conversation will be recorded.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
@transport.event_handler("on_participant_left")
|
||||||
|
async def on_participant_left(transport, participant, reason):
|
||||||
|
await audiobuffer.stop_recording()
|
||||||
|
await task.cancel()
|
||||||
|
|
||||||
|
# Handler for merged audio
|
||||||
|
@audiobuffer.event_handler("on_audio_data")
|
||||||
|
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"recordings/merged_{timestamp}.wav"
|
||||||
|
os.makedirs("recordings", exist_ok=True)
|
||||||
|
await save_audio_file(audio, filename, sample_rate, num_channels)
|
||||||
|
|
||||||
|
# Handler for separate tracks
|
||||||
|
@audiobuffer.event_handler("on_track_audio_data")
|
||||||
|
async def on_track_audio_data(buffer, user_audio, bot_audio, sample_rate, num_channels):
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
os.makedirs("recordings", exist_ok=True)
|
||||||
|
|
||||||
|
# Save user audio
|
||||||
|
user_filename = f"recordings/user_{timestamp}.wav"
|
||||||
|
await save_audio_file(user_audio, user_filename, sample_rate, 1)
|
||||||
|
|
||||||
|
# Save bot audio
|
||||||
|
bot_filename = f"recordings/bot_{timestamp}.wav"
|
||||||
|
await save_audio_file(bot_audio, bot_filename, sample_rate, 1)
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
await runner.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -7,20 +7,17 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import google.ai.generativelanguage as glm
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
BotStoppedSpeakingFrame,
|
EndFrame,
|
||||||
EndTaskFrame,
|
EndTaskFrame,
|
||||||
Frame,
|
|
||||||
InputAudioRawFrame,
|
InputAudioRawFrame,
|
||||||
SystemFrame,
|
StopTaskFrame,
|
||||||
TranscriptionFrame,
|
TranscriptionFrame,
|
||||||
UserStartedSpeakingFrame,
|
UserStartedSpeakingFrame,
|
||||||
UserStoppedSpeakingFrame,
|
UserStoppedSpeakingFrame,
|
||||||
@@ -28,12 +25,17 @@ from pipecat.frames.frames import (
|
|||||||
from pipecat.pipeline.pipeline import Pipeline
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
from pipecat.pipeline.runner import PipelineRunner
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
|
||||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
from pipecat.services.ai_services import LLMService
|
from pipecat.services.ai_services import LLMService
|
||||||
|
from pipecat.services.deepgram import DeepgramSTTService
|
||||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||||
from pipecat.services.google import GoogleLLMContext, GoogleLLMService
|
from pipecat.services.google import GoogleLLMService
|
||||||
from pipecat.transports.services.daily import DailyDialinSettings, DailyParams, DailyTransport
|
from pipecat.services.google.google import GoogleLLMContext
|
||||||
|
from pipecat.transports.services.daily import (
|
||||||
|
DailyDialinSettings,
|
||||||
|
DailyParams,
|
||||||
|
DailyTransport,
|
||||||
|
)
|
||||||
|
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
|
|
||||||
@@ -44,6 +46,8 @@ logger.add(sys.stderr, level="DEBUG")
|
|||||||
daily_api_key = os.getenv("DAILY_API_KEY", "")
|
daily_api_key = os.getenv("DAILY_API_KEY", "")
|
||||||
daily_api_url = os.getenv("DAILY_API_URL", "https://api.daily.co/v1")
|
daily_api_url = os.getenv("DAILY_API_URL", "https://api.daily.co/v1")
|
||||||
|
|
||||||
|
system_message = None
|
||||||
|
|
||||||
|
|
||||||
class UserAudioCollector(FrameProcessor):
|
class UserAudioCollector(FrameProcessor):
|
||||||
"""This FrameProcessor collects audio frames in a buffer, then adds them to the
|
"""This FrameProcessor collects audio frames in a buffer, then adds them to the
|
||||||
@@ -117,7 +121,13 @@ class FunctionHandlers:
|
|||||||
self.context_switcher = context_switcher
|
self.context_switcher = context_switcher
|
||||||
|
|
||||||
async def voicemail_response(
|
async def voicemail_response(
|
||||||
self, function_name, tool_call_id, args, llm, context, result_callback
|
self,
|
||||||
|
function_name,
|
||||||
|
tool_call_id,
|
||||||
|
args,
|
||||||
|
llm: LLMService,
|
||||||
|
context,
|
||||||
|
result_callback,
|
||||||
):
|
):
|
||||||
"""Function the bot can call to leave a voicemail message."""
|
"""Function the bot can call to leave a voicemail message."""
|
||||||
message = """You are Chatbot leaving a voicemail message. Say EXACTLY this message and nothing else:
|
message = """You are Chatbot leaving a voicemail message. Say EXACTLY this message and nothing else:
|
||||||
@@ -127,62 +137,48 @@ class FunctionHandlers:
|
|||||||
After saying this message, call the terminate_call function."""
|
After saying this message, call the terminate_call function."""
|
||||||
|
|
||||||
await self.context_switcher.switch_context(system_instruction=message)
|
await self.context_switcher.switch_context(system_instruction=message)
|
||||||
|
|
||||||
await result_callback("Leaving a voicemail message")
|
await result_callback("Leaving a voicemail message")
|
||||||
|
|
||||||
async def human_conversation(
|
async def human_conversation(
|
||||||
self, function_name, tool_call_id, args, llm, context, result_callback
|
self,
|
||||||
|
function_name,
|
||||||
|
tool_call_id,
|
||||||
|
args,
|
||||||
|
llm: LLMService,
|
||||||
|
context,
|
||||||
|
result_callback,
|
||||||
):
|
):
|
||||||
"""Function the bot can when it detects it's talking to a human."""
|
"""Function the bot can when it detects it's talking to a human."""
|
||||||
message = """You are Chatbot talking to a human. Be friendly and helpful.
|
await llm.push_frame(StopTaskFrame(), FrameDirection.UPSTREAM)
|
||||||
|
|
||||||
Start with: "Hello! I'm a friendly chatbot. How can I help you today?"
|
|
||||||
|
|
||||||
Keep your responses brief and to the point. Listen to what the person says.
|
|
||||||
|
|
||||||
When the person indicates they're done with the conversation by saying something like:
|
|
||||||
- "Goodbye"
|
|
||||||
- "That's all"
|
|
||||||
- "I'm done"
|
|
||||||
- "Thank you, that's all I needed"
|
|
||||||
|
|
||||||
THEN say: "Thank you for chatting. Goodbye!" and call the terminate_call function."""
|
|
||||||
|
|
||||||
await self.context_switcher.switch_context(system_instruction=message)
|
|
||||||
|
|
||||||
await result_callback("Talking to the customer")
|
|
||||||
|
|
||||||
|
|
||||||
async def terminate_call(
|
async def terminate_call(
|
||||||
function_name, tool_call_id, args, llm: LLMService, context, result_callback
|
function_name,
|
||||||
|
tool_call_id,
|
||||||
|
args,
|
||||||
|
llm: LLMService,
|
||||||
|
context,
|
||||||
|
result_callback,
|
||||||
|
call_state=None,
|
||||||
):
|
):
|
||||||
"""Function the bot can call to terminate the call upon completion of the call."""
|
"""Function the bot can call to terminate the call upon completion of the call."""
|
||||||
|
if call_state:
|
||||||
await llm.queue_frame(EndTaskFrame(), FrameDirection.UPSTREAM)
|
call_state.bot_terminated_call = True
|
||||||
|
await llm.push_frame(EndTaskFrame(), FrameDirection.UPSTREAM)
|
||||||
|
|
||||||
|
|
||||||
async def main(
|
async def main(
|
||||||
room_url: str,
|
room_url: str,
|
||||||
token: str,
|
token: str,
|
||||||
callId: str,
|
callId: Optional[str],
|
||||||
callDomain: str,
|
callDomain: Optional[str],
|
||||||
detect_voicemail: bool,
|
detect_voicemail: bool,
|
||||||
dialout_number: Optional[str],
|
dialout_number: Optional[str],
|
||||||
):
|
):
|
||||||
# dialin_settings are only needed if Daily's SIP URI is used
|
|
||||||
# If you are handling this via Twilio, Telnyx, set this to None
|
|
||||||
# and handle call-forwarding when on_dialin_ready fires.
|
|
||||||
|
|
||||||
# We don't want to specify dial-in settings if we're not dialing in
|
|
||||||
dialin_settings = None
|
dialin_settings = None
|
||||||
if callId and callDomain:
|
if callId and callDomain:
|
||||||
dialin_settings = DailyDialinSettings(call_id=callId, call_domain=callDomain)
|
dialin_settings = DailyDialinSettings(call_id=callId, call_domain=callDomain)
|
||||||
|
transport_params = DailyParams(
|
||||||
transport = DailyTransport(
|
|
||||||
room_url,
|
|
||||||
token,
|
|
||||||
"Chatbot",
|
|
||||||
DailyParams(
|
|
||||||
api_url=daily_api_url,
|
api_url=daily_api_url,
|
||||||
api_key=daily_api_key,
|
api_key=daily_api_key,
|
||||||
dialin_settings=dialin_settings,
|
dialin_settings=dialin_settings,
|
||||||
@@ -192,8 +188,30 @@ async def main(
|
|||||||
vad_enabled=True,
|
vad_enabled=True,
|
||||||
vad_analyzer=SileroVADAnalyzer(),
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
vad_audio_passthrough=True,
|
vad_audio_passthrough=True,
|
||||||
# transcription_enabled=True,
|
)
|
||||||
),
|
else:
|
||||||
|
transport_params = DailyParams(
|
||||||
|
api_url=daily_api_url,
|
||||||
|
api_key=daily_api_key,
|
||||||
|
audio_in_enabled=True,
|
||||||
|
audio_out_enabled=True,
|
||||||
|
camera_out_enabled=False,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
vad_audio_passthrough=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class CallState:
|
||||||
|
participant_left_early = False
|
||||||
|
bot_terminated_call = False
|
||||||
|
|
||||||
|
call_state = CallState()
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Chatbot",
|
||||||
|
transport_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
tts = ElevenLabsTTSService(
|
tts = ElevenLabsTTSService(
|
||||||
@@ -201,6 +219,10 @@ async def main(
|
|||||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||||
|
|
||||||
|
### VOICEMAIL PIPELINE
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
"function_declarations": [
|
"function_declarations": [
|
||||||
@@ -222,55 +244,67 @@ async def main(
|
|||||||
|
|
||||||
system_instruction = """You are Chatbot trying to determine if this is a voicemail system or a human.
|
system_instruction = """You are Chatbot trying to determine if this is a voicemail system or a human.
|
||||||
|
|
||||||
If you hear any of these phrases (or very similar ones):
|
If you hear any of these phrases (or very similar ones):
|
||||||
- "Please leave a message after the beep"
|
- "Please leave a message after the beep"
|
||||||
- "No one is available to take your call"
|
- "No one is available to take your call"
|
||||||
- "Record your message after the tone"
|
- "Record your message after the tone"
|
||||||
- "You have reached voicemail for..."
|
- "You have reached voicemail for..."
|
||||||
- "You have reached [phone number]"
|
- "You have reached [phone number]"
|
||||||
- "[phone number] is unavailable"
|
- "[phone number] is unavailable"
|
||||||
- "The person you are trying to reach..."
|
- "The person you are trying to reach..."
|
||||||
- "The number you have dialed..."
|
- "The number you have dialed..."
|
||||||
- "Your call has been forwarded to an automated voice messaging system"
|
- "Your call has been forwarded to an automated voice messaging system"
|
||||||
|
|
||||||
Then call the function switch_to_voicemail_response.
|
Then call the function switch_to_voicemail_response.
|
||||||
|
|
||||||
If it sounds like a human (saying hello, asking questions, etc.), call the function switch_to_human_conversation.
|
If it sounds like a human (saying hello, asking questions, etc.), call the function switch_to_human_conversation.
|
||||||
|
|
||||||
DO NOT say anything until you've determined if this is a voicemail or human."""
|
DO NOT say anything until you've determined if this is a voicemail or human."""
|
||||||
|
|
||||||
llm = GoogleLLMService(
|
voicemail_detection_llm = GoogleLLMService(
|
||||||
model="models/gemini-2.0-flash-lite-preview-02-05",
|
model="models/gemini-2.0-flash-lite",
|
||||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||||
system_instruction=system_instruction,
|
system_instruction=system_instruction,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = GoogleLLMContext()
|
voicemail_detection_context = GoogleLLMContext()
|
||||||
context_aggregator = llm.create_context_aggregator(context)
|
voicemail_detection_context_aggregator = voicemail_detection_llm.create_context_aggregator(
|
||||||
audio_collector = UserAudioCollector(context, context_aggregator.user())
|
voicemail_detection_context
|
||||||
|
)
|
||||||
context_switcher = ContextSwitcher(llm, context_aggregator.user())
|
context_switcher = ContextSwitcher(
|
||||||
|
voicemail_detection_llm, voicemail_detection_context_aggregator.user()
|
||||||
|
)
|
||||||
handlers = FunctionHandlers(context_switcher)
|
handlers = FunctionHandlers(context_switcher)
|
||||||
|
|
||||||
llm.register_function("switch_to_voicemail_response", handlers.voicemail_response)
|
voicemail_detection_llm.register_function(
|
||||||
llm.register_function("switch_to_human_conversation", handlers.human_conversation)
|
"switch_to_voicemail_response", handlers.voicemail_response
|
||||||
llm.register_function("terminate_call", terminate_call)
|
)
|
||||||
|
voicemail_detection_llm.register_function(
|
||||||
pipeline = Pipeline(
|
"switch_to_human_conversation", handlers.human_conversation
|
||||||
[
|
)
|
||||||
transport.input(), # Transport user input
|
voicemail_detection_llm.register_function(
|
||||||
audio_collector, # Collect audio frames
|
"terminate_call",
|
||||||
context_aggregator.user(), # User responses
|
lambda *args, **kwargs: terminate_call(*args, **kwargs, call_state=call_state),
|
||||||
llm, # LLM
|
|
||||||
tts, # TTS
|
|
||||||
transport.output(), # Transport bot output
|
|
||||||
context_aggregator.assistant(), # Assistant spoken responses
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
task = PipelineTask(
|
voicemail_detection_audio_collector = UserAudioCollector(
|
||||||
pipeline,
|
voicemail_detection_context, voicemail_detection_context_aggregator.user()
|
||||||
|
)
|
||||||
|
|
||||||
|
voicemail_detection_pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(), # Transport user input
|
||||||
|
voicemail_detection_audio_collector, # Collect audio frames
|
||||||
|
voicemail_detection_context_aggregator.user(), # User responses
|
||||||
|
voicemail_detection_llm, # LLM
|
||||||
|
tts, # TTS
|
||||||
|
transport.output(), # Transport bot output
|
||||||
|
voicemail_detection_context_aggregator.assistant(), # Assistant spoken responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
voicemail_detection_pipeline_task = PipelineTask(
|
||||||
|
voicemail_detection_pipeline,
|
||||||
params=PipelineParams(allow_interruptions=True),
|
params=PipelineParams(allow_interruptions=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -305,25 +339,116 @@ DO NOT say anything until you've determined if this is a voicemail or human."""
|
|||||||
# machine to say something like 'Leave a message after the beep', or for the user to say 'Hello?'.
|
# machine to say something like 'Leave a message after the beep', or for the user to say 'Hello?'.
|
||||||
@transport.event_handler("on_first_participant_joined")
|
@transport.event_handler("on_first_participant_joined")
|
||||||
async def on_first_participant_joined(transport, participant):
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
logger.debug("Detect voicemail; capturing participant transcription")
|
||||||
await transport.capture_participant_transcription(participant["id"])
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
else:
|
else:
|
||||||
logger.debug("no dialout number; assuming dialin")
|
logger.debug("+++++ No dialout number; assuming dialin")
|
||||||
|
|
||||||
# Different handlers for dialin
|
# Different handlers for dialin
|
||||||
@transport.event_handler("on_first_participant_joined")
|
@transport.event_handler("on_first_participant_joined")
|
||||||
async def on_first_participant_joined(transport, participant):
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
# This event is not firing for some reason
|
||||||
await transport.capture_participant_transcription(participant["id"])
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
# For the dialin case, we want the bot to answer the phone and greet the user. We
|
dialin_instructions = """Always call the function switch_to_human_conversation"""
|
||||||
# can prompt the bot to speak by putting the context into the pipeline.
|
messages = [
|
||||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
{
|
||||||
|
"role": "system",
|
||||||
@transport.event_handler("on_participant_left")
|
"content": dialin_instructions,
|
||||||
async def on_participant_left(transport, participant, reason):
|
}
|
||||||
await task.cancel()
|
]
|
||||||
|
voicemail_detection_context_aggregator.user().set_messages(messages)
|
||||||
|
await voicemail_detection_pipeline_task.queue_frames(
|
||||||
|
[voicemail_detection_context_aggregator.user().get_context_frame()]
|
||||||
|
)
|
||||||
|
|
||||||
runner = PipelineRunner()
|
runner = PipelineRunner()
|
||||||
|
|
||||||
await runner.run(task)
|
@transport.event_handler("on_participant_left")
|
||||||
|
async def on_participant_left(transport, participant, reason):
|
||||||
|
call_state.participant_left_early = True
|
||||||
|
await voicemail_detection_pipeline_task.queue_frame(EndFrame())
|
||||||
|
|
||||||
|
print("!!! starting voicemail detection pipeline")
|
||||||
|
await runner.run(voicemail_detection_pipeline_task)
|
||||||
|
print("!!! Done with voicemail detection pipeline")
|
||||||
|
|
||||||
|
if call_state.participant_left_early or call_state.bot_terminated_call:
|
||||||
|
if call_state.participant_left_early:
|
||||||
|
print("!!! Participant left early; terminating call")
|
||||||
|
elif call_state.bot_terminated_call:
|
||||||
|
print("!!! Bot terminated call; not proceeding to human conversation")
|
||||||
|
return
|
||||||
|
|
||||||
|
### HUMAN CONVERSATION PIPELINE
|
||||||
|
|
||||||
|
human_conversation_system_instruction = """You are Chatbot talking to a human. Be friendly and helpful.
|
||||||
|
|
||||||
|
Start with: "Hello! I'm a friendly chatbot. How can I help you today?"
|
||||||
|
|
||||||
|
Keep your responses brief and to the point. Listen to what the person says.
|
||||||
|
|
||||||
|
When the person indicates they're done with the conversation by saying something like:
|
||||||
|
- "Goodbye"
|
||||||
|
- "That's all"
|
||||||
|
- "I'm done"
|
||||||
|
- "Thank you, that's all I needed"
|
||||||
|
|
||||||
|
THEN say: "Thank you for chatting. Goodbye!" and call the terminate_call function."""
|
||||||
|
|
||||||
|
human_conversation_llm = GoogleLLMService(
|
||||||
|
model="models/gemini-2.0-flash-001",
|
||||||
|
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||||
|
system_instruction=human_conversation_system_instruction,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
human_conversation_context = GoogleLLMContext()
|
||||||
|
|
||||||
|
human_conversation_context_aggregator = human_conversation_llm.create_context_aggregator(
|
||||||
|
human_conversation_context
|
||||||
|
)
|
||||||
|
|
||||||
|
human_conversation_llm.register_function(
|
||||||
|
"terminate_call",
|
||||||
|
lambda *args, **kwargs: terminate_call(*args, **kwargs, call_state=call_state),
|
||||||
|
)
|
||||||
|
|
||||||
|
human_conversation_pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(), # Transport user input
|
||||||
|
stt,
|
||||||
|
human_conversation_context_aggregator.user(), # User responses
|
||||||
|
human_conversation_llm, # LLM
|
||||||
|
tts, # TTS
|
||||||
|
transport.output(), # Transport bot output
|
||||||
|
human_conversation_context_aggregator.assistant(), # Assistant spoken responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
human_conversation_pipeline_task = PipelineTask(
|
||||||
|
human_conversation_pipeline,
|
||||||
|
params=PipelineParams(allow_interruptions=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_participant_left")
|
||||||
|
async def on_participant_left(transport, participant, reason):
|
||||||
|
await voicemail_detection_pipeline_task.queue_frame(EndFrame())
|
||||||
|
await human_conversation_pipeline_task.queue_frame(EndFrame())
|
||||||
|
|
||||||
|
print("!!! starting human conversation pipeline")
|
||||||
|
human_conversation_context_aggregator.user().set_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": human_conversation_system_instruction,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await human_conversation_pipeline_task.queue_frames(
|
||||||
|
[human_conversation_context_aggregator.user().get_context_frame()]
|
||||||
|
)
|
||||||
|
await runner.run(human_conversation_pipeline_task)
|
||||||
|
|
||||||
|
print("!!! Done with human conversation pipeline")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.frames.frames import TTSSpeakFrame
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||||
|
from pipecat.services.ai_services import LLMService
|
||||||
|
from pipecat.services.cartesia import CartesiaTTSService
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_fetch_weather(function_name, llm, context):
|
||||||
|
"""Push a frame to the LLM; this is handy when the LLM response might take a while."""
|
||||||
|
await llm.push_frame(TTSSpeakFrame("Let me check on that."))
|
||||||
|
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||||
|
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherBot:
|
||||||
|
"""Generic base class for setting up and running an LLM-powered bot."""
|
||||||
|
|
||||||
|
def __init__(self, llm: LLMService):
|
||||||
|
"""Initialize the base handler with a specific LLM."""
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Set up and start the processing pipeline."""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Respond bot",
|
||||||
|
DailyParams(
|
||||||
|
audio_out_enabled=True,
|
||||||
|
transcription_enabled=True,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tts = CartesiaTTSService(
|
||||||
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||||
|
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register a function_name of None to get all functions
|
||||||
|
# sent to the same callback with an additional function_name parameter.
|
||||||
|
self.llm.register_function(
|
||||||
|
None, fetch_weather_from_api, start_callback=start_fetch_weather
|
||||||
|
)
|
||||||
|
|
||||||
|
weather_function = FunctionSchema(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
properties={
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required=["location"],
|
||||||
|
)
|
||||||
|
tools = ToolsSchema(standard_tools=[weather_function])
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant who can report the weather in any location in the universe. Respond concisely. Your response will be turned into speech so use only simple words and punctuation.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": " Start the conversation by introducing yourself."},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages, tools)
|
||||||
|
context_aggregator = self.llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(),
|
||||||
|
context_aggregator.user(),
|
||||||
|
self.llm,
|
||||||
|
tts,
|
||||||
|
transport.output(),
|
||||||
|
context_aggregator.assistant(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
enable_metrics=True,
|
||||||
|
enable_usage_metrics=True,
|
||||||
|
report_only_initial_ttfb=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
await runner.run(task)
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from runner import configure
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.frames.frames import TTSSpeakFrame
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||||
|
from pipecat.services.ai_services import LLMService
|
||||||
|
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(sys.stderr, level="DEBUG")
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_fetch_weather(function_name, llm, context):
|
||||||
|
"""Push a frame to the LLM; this is handy when the LLM response might take a while."""
|
||||||
|
await llm.push_frame(TTSSpeakFrame("Let me check on that."))
|
||||||
|
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||||
|
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalWeatherBot:
|
||||||
|
"""Generic base class for setting up and running an LLM-powered bot."""
|
||||||
|
|
||||||
|
def __init__(self, llm: LLMService):
|
||||||
|
"""Initialize the base handler with a specific LLM."""
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tools() -> ToolsSchema:
|
||||||
|
weather_function = FunctionSchema(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
properties={
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required=["location"],
|
||||||
|
)
|
||||||
|
return ToolsSchema(standard_tools=[weather_function])
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Set up and start the processing pipeline."""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
(room_url, token) = await configure(session)
|
||||||
|
|
||||||
|
transport = DailyTransport(
|
||||||
|
room_url,
|
||||||
|
token,
|
||||||
|
"Respond bot",
|
||||||
|
DailyParams(
|
||||||
|
audio_out_enabled=True,
|
||||||
|
vad_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
vad_audio_passthrough=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register a function_name of None to get all functions
|
||||||
|
# sent to the same callback with an additional function_name parameter.
|
||||||
|
self.llm.register_function(
|
||||||
|
None, fetch_weather_from_api, start_callback=start_fetch_weather
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant who can report the weather in any location in the universe. Respond concisely. Your response will be turned into speech so use only simple words and punctuation.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": " Start the conversation by introducing yourself."},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages, MultimodalWeatherBot.tools())
|
||||||
|
context_aggregator = self.llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(),
|
||||||
|
context_aggregator.user(),
|
||||||
|
self.llm,
|
||||||
|
transport.output(),
|
||||||
|
context_aggregator.assistant(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_first_participant_joined")
|
||||||
|
async def on_first_participant_joined(transport, participant):
|
||||||
|
await transport.capture_participant_transcription(participant["id"])
|
||||||
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
runner = PipelineRunner()
|
||||||
|
await runner.run(task)
|
||||||
64
examples/unified-format-function-calling/runner.py
Normal file
64
examples/unified-format-function-calling/runner.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper
|
||||||
|
|
||||||
|
|
||||||
|
async def configure(aiohttp_session: aiohttp.ClientSession):
|
||||||
|
(url, token, _) = await configure_with_args(aiohttp_session)
|
||||||
|
return (url, token)
|
||||||
|
|
||||||
|
|
||||||
|
async def configure_with_args(
|
||||||
|
aiohttp_session: aiohttp.ClientSession, parser: Optional[argparse.ArgumentParser] = None
|
||||||
|
):
|
||||||
|
if not parser:
|
||||||
|
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")
|
||||||
|
parser.add_argument(
|
||||||
|
"-u", "--url", type=str, required=False, help="URL of the Daily room to join"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-k",
|
||||||
|
"--apikey",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Daily API Key (needed to create an owner token for the room)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
url = args.url or os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||||
|
key = args.apikey or os.getenv("DAILY_API_KEY")
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
raise Exception(
|
||||||
|
"No Daily room specified. use the -u/--url option from the command line, or set DAILY_SAMPLE_ROOM_URL in your environment to specify a Daily room URL."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
raise Exception(
|
||||||
|
"No Daily API key specified. use the -k/--apikey option from the command line, or set DAILY_API_KEY in your environment to specify a Daily API key, available from https://dashboard.daily.co/developers."
|
||||||
|
)
|
||||||
|
|
||||||
|
daily_rest_helper = DailyRESTHelper(
|
||||||
|
daily_api_key=key,
|
||||||
|
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||||
|
aiohttp_session=aiohttp_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a meeting token for the given room with an expiration 1 hour in
|
||||||
|
# the future.
|
||||||
|
expiry_time: float = 60 * 60
|
||||||
|
|
||||||
|
token = await daily_rest_helper.get_token(url, expiry_time)
|
||||||
|
|
||||||
|
return (url, token, args)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.anthropic import AnthropicLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = AnthropicLLMService(
|
||||||
|
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20240620"
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(AnthropicWeatherBot().run())
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.azure import AzureLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = AzureLLMService(
|
||||||
|
api_key=os.getenv("AZURE_CHATGPT_API_KEY"),
|
||||||
|
endpoint=os.getenv("AZURE_CHATGPT_ENDPOINT"),
|
||||||
|
model=os.getenv("AZURE_CHATGPT_MODEL"),
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(AzureWeatherBot().run())
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.cerebras import CerebrasLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CerebrasWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = CerebrasLLMService(api_key=os.getenv("CEREBRAS_API_KEY"), model="llama-3.3-70b")
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(CerebrasWeatherBot().run())
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.deepseek import DeepSeekLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = DeepSeekLLMService(api_key=os.getenv("DEEPSEEK_API_KEY"), model="deepseek-chat")
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(DeepSeekWeatherBot().run())
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.fireworks import FireworksLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = FireworksLLMService(
|
||||||
|
api_key=os.getenv("FIREWORKS_API_KEY"),
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(FireworksWeatherBot().run())
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from multimodal_base_function_calling import MultimodalWeatherBot
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.tools_schema import AdapterType
|
||||||
|
from pipecat.services.gemini_multimodal_live import GeminiMultimodalLiveLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiMultimodalWeatherBot(MultimodalWeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
search_tool = {"google_search": {}}
|
||||||
|
tools_def = MultimodalWeatherBot.tools()
|
||||||
|
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
|
||||||
|
|
||||||
|
llm = GeminiMultimodalLiveLLMService(
|
||||||
|
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||||
|
voice_id="Puck",
|
||||||
|
transcribe_user_audio=True,
|
||||||
|
transcribe_model_audio=True,
|
||||||
|
tools=tools_def,
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(GeminiMultimodalWeatherBot().run())
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.google import GoogleLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"), model="gemini-2.0-flash-001")
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(GeminiWeatherBot().run())
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.grok import GrokLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GrokWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = GrokLLMService(api_key=os.getenv("GROK_API_KEY"))
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(GrokWeatherBot().run())
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.groq import GroqLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GroqWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(GroqWeatherBot().run())
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.nim import NimLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class NimWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = NimLLMService(
|
||||||
|
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.3-70b-instruct"
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(NimWeatherBot().run())
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from multimodal_base_function_calling import MultimodalWeatherBot
|
||||||
|
|
||||||
|
from pipecat.services.openai_realtime_beta import (
|
||||||
|
InputAudioTranscription,
|
||||||
|
OpenAIRealtimeBetaLLMService,
|
||||||
|
SessionProperties,
|
||||||
|
TurnDetection,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiRealTimeWeatherBot(MultimodalWeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
session_properties = SessionProperties(
|
||||||
|
input_audio_transcription=InputAudioTranscription(),
|
||||||
|
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||||
|
# on by default
|
||||||
|
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = OpenAIRealtimeBetaLLMService(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
session_properties=session_properties,
|
||||||
|
start_audio_paused=False,
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(OpenAiRealTimeWeatherBot().run())
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.openai import OpenAILLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(OpenAiWeatherBot().run())
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.openrouter import OpenRouterLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = OpenRouterLLMService(
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY"), model="openai/gpt-4o-2024-11-20"
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(OpenRouterWeatherBot().run())
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from base_function_calling import WeatherBot
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.services.together import TogetherLLMService
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherWeatherBot(WeatherBot):
|
||||||
|
"""Main class defining the LLM and passing it to the base handler."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
llm = TogetherLLMService(
|
||||||
|
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||||
|
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||||
|
)
|
||||||
|
super().__init__(llm)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(TogetherWeatherBot().run())
|
||||||
@@ -20,17 +20,14 @@ classifiers = [
|
|||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp~=3.11.11",
|
"aiohttp~=3.11.13",
|
||||||
"audioop-lts~=0.2.1; python_version>='3.13'",
|
"audioop-lts~=0.2.1; python_version>='3.13'",
|
||||||
# We need an older version of `httpx` that doesn't remove the deprecated
|
|
||||||
# `proxies` argument. This is necessary for Azure and Anthropic clients.
|
|
||||||
"httpx~=0.27.2",
|
|
||||||
"loguru~=0.7.3",
|
"loguru~=0.7.3",
|
||||||
"Markdown~=3.7",
|
"Markdown~=3.7",
|
||||||
"numpy~=1.26.4",
|
"numpy~=1.26.4",
|
||||||
"Pillow~=11.1.0",
|
"Pillow~=11.1.0",
|
||||||
"protobuf~=5.29.3",
|
"protobuf~=5.29.3",
|
||||||
"pydantic~=2.10.5",
|
"pydantic~=2.10.6",
|
||||||
"pyloudnorm~=0.1.1",
|
"pyloudnorm~=0.1.1",
|
||||||
"resampy~=0.4.3",
|
"resampy~=0.4.3",
|
||||||
"soxr~=0.5.0",
|
"soxr~=0.5.0",
|
||||||
@@ -42,7 +39,7 @@ Source = "https://github.com/pipecat-ai/pipecat"
|
|||||||
Website = "https://pipecat.ai"
|
Website = "https://pipecat.ai"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
anthropic = [ "anthropic~=0.45.2" ]
|
anthropic = [ "anthropic~=0.47.2" ]
|
||||||
assemblyai = [ "assemblyai~=0.36.0" ]
|
assemblyai = [ "assemblyai~=0.36.0" ]
|
||||||
aws = [ "boto3~=1.35.99" ]
|
aws = [ "boto3~=1.35.99" ]
|
||||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||||
@@ -50,13 +47,13 @@ canonical = [ "aiofiles~=24.1.0" ]
|
|||||||
cartesia = [ "cartesia~=1.3.1", "websockets~=13.1" ]
|
cartesia = [ "cartesia~=1.3.1", "websockets~=13.1" ]
|
||||||
cerebras = []
|
cerebras = []
|
||||||
deepseek = []
|
deepseek = []
|
||||||
daily = [ "daily-python~=0.14.2" ]
|
daily = [ "daily-python~=0.15.0" ]
|
||||||
deepgram = [ "deepgram-sdk~=3.8.0" ]
|
deepgram = [ "deepgram-sdk~=3.8.0" ]
|
||||||
elevenlabs = [ "websockets~=13.1" ]
|
elevenlabs = [ "websockets~=13.1" ]
|
||||||
fal = [ "fal-client~=0.5.6" ]
|
fal = [ "fal-client~=0.5.6" ]
|
||||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||||
gladia = [ "websockets~=13.1" ]
|
gladia = [ "websockets~=13.1" ]
|
||||||
google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.2.0", "google-generativeai~=0.8.4" ]
|
google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.3.0", "google-generativeai~=0.8.4" ]
|
||||||
grok = []
|
grok = []
|
||||||
groq = []
|
groq = []
|
||||||
gstreamer = [ "pygobject~=3.50.0" ]
|
gstreamer = [ "pygobject~=3.50.0" ]
|
||||||
@@ -73,7 +70,7 @@ noisereduce = [ "noisereduce~=3.0.3" ]
|
|||||||
openai = [ "websockets~=13.1" ]
|
openai = [ "websockets~=13.1" ]
|
||||||
openpipe = [ "openpipe~=4.45.0" ]
|
openpipe = [ "openpipe~=4.45.0" ]
|
||||||
perplexity = []
|
perplexity = []
|
||||||
playht = [ "pyht~=0.1.6", "websockets~=13.1" ]
|
playht = [ "pyht~=0.1.12", "websockets~=13.1" ]
|
||||||
rime = [ "websockets~=13.1" ]
|
rime = [ "websockets~=13.1" ]
|
||||||
riva = [ "nvidia-riva-client~=2.18.0" ]
|
riva = [ "nvidia-riva-client~=2.18.0" ]
|
||||||
sentry = [ "sentry-sdk~=2.20.0" ]
|
sentry = [ "sentry-sdk~=2.20.0" ]
|
||||||
|
|||||||
4
scripts/fix-ruff.sh
Executable file
4
scripts/fix-ruff.sh
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
ruff format src
|
||||||
|
ruff format examples
|
||||||
|
ruff format tests
|
||||||
|
ruff check --select I --fix
|
||||||
0
src/pipecat/adapters/__init__.py
Normal file
0
src/pipecat/adapters/__init__.py
Normal file
22
src/pipecat/adapters/base_llm_adapter.py
Normal file
22
src/pipecat/adapters/base_llm_adapter.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, List, Union, cast
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMAdapter(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]:
|
||||||
|
"""Converts tools to the provider's format."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_standard_tools(self, tools: Any) -> List[Any]:
|
||||||
|
if isinstance(tools, ToolsSchema):
|
||||||
|
logger.debug(f"Retrieving the tools using the adapter: {type(self)}")
|
||||||
|
return self.to_provider_tools_format(tools)
|
||||||
|
# Fallback to return the same tools in case they are not in a standard format
|
||||||
|
return tools
|
||||||
|
|
||||||
|
# TODO: we can move the logic to also handle the Messages here
|
||||||
0
src/pipecat/adapters/schemas/__init__.py
Normal file
0
src/pipecat/adapters/schemas/__init__.py
Normal file
55
src/pipecat/adapters/schemas/function_schema.py
Normal file
55
src/pipecat/adapters/schemas/function_schema.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionSchema:
|
||||||
|
def __init__(
|
||||||
|
self, name: str, description: str, properties: Dict[str, Any], required: List[str]
|
||||||
|
) -> None:
|
||||||
|
"""Standardized function schema representation.
|
||||||
|
|
||||||
|
:param name: Name of the function.
|
||||||
|
:param description: Description of the function.
|
||||||
|
:param properties: Dictionary defining properties types and descriptions.
|
||||||
|
:param required: List of required parameters.
|
||||||
|
"""
|
||||||
|
self._name = name
|
||||||
|
self._description = description
|
||||||
|
self._properties = properties
|
||||||
|
self._required = required
|
||||||
|
|
||||||
|
def to_default_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Converts the function schema to a dictionary.
|
||||||
|
|
||||||
|
:return: Dictionary representation of the function schema.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": self._name,
|
||||||
|
"description": self._description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": self._properties,
|
||||||
|
"required": self._required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
@property
|
||||||
|
def properties(self) -> Dict[str, Any]:
|
||||||
|
return self._properties
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required(self) -> List[str]:
|
||||||
|
return self._required
|
||||||
43
src/pipecat/adapters/schemas/tools_schema.py
Normal file
43
src/pipecat/adapters/schemas/tools_schema.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterType(Enum):
|
||||||
|
GEMINI = "gemini" # that is the only service where we are able to add custom tools for now
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsSchema:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
standard_tools: List[FunctionSchema],
|
||||||
|
custom_tools: Dict[AdapterType, List[Dict[str, Any]]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
A schema for tools that includes both standardized function schemas
|
||||||
|
and custom tools that do not follow the FunctionSchema format.
|
||||||
|
|
||||||
|
:param standard_tools: List of tools following FunctionSchema.
|
||||||
|
:param custom_tools: List of tools in a custom format (e.g., search_tool).
|
||||||
|
"""
|
||||||
|
self._standard_tools = standard_tools
|
||||||
|
self._custom_tools = custom_tools
|
||||||
|
|
||||||
|
@property
|
||||||
|
def standard_tools(self) -> List[FunctionSchema]:
|
||||||
|
return self._standard_tools
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_tools(self) -> Dict[AdapterType, List[Dict[str, Any]]]:
|
||||||
|
return self._custom_tools
|
||||||
|
|
||||||
|
@custom_tools.setter
|
||||||
|
def custom_tools(self, value: Dict[AdapterType, List[Dict[str, Any]]]) -> None:
|
||||||
|
self._custom_tools = value
|
||||||
0
src/pipecat/adapters/services/__init__.py
Normal file
0
src/pipecat/adapters/services/__init__.py
Normal file
34
src/pipecat/adapters/services/anthropic_adapter.py
Normal file
34
src/pipecat/adapters/services/anthropic_adapter.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicLLMAdapter(BaseLLMAdapter):
|
||||||
|
@staticmethod
|
||||||
|
def _to_anthropic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": function.name,
|
||||||
|
"description": function.description,
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": function.properties,
|
||||||
|
"required": function.required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||||
|
"""Converts function schemas to Anthropic's function-calling format.
|
||||||
|
|
||||||
|
:return: Anthropic formatted function call definition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
functions_schema = tools_schema.standard_tools
|
||||||
|
return [self._to_anthropic_function_format(func) for func in functions_schema]
|
||||||
28
src/pipecat/adapters/services/gemini_adapter.py
Normal file
28
src/pipecat/adapters/services/gemini_adapter.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||||
|
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiLLMAdapter(BaseLLMAdapter):
|
||||||
|
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||||
|
"""Converts function schemas to Gemini's function-calling format.
|
||||||
|
|
||||||
|
:return: Gemini formatted function call definition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
functions_schema = tools_schema.standard_tools
|
||||||
|
formatted_standard_tools = [
|
||||||
|
{"function_declarations": [func.to_default_dict() for func in functions_schema]}
|
||||||
|
]
|
||||||
|
custom_gemini_tools = []
|
||||||
|
if tools_schema.custom_tools:
|
||||||
|
custom_gemini_tools = tools_schema.custom_tools.get(AdapterType.GEMINI, [])
|
||||||
|
|
||||||
|
return formatted_standard_tools + custom_gemini_tools
|
||||||
24
src/pipecat/adapters/services/open_ai_adapter.py
Normal file
24
src/pipecat/adapters/services/open_ai_adapter.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMAdapter(BaseLLMAdapter):
|
||||||
|
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[ChatCompletionToolParam]:
|
||||||
|
"""Converts function schemas to OpenAI's function-calling format.
|
||||||
|
|
||||||
|
:return: OpenAI formatted function call definition.
|
||||||
|
"""
|
||||||
|
functions_schema = tools_schema.standard_tools
|
||||||
|
return [
|
||||||
|
ChatCompletionToolParam(type="function", function=func.to_default_dict())
|
||||||
|
for func in functions_schema
|
||||||
|
]
|
||||||
34
src/pipecat/adapters/services/open_ai_realtime_adapter.py
Normal file
34
src/pipecat/adapters/services/open_ai_realtime_adapter.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||||
|
@staticmethod
|
||||||
|
def _to_openai_realtime_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"name": function.name,
|
||||||
|
"description": function.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": function.properties,
|
||||||
|
"required": function.required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||||
|
"""Converts function schemas to Openai Realtime function-calling format.
|
||||||
|
|
||||||
|
:return: Openai Realtime formatted function call definition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
functions_schema = tools_schema.standard_tools
|
||||||
|
return [self._to_openai_realtime_function_format(func) for func in functions_schema]
|
||||||
0
src/pipecat/audio/turn/__init__.py
Normal file
0
src/pipecat/audio/turn/__init__.py
Normal file
32
src/pipecat/audio/turn/base_turn_analyzer.py
Normal file
32
src/pipecat/audio/turn/base_turn_analyzer.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class EndOfTurnState(Enum):
|
||||||
|
COMPLETE = 1
|
||||||
|
INCOMPLETE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEndOfTurnAnalyzer(ABC):
|
||||||
|
def __init__(self, *, sample_rate: Optional[int] = None):
|
||||||
|
self._init_sample_rate = sample_rate
|
||||||
|
self._sample_rate = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
return self._sample_rate
|
||||||
|
|
||||||
|
def set_sample_rate(self, sample_rate: int):
|
||||||
|
self._sample_rate = self._init_sample_rate or sample_rate
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def analyze_audio(self, buffer: bytes) -> EndOfTurnState:
|
||||||
|
pass
|
||||||
83
src/pipecat/audio/turn/smart_turn.py
Normal file
83
src/pipecat/audio/turn/smart_turn.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from transformers import AutoFeatureExtractor, Wav2Vec2BertForSequenceClassification
|
||||||
|
|
||||||
|
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
|
||||||
|
|
||||||
|
# MODEL_PATH = "model-v1"
|
||||||
|
MODEL_PATH = "pipecat-ai/smart-turn"
|
||||||
|
|
||||||
|
|
||||||
|
class SmartTurnAnalyzer(BaseEndOfTurnAnalyzer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._audio_buffer = bytearray()
|
||||||
|
|
||||||
|
logger.debug("Loading Smart Turn model...")
|
||||||
|
|
||||||
|
# Load model and processor
|
||||||
|
model = Wav2Vec2BertForSequenceClassification.from_pretrained(MODEL_PATH)
|
||||||
|
self._processor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
# Set model to evaluation mode and move to GPU if available
|
||||||
|
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self._model = model.to(self._device)
|
||||||
|
self._model.eval()
|
||||||
|
|
||||||
|
logger.debug("Loaded Smart Turn")
|
||||||
|
|
||||||
|
def analyze_audio(self, buffer: bytes) -> EndOfTurnState:
|
||||||
|
self._audio_buffer += buffer
|
||||||
|
if len(self._audio_buffer) < 16000 * 2 * 6:
|
||||||
|
return EndOfTurnState.INCOMPLETE
|
||||||
|
|
||||||
|
audio_int16 = np.frombuffer(self._audio_buffer, dtype=np.int16)
|
||||||
|
|
||||||
|
# Divide by 32768 because we have signed 16-bit data.
|
||||||
|
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
print(audio_float32)
|
||||||
|
|
||||||
|
# Process audio
|
||||||
|
inputs = self._processor(
|
||||||
|
audio_float32,
|
||||||
|
sampling_rate=16000,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=800, # Maximum length as specified in training
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move inputs to device
|
||||||
|
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Get probabilities using softmax
|
||||||
|
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
||||||
|
completion_prob = probabilities[0, 1].item() # Probability of class 1 (Complete)
|
||||||
|
|
||||||
|
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||||
|
prediction = 1 if completion_prob > 0.5 else 0
|
||||||
|
|
||||||
|
state = EndOfTurnState.COMPLETE if prediction == 1 else EndOfTurnState.INCOMPLETE
|
||||||
|
|
||||||
|
if state == EndOfTurnState.COMPLETE:
|
||||||
|
self._audio_buffer = bytearray()
|
||||||
|
else:
|
||||||
|
self._audio_buffer = self._audio_buffer[len(buffer) :]
|
||||||
|
|
||||||
|
print("AAAAAAAAAAAA", state)
|
||||||
|
|
||||||
|
return state
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
# SPDX-License-Identifier: BSD 2-Clause License
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
#
|
#
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class VADParams(BaseModel):
|
|||||||
min_volume: float = VAD_MIN_VOLUME
|
min_volume: float = VAD_MIN_VOLUME
|
||||||
|
|
||||||
|
|
||||||
class VADAnalyzer:
|
class VADAnalyzer(ABC):
|
||||||
def __init__(self, *, sample_rate: Optional[int] = None, params: VADParams):
|
def __init__(self, *, sample_rate: Optional[int] = None, params: VADParams):
|
||||||
self._init_sample_rate = sample_rate
|
self._init_sample_rate = sample_rate
|
||||||
self._sample_rate = 0
|
self._sample_rate = 0
|
||||||
|
|||||||
@@ -568,7 +568,8 @@ class UserStoppedSpeakingFrame(SystemFrame):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EmulateUserStartedSpeakingFrame(SystemFrame):
|
class EmulateUserStartedSpeakingFrame(SystemFrame):
|
||||||
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
||||||
user starts speaking."""
|
user starts speaking.
|
||||||
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -576,7 +577,20 @@ class EmulateUserStartedSpeakingFrame(SystemFrame):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
||||||
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
||||||
user stops speaking."""
|
user stops speaking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserEndOfTurnFrame(SystemFrame):
|
||||||
|
"""Emitted by VAD to indicate that a user has started speaking. This can be
|
||||||
|
used for interruptions or other times when detecting that someone is
|
||||||
|
speaking is more important than knowing what they're saying (as you will
|
||||||
|
with a TranscriptionFrame)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -5,25 +5,14 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import AsyncIterable, Iterable
|
from typing import AsyncIterable, Iterable
|
||||||
|
|
||||||
from pipecat.frames.frames import Frame
|
from pipecat.frames.frames import Frame
|
||||||
|
from pipecat.utils.base_object import BaseObject
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(ABC):
|
class BaseTask(BaseObject):
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def id(self) -> int:
|
|
||||||
"""Returns the unique indetifier for this task."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Returns the name of this task."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||||
"""Sets the event loop that this task will run on."""
|
"""Sets the event loop that this task will run on."""
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ from typing import Optional
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from pipecat.pipeline.task import PipelineTask
|
from pipecat.pipeline.task import PipelineTask
|
||||||
from pipecat.utils.utils import obj_count, obj_id
|
from pipecat.utils.base_object import BaseObject
|
||||||
|
|
||||||
|
|
||||||
class PipelineRunner:
|
class PipelineRunner(BaseObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -24,8 +24,7 @@ class PipelineRunner:
|
|||||||
force_gc: bool = False,
|
force_gc: bool = False,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
):
|
):
|
||||||
self.id: int = obj_id()
|
super().__init__(name=name)
|
||||||
self.name: str = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
|
||||||
|
|
||||||
self._tasks = {}
|
self._tasks = {}
|
||||||
self._sig_task = None
|
self._sig_task = None
|
||||||
@@ -74,6 +73,3 @@ class PipelineRunner:
|
|||||||
collected = gc.collect()
|
collected = gc.collect()
|
||||||
logger.debug(f"Garbage collector: collected {collected} objects.")
|
logger.debug(f"Garbage collector: collected {collected} objects.")
|
||||||
logger.debug(f"Garbage collector: uncollectable objects {gc.garbage}")
|
logger.debug(f"Garbage collector: uncollectable objects {gc.garbage}")
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.name
|
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from pipecat.pipeline.base_task import BaseTask
|
|||||||
from pipecat.pipeline.task_observer import TaskObserver
|
from pipecat.pipeline.task_observer import TaskObserver
|
||||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
|
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
|
||||||
from pipecat.utils.utils import obj_count, obj_id
|
|
||||||
|
|
||||||
HEARTBEAT_SECONDS = 1.0
|
HEARTBEAT_SECONDS = 1.0
|
||||||
HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 5
|
HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 5
|
||||||
@@ -138,9 +137,7 @@ class PipelineTask(BaseTask):
|
|||||||
task_manager: Optional[BaseTaskManager] = None,
|
task_manager: Optional[BaseTaskManager] = None,
|
||||||
check_dangling_tasks: bool = True,
|
check_dangling_tasks: bool = True,
|
||||||
):
|
):
|
||||||
self._id: int = obj_id()
|
super().__init__()
|
||||||
self._name: str = f"{self.__class__.__name__}#{obj_count(self)}"
|
|
||||||
|
|
||||||
self._pipeline = pipeline
|
self._pipeline = pipeline
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
self._params = params
|
self._params = params
|
||||||
@@ -180,16 +177,6 @@ class PipelineTask(BaseTask):
|
|||||||
|
|
||||||
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
|
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> int:
|
|
||||||
"""Returns the unique indetifier for this task."""
|
|
||||||
return self._id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Returns the name of this task."""
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def params(self) -> PipelineParams:
|
def params(self) -> PipelineParams:
|
||||||
"""Returns the pipeline parameters of this task."""
|
"""Returns the pipeline parameters of this task."""
|
||||||
@@ -434,6 +421,3 @@ class PipelineTask(BaseTask):
|
|||||||
tasks = [t.get_name() for t in self._task_manager.current_tasks()]
|
tasks = [t.get_name() for t in self._task_manager.current_tasks()]
|
||||||
if tasks:
|
if tasks:
|
||||||
logger.warning(f"Dangling tasks detected: {tasks}")
|
logger.warning(f"Dangling tasks detected: {tasks}")
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.name
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from pipecat.frames.frames import Frame
|
|||||||
from pipecat.observers.base_observer import BaseObserver
|
from pipecat.observers.base_observer import BaseObserver
|
||||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
from pipecat.utils.asyncio import BaseTaskManager
|
from pipecat.utils.asyncio import BaseTaskManager
|
||||||
from pipecat.utils.utils import obj_count, obj_id
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -56,20 +55,10 @@ class TaskObserver(BaseObserver):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, observers: List[BaseObserver] = [], task_manager: BaseTaskManager):
|
def __init__(self, *, observers: List[BaseObserver] = [], task_manager: BaseTaskManager):
|
||||||
self._id: int = obj_id()
|
|
||||||
self._name: str = f"{self.__class__.__name__}#{obj_count(self)}"
|
|
||||||
self._observers = observers
|
self._observers = observers
|
||||||
self._task_manager = task_manager
|
self._task_manager = task_manager
|
||||||
self._proxies: List[Proxy] = []
|
self._proxies: List[Proxy] = []
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> int:
|
|
||||||
return self._id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Starts all proxy observer tasks."""
|
"""Starts all proxy observer tasks."""
|
||||||
self._proxies = self._create_proxies(self._observers)
|
self._proxies = self._create_proxies(self._observers)
|
||||||
@@ -100,7 +89,7 @@ class TaskObserver(BaseObserver):
|
|||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
task = self._task_manager.create_task(
|
task = self._task_manager.create_task(
|
||||||
self._proxy_task_handler(queue, observer),
|
self._proxy_task_handler(queue, observer),
|
||||||
f"{self}::{observer.__class__.__name__}::_proxy_task_handler",
|
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
|
||||||
)
|
)
|
||||||
proxy = Proxy(queue=queue, task=task, observer=observer)
|
proxy = Proxy(queue=queue, task=task, observer=observer)
|
||||||
proxies.append(proxy)
|
proxies.append(proxy)
|
||||||
@@ -112,6 +101,3 @@ class TaskObserver(BaseObserver):
|
|||||||
await observer.on_push_frame(
|
await observer.on_push_frame(
|
||||||
data.src, data.dst, data.frame, data.direction, data.timestamp
|
data.src, data.dst, data.frame, data.direction, data.timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.name
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
|
|||||||
LLMMessagesFrame,
|
LLMMessagesFrame,
|
||||||
LLMMessagesUpdateFrame,
|
LLMMessagesUpdateFrame,
|
||||||
LLMSetToolsFrame,
|
LLMSetToolsFrame,
|
||||||
|
LLMTextFrame,
|
||||||
StartFrame,
|
StartFrame,
|
||||||
StartInterruptionFrame,
|
StartInterruptionFrame,
|
||||||
TextFrame,
|
TextFrame,
|
||||||
@@ -36,6 +37,59 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
|||||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFullResponseAggregator(FrameProcessor):
|
||||||
|
"""This is an LLM aggregator that aggregates a full LLM completion. It
|
||||||
|
aggregates LLM text frames (tokens) received between
|
||||||
|
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. Every full
|
||||||
|
completion is returned via the "on_completion" event handler:
|
||||||
|
|
||||||
|
@aggregator.event_handler("on_completion")
|
||||||
|
async def on_completion(
|
||||||
|
aggregator: LLMFullResponseAggregator,
|
||||||
|
completion: str,
|
||||||
|
completed: bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self._aggregation = ""
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
self._register_event_handler("on_completion")
|
||||||
|
|
||||||
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
|
await super().process_frame(frame, direction)
|
||||||
|
|
||||||
|
if isinstance(frame, StartInterruptionFrame):
|
||||||
|
await self._call_event_handler("on_completion", self._aggregation, False)
|
||||||
|
self._aggregation = ""
|
||||||
|
self._started = False
|
||||||
|
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||||
|
await self._handle_llm_start(frame)
|
||||||
|
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||||
|
await self._handle_llm_end(frame)
|
||||||
|
elif isinstance(frame, LLMTextFrame):
|
||||||
|
await self._handle_llm_text(frame)
|
||||||
|
|
||||||
|
await self.push_frame(frame, direction)
|
||||||
|
|
||||||
|
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||||
|
self._started = True
|
||||||
|
|
||||||
|
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||||
|
await self._call_event_handler("on_completion", self._aggregation, True)
|
||||||
|
self._started = False
|
||||||
|
self._aggregation = ""
|
||||||
|
|
||||||
|
async def _handle_llm_text(self, frame: TextFrame):
|
||||||
|
if not self._started:
|
||||||
|
return
|
||||||
|
self._aggregation += frame.text
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMResponseAggregator(FrameProcessor):
|
class BaseLLMResponseAggregator(FrameProcessor):
|
||||||
"""This is the base class for all LLM response aggregators. These
|
"""This is the base class for all LLM response aggregators. These
|
||||||
aggregators process incoming frames and aggregate content until they are
|
aggregators process incoming frames and aggregate content until they are
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from openai.types.chat import (
|
|||||||
)
|
)
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
AudioRawFrame,
|
AudioRawFrame,
|
||||||
Frame,
|
Frame,
|
||||||
@@ -44,13 +46,20 @@ class OpenAILLMContext:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
messages: Optional[List[ChatCompletionMessageParam]] = None,
|
messages: Optional[List[ChatCompletionMessageParam]] = None,
|
||||||
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = NOT_GIVEN,
|
||||||
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
|
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
|
||||||
):
|
):
|
||||||
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
|
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
|
||||||
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||||
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
|
self._tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = tools
|
||||||
self._user_image_request_context = {}
|
self._user_image_request_context = {}
|
||||||
|
self._llm_adapter: Optional[BaseLLMAdapter] = None
|
||||||
|
|
||||||
|
def get_llm_adapter(self) -> Optional[BaseLLMAdapter]:
|
||||||
|
return self._llm_adapter
|
||||||
|
|
||||||
|
def set_llm_adapter(self, llm_adapter: BaseLLMAdapter):
|
||||||
|
self._llm_adapter = llm_adapter
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
||||||
@@ -67,7 +76,9 @@ class OpenAILLMContext:
|
|||||||
return self._messages
|
return self._messages
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tools(self) -> List[ChatCompletionToolParam] | NotGiven:
|
def tools(self) -> List[ChatCompletionToolParam] | NotGiven | List[Any]:
|
||||||
|
if self._llm_adapter:
|
||||||
|
return self._llm_adapter.from_standard_tools(self._tools)
|
||||||
return self._tools
|
return self._tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -152,7 +163,7 @@ class OpenAILLMContext:
|
|||||||
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
|
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
|
||||||
self._tool_choice = tool_choice
|
self._tool_choice = tool_choice
|
||||||
|
|
||||||
def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN):
|
def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = NOT_GIVEN):
|
||||||
if tools != NOT_GIVEN and len(tools) == 0:
|
if tools != NOT_GIVEN and len(tools) == 0:
|
||||||
tools = NOT_GIVEN
|
tools = NOT_GIVEN
|
||||||
self._tools = tools
|
self._tools = tools
|
||||||
|
|||||||
@@ -21,20 +21,32 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
|||||||
|
|
||||||
|
|
||||||
class AudioBufferProcessor(FrameProcessor):
|
class AudioBufferProcessor(FrameProcessor):
|
||||||
"""This processor buffers audio raw frames (input and output). The mixed
|
"""Processes and buffers audio frames from both input (user) and output (bot) sources.
|
||||||
audio can be obtained by registering an "on_audio_data" event handler.
|
|
||||||
The event handler will be called every time `buffer_size` is reached.
|
|
||||||
|
|
||||||
You can provide the desired output `sample_rate` and incoming audio frames
|
This processor manages audio buffering and synchronization, providing both merged and
|
||||||
will resampled to match it. Also, you can provide the number of channels, 1
|
track-specific audio access through event handlers. It supports various audio configurations
|
||||||
for mono and 2 for stereo. With mono audio user and bot audio will be mixed,
|
including sample rate conversion and mono/stereo output.
|
||||||
in the case of stereo the left channel will be used for the user's audio and
|
|
||||||
the right channel for the bot.
|
|
||||||
|
|
||||||
Most of the time, user audio will be a continuous stream but it's possible
|
Events:
|
||||||
that in some cases only the spoken audio is sent. To accomodate for those
|
on_audio_data: Triggered when buffer_size is reached, providing merged audio
|
||||||
cases make sure to set `user_continuous_stream` accordingly.
|
on_track_audio_data: Triggered when buffer_size is reached, providing separate tracks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate (Optional[int]): Desired output sample rate. If None, uses source rate
|
||||||
|
num_channels (int): Number of channels (1 for mono, 2 for stereo). Defaults to 1
|
||||||
|
buffer_size (int): Size of buffer before triggering events. 0 for no buffering
|
||||||
|
user_continuous_stream (bool): Whether user audio is continuous or speech-only
|
||||||
|
|
||||||
|
Audio handling:
|
||||||
|
- Mono output (num_channels=1): User and bot audio are mixed
|
||||||
|
- Stereo output (num_channels=2): User audio on left, bot audio on right
|
||||||
|
- Automatic resampling of incoming audio to match desired sample_rate
|
||||||
|
- Silence insertion for non-continuous audio streams
|
||||||
|
- Buffer synchronization between user and bot audio
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When user_continuous_stream is False, the processor expects only speech
|
||||||
|
segments and will handle silence insertion between segments automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -65,21 +77,45 @@ class AudioBufferProcessor(FrameProcessor):
|
|||||||
self._resampler = create_default_resampler()
|
self._resampler = create_default_resampler()
|
||||||
|
|
||||||
self._register_event_handler("on_audio_data")
|
self._register_event_handler("on_audio_data")
|
||||||
|
self._register_event_handler("on_track_audio_data")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_rate(self) -> int:
|
def sample_rate(self) -> int:
|
||||||
|
"""Current sample rate of the audio processor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The sample rate in Hz
|
||||||
|
"""
|
||||||
return self._sample_rate
|
return self._sample_rate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_channels(self) -> int:
|
def num_channels(self) -> int:
|
||||||
|
"""Number of channels in the audio output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of channels (1 for mono, 2 for stereo)
|
||||||
|
"""
|
||||||
return self._num_channels
|
return self._num_channels
|
||||||
|
|
||||||
def has_audio(self) -> bool:
|
def has_audio(self) -> bool:
|
||||||
|
"""Check if both user and bot audio buffers contain data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if both buffers contain audio data
|
||||||
|
"""
|
||||||
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
|
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
|
||||||
self._bot_audio_buffer
|
self._bot_audio_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge_audio_buffers(self) -> bytes:
|
def merge_audio_buffers(self) -> bytes:
|
||||||
|
"""Merge user and bot audio buffers into a single audio stream.
|
||||||
|
|
||||||
|
For mono output, audio is mixed. For stereo output, user audio is placed
|
||||||
|
on the left channel and bot audio on the right channel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Mixed audio data
|
||||||
|
"""
|
||||||
if self._num_channels == 1:
|
if self._num_channels == 1:
|
||||||
return mix_audio(bytes(self._user_audio_buffer), bytes(self._bot_audio_buffer))
|
return mix_audio(bytes(self._user_audio_buffer), bytes(self._bot_audio_buffer))
|
||||||
elif self._num_channels == 2:
|
elif self._num_channels == 2:
|
||||||
@@ -90,14 +126,23 @@ class AudioBufferProcessor(FrameProcessor):
|
|||||||
return b""
|
return b""
|
||||||
|
|
||||||
async def start_recording(self):
|
async def start_recording(self):
|
||||||
|
"""Start recording audio from both user and bot.
|
||||||
|
|
||||||
|
Initializes recording state and resets audio buffers.
|
||||||
|
"""
|
||||||
self._recording = True
|
self._recording = True
|
||||||
self._reset_recording()
|
self._reset_recording()
|
||||||
|
|
||||||
async def stop_recording(self):
|
async def stop_recording(self):
|
||||||
|
"""Stop recording and trigger final audio data handlers.
|
||||||
|
|
||||||
|
Calls audio handlers with any remaining buffered audio before stopping.
|
||||||
|
"""
|
||||||
await self._call_on_audio_data_handler()
|
await self._call_on_audio_data_handler()
|
||||||
self._recording = False
|
self._recording = False
|
||||||
|
|
||||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
|
"""Process incoming audio frames and manage audio buffers."""
|
||||||
await super().process_frame(frame, direction)
|
await super().process_frame(frame, direction)
|
||||||
|
|
||||||
# Update output sample rate if necessary.
|
# Update output sample rate if necessary.
|
||||||
@@ -160,10 +205,21 @@ class AudioBufferProcessor(FrameProcessor):
|
|||||||
if not self.has_audio() or not self._recording:
|
if not self.has_audio() or not self._recording:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Call original handler with merged audio
|
||||||
merged_audio = self.merge_audio_buffers()
|
merged_audio = self.merge_audio_buffers()
|
||||||
await self._call_event_handler(
|
await self._call_event_handler(
|
||||||
"on_audio_data", merged_audio, self._sample_rate, self._num_channels
|
"on_audio_data", merged_audio, self._sample_rate, self._num_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Call new handler with separate tracks
|
||||||
|
await self._call_event_handler(
|
||||||
|
"on_track_audio_data",
|
||||||
|
bytes(self._user_audio_buffer),
|
||||||
|
bytes(self._bot_audio_buffer),
|
||||||
|
self._sample_rate,
|
||||||
|
self._num_channels,
|
||||||
|
)
|
||||||
|
|
||||||
self._reset_audio_buffers()
|
self._reset_audio_buffers()
|
||||||
|
|
||||||
def _buffer_has_audio(self, buffer: bytearray) -> bool:
|
def _buffer_has_audio(self, buffer: bytearray) -> bool:
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Awaitable, Callable, Coroutine, Optional
|
from typing import Awaitable, Callable, Coroutine, Optional
|
||||||
|
|
||||||
@@ -24,7 +23,7 @@ from pipecat.frames.frames import (
|
|||||||
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
|
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
|
||||||
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
|
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
|
||||||
from pipecat.utils.asyncio import BaseTaskManager
|
from pipecat.utils.asyncio import BaseTaskManager
|
||||||
from pipecat.utils.utils import obj_count, obj_id
|
from pipecat.utils.base_object import BaseObject
|
||||||
|
|
||||||
|
|
||||||
class FrameDirection(Enum):
|
class FrameDirection(Enum):
|
||||||
@@ -32,7 +31,7 @@ class FrameDirection(Enum):
|
|||||||
UPSTREAM = 2
|
UPSTREAM = 2
|
||||||
|
|
||||||
|
|
||||||
class FrameProcessor:
|
class FrameProcessor(BaseObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -40,14 +39,11 @@ class FrameProcessor:
|
|||||||
metrics: Optional[FrameProcessorMetrics] = None,
|
metrics: Optional[FrameProcessorMetrics] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._id: int = obj_id()
|
super().__init__(name=name)
|
||||||
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
|
||||||
self._parent: Optional["FrameProcessor"] = None
|
self._parent: Optional["FrameProcessor"] = None
|
||||||
self._prev: Optional["FrameProcessor"] = None
|
self._prev: Optional["FrameProcessor"] = None
|
||||||
self._next: Optional["FrameProcessor"] = None
|
self._next: Optional["FrameProcessor"] = None
|
||||||
|
|
||||||
self._event_handlers: dict = {}
|
|
||||||
|
|
||||||
# Clock
|
# Clock
|
||||||
self._clock: Optional[BaseClock] = None
|
self._clock: Optional[BaseClock] = None
|
||||||
|
|
||||||
@@ -254,23 +250,6 @@ class FrameProcessor:
|
|||||||
else:
|
else:
|
||||||
await self.__push_queue.put((frame, direction))
|
await self.__push_queue.put((frame, direction))
|
||||||
|
|
||||||
def event_handler(self, event_name: str):
|
|
||||||
def decorator(handler):
|
|
||||||
self.add_event_handler(event_name, handler)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def add_event_handler(self, event_name: str, handler):
|
|
||||||
if event_name not in self._event_handlers:
|
|
||||||
raise Exception(f"Event handler {event_name} not registered")
|
|
||||||
self._event_handlers[event_name].append(handler)
|
|
||||||
|
|
||||||
def _register_event_handler(self, event_name: str):
|
|
||||||
if event_name in self._event_handlers:
|
|
||||||
raise Exception(f"Event handler {event_name} already registered")
|
|
||||||
self._event_handlers[event_name] = []
|
|
||||||
|
|
||||||
async def __start(self, frame: StartFrame):
|
async def __start(self, frame: StartFrame):
|
||||||
self.__create_input_task()
|
self.__create_input_task()
|
||||||
self.__create_push_task()
|
self.__create_push_task()
|
||||||
@@ -385,16 +364,3 @@ class FrameProcessor:
|
|||||||
(frame, direction) = await self.__push_queue.get()
|
(frame, direction) = await self.__push_queue.get()
|
||||||
await self.__internal_push_frame(frame, direction)
|
await self.__internal_push_frame(frame, direction)
|
||||||
self.__push_queue.task_done()
|
self.__push_queue.task_done()
|
||||||
|
|
||||||
async def _call_event_handler(self, event_name: str, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
for handler in self._event_handlers[event_name]:
|
|
||||||
if inspect.iscoroutinefunction(handler):
|
|
||||||
await handler(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
handler(self, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.name
|
|
||||||
|
|||||||
@@ -375,6 +375,22 @@ class RTVIMetricsMessage(BaseModel):
|
|||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class RTVIServerMessage(BaseModel):
|
||||||
|
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||||
|
type: Literal["server-message"] = "server-message"
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RTVIServerMessageFrame(SystemFrame):
|
||||||
|
"""A frame for sending server messages to the client."""
|
||||||
|
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name}(data: {self.data})"
|
||||||
|
|
||||||
|
|
||||||
class RTVIFrameProcessor(FrameProcessor):
|
class RTVIFrameProcessor(FrameProcessor):
|
||||||
def __init__(self, direction: FrameDirection = FrameDirection.DOWNSTREAM, **kwargs):
|
def __init__(self, direction: FrameDirection = FrameDirection.DOWNSTREAM, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -710,6 +726,9 @@ class RTVIObserver(BaseObserver):
|
|||||||
mark_as_seen = False
|
mark_as_seen = False
|
||||||
elif isinstance(frame, MetricsFrame):
|
elif isinstance(frame, MetricsFrame):
|
||||||
await self._handle_metrics(frame)
|
await self._handle_metrics(frame)
|
||||||
|
elif isinstance(frame, RTVIServerMessageFrame):
|
||||||
|
message = RTVIServerMessage(data=frame.data)
|
||||||
|
await self.push_transport_message_urgent(message)
|
||||||
|
|
||||||
if mark_as_seen:
|
if mark_as_seen:
|
||||||
self._frames_seen.add(frame.id)
|
self._frames_seen.add(frame.id)
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import asyncio
|
|||||||
import io
|
import io
|
||||||
import wave
|
import wave
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple
|
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple, Type
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||||
|
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||||
from pipecat.audio.utils import calculate_audio_volume, exp_smoothing
|
from pipecat.audio.utils import calculate_audio_volume, exp_smoothing
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
AudioRawFrame,
|
AudioRawFrame,
|
||||||
@@ -137,10 +139,23 @@ class AIService(FrameProcessor):
|
|||||||
class LLMService(AIService):
|
class LLMService(AIService):
|
||||||
"""This class is a no-op but serves as a base class for LLM services."""
|
"""This class is a no-op but serves as a base class for LLM services."""
|
||||||
|
|
||||||
|
# OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
|
||||||
|
# However, subclasses should override this with a more specific adapter when necessary.
|
||||||
|
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self._start_callbacks = {}
|
self._start_callbacks = {}
|
||||||
|
self._adapter = self.adapter_class()
|
||||||
|
|
||||||
|
def get_llm_adapter(self) -> BaseLLMAdapter:
|
||||||
|
return self._adapter
|
||||||
|
|
||||||
|
def create_context_aggregator(
|
||||||
|
self, context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||||
|
) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
self._register_event_handler("on_completion_timeout")
|
self._register_event_handler("on_completion_timeout")
|
||||||
|
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ import io
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
Frame,
|
Frame,
|
||||||
FunctionCallInProgressFrame,
|
FunctionCallInProgressFrame,
|
||||||
@@ -85,6 +86,9 @@ class AnthropicLLMService(LLMService):
|
|||||||
use `AsyncAnthropicBedrock` and `AsyncAnthropicVertex` clients
|
use `AsyncAnthropicBedrock` and `AsyncAnthropicVertex` clients
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Overriding the default adapter to use the Anthropic one.
|
||||||
|
adapter_class = AnthropicLLMAdapter
|
||||||
|
|
||||||
class InputParams(BaseModel):
|
class InputParams(BaseModel):
|
||||||
enable_prompt_caching_beta: Optional[bool] = False
|
enable_prompt_caching_beta: Optional[bool] = False
|
||||||
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
||||||
@@ -123,16 +127,38 @@ class AnthropicLLMService(LLMService):
|
|||||||
def enable_prompt_caching_beta(self) -> bool:
|
def enable_prompt_caching_beta(self) -> bool:
|
||||||
return self._enable_prompt_caching_beta
|
return self._enable_prompt_caching_beta
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_context_aggregator(
|
def create_context_aggregator(
|
||||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_kwargs: Mapping[str, Any] = {},
|
||||||
|
assistant_kwargs: Mapping[str, Any] = {},
|
||||||
) -> AnthropicContextAggregatorPair:
|
) -> AnthropicContextAggregatorPair:
|
||||||
|
"""Create an instance of AnthropicContextAggregatorPair from an
|
||||||
|
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
|
assistant aggregators can be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the user context aggregator constructor. Defaults
|
||||||
|
to an empty mapping.
|
||||||
|
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the assistant context aggregator
|
||||||
|
constructor. Defaults to an empty mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnthropicContextAggregatorPair: A pair of context aggregators, one
|
||||||
|
for the user and one for the assistant, encapsulated in an
|
||||||
|
AnthropicContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
|
||||||
if isinstance(context, OpenAILLMContext):
|
if isinstance(context, OpenAILLMContext):
|
||||||
context = AnthropicLLMContext.from_openai_context(context)
|
context = AnthropicLLMContext.from_openai_context(context)
|
||||||
user = AnthropicUserContextAggregator(context)
|
user = AnthropicUserContextAggregator(context, **user_kwargs)
|
||||||
assistant = AnthropicAssistantContextAggregator(
|
assistant = AnthropicAssistantContextAggregator(context, **assistant_kwargs)
|
||||||
context, expect_stripped_words=assistant_expect_stripped_words
|
|
||||||
)
|
|
||||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|
||||||
async def _process_context(self, context: OpenAILLMContext):
|
async def _process_context(self, context: OpenAILLMContext):
|
||||||
@@ -152,7 +178,7 @@ class AnthropicLLMService(LLMService):
|
|||||||
await self.start_processing_metrics()
|
await self.start_processing_metrics()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Generating chat: {context.system} | {context.get_messages_for_logging()}"
|
f"{self}: Generating chat [{context.system}] | [{context.get_messages_for_logging()}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = context.messages
|
messages = context.messages
|
||||||
@@ -362,6 +388,7 @@ class AnthropicLLMContext(OpenAILLMContext):
|
|||||||
tools=openai_context.tools,
|
tools=openai_context.tools,
|
||||||
tool_choice=openai_context.tool_choice,
|
tool_choice=openai_context.tool_choice,
|
||||||
)
|
)
|
||||||
|
self.set_llm_adapter(openai_context.get_llm_adapter())
|
||||||
self._restructure_from_openai_messages()
|
self._restructure_from_openai_messages()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class PollyTTSService(TTSService):
|
|||||||
return audio_data
|
return audio_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.start_ttfb_metrics()
|
await self.start_ttfb_metrics()
|
||||||
|
|||||||
@@ -578,7 +578,7 @@ class AzureTTSService(AzureBaseTTSService):
|
|||||||
self._audio_queue.put_nowait(None)
|
self._audio_queue.put_nowait(None)
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._speech_synthesizer is None:
|
if self._speech_synthesizer is None:
|
||||||
@@ -645,7 +645,7 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
await self.start_ttfb_metrics()
|
await self.start_ttfb_metrics()
|
||||||
|
|
||||||
|
|||||||
@@ -62,17 +62,21 @@ class CanonicalMetricsService(AIService):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
aiohttp_session: aiohttp.ClientSession,
|
aiohttp_session: aiohttp.ClientSession,
|
||||||
audio_buffer_processor: AudioBufferProcessor,
|
|
||||||
call_id: str,
|
call_id: str,
|
||||||
assistant: str,
|
assistant: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_url: str = "https://voiceapp.canonical.chat/api/v1",
|
api_url: str = "https://voiceapp.canonical.chat/api/v1",
|
||||||
assistant_speaks_first: bool = True,
|
assistant_speaks_first: bool = True,
|
||||||
output_dir: str = "recordings",
|
output_dir: str = "recordings",
|
||||||
|
audio_buffer_processor: Optional[AudioBufferProcessor] = None,
|
||||||
context: Optional[OpenAILLMContext] = None,
|
context: Optional[OpenAILLMContext] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
# Validate that at least one of audio_buffer_processor or context is provided
|
||||||
|
if audio_buffer_processor is None and context is None:
|
||||||
|
raise ValueError("At least one of audio_buffer_processor or context must be specified")
|
||||||
|
|
||||||
self._aiohttp_session = aiohttp_session
|
self._aiohttp_session = aiohttp_session
|
||||||
self._audio_buffer_processor = audio_buffer_processor
|
self._audio_buffer_processor = audio_buffer_processor
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
@@ -85,16 +89,36 @@ class CanonicalMetricsService(AIService):
|
|||||||
|
|
||||||
async def stop(self, frame: EndFrame):
|
async def stop(self, frame: EndFrame):
|
||||||
await super().stop(frame)
|
await super().stop(frame)
|
||||||
await self._process_audio()
|
await self._process_completion()
|
||||||
|
|
||||||
async def cancel(self, frame: CancelFrame):
|
async def cancel(self, frame: CancelFrame):
|
||||||
await super().cancel(frame)
|
await super().cancel(frame)
|
||||||
await self._process_audio()
|
await self._process_completion()
|
||||||
|
|
||||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
await super().process_frame(frame, direction)
|
await super().process_frame(frame, direction)
|
||||||
await self.push_frame(frame, direction)
|
await self.push_frame(frame, direction)
|
||||||
|
|
||||||
|
async def _process_completion(self):
|
||||||
|
if self._audio_buffer_processor is not None:
|
||||||
|
await self._process_audio()
|
||||||
|
elif self._context is not None:
|
||||||
|
await self._process_transcript()
|
||||||
|
|
||||||
|
async def _process_transcript(self):
|
||||||
|
params = {
|
||||||
|
"callId": self._call_id,
|
||||||
|
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
|
||||||
|
"transcript": self._context.messages,
|
||||||
|
}
|
||||||
|
response = await self._aiohttp_session.post(
|
||||||
|
f"{self._api_url}/call",
|
||||||
|
headers=self._request_headers(),
|
||||||
|
json=params,
|
||||||
|
)
|
||||||
|
if not response.ok:
|
||||||
|
logger.error(f"Failed to process transcript: {await response.text()}")
|
||||||
|
|
||||||
async def _process_audio(self):
|
async def _process_audio(self):
|
||||||
audio_buffer_processor = self._audio_buffer_processor
|
audio_buffer_processor = self._audio_buffer_processor
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
|||||||
logger.error(f"{self} error, unknown message type: {msg}")
|
logger.error(f"{self} error, unknown message type: {msg}")
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self._websocket:
|
if not self._websocket:
|
||||||
@@ -358,7 +358,7 @@ class CartesiaHttpTTSService(TTSService):
|
|||||||
await self._client.close()
|
await self._client.close()
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
voice_controls = None
|
voice_controls = None
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class DeepgramTTSService(TTSService):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
options = SpeakOptions(
|
options = SpeakOptions(
|
||||||
model=self._voice_id,
|
model=self._voice_id,
|
||||||
|
|||||||
@@ -116,6 +116,44 @@ def output_format_from_sample_rate(sample_rate: int) -> str:
|
|||||||
return "pcm_16000"
|
return "pcm_16000"
|
||||||
|
|
||||||
|
|
||||||
|
def build_elevenlabs_voice_settings(
|
||||||
|
settings: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Union[float, bool]]]:
|
||||||
|
"""Build voice settings dictionary for ElevenLabs based on provided settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Dictionary containing voice settings parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of voice settings or None if required parameters are missing
|
||||||
|
"""
|
||||||
|
voice_settings = {}
|
||||||
|
if settings["stability"] is not None and settings["similarity_boost"] is not None:
|
||||||
|
voice_settings["stability"] = settings["stability"]
|
||||||
|
voice_settings["similarity_boost"] = settings["similarity_boost"]
|
||||||
|
if settings["style"] is not None:
|
||||||
|
voice_settings["style"] = settings["style"]
|
||||||
|
if settings["use_speaker_boost"] is not None:
|
||||||
|
voice_settings["use_speaker_boost"] = settings["use_speaker_boost"]
|
||||||
|
if settings["speed"] is not None:
|
||||||
|
voice_settings["speed"] = settings["speed"]
|
||||||
|
else:
|
||||||
|
if settings["style"] is not None:
|
||||||
|
logger.warning(
|
||||||
|
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||||
|
)
|
||||||
|
if settings["use_speaker_boost"] is not None:
|
||||||
|
logger.warning(
|
||||||
|
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||||
|
)
|
||||||
|
if settings["speed"] is not None:
|
||||||
|
logger.warning(
|
||||||
|
"'speed' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||||
|
)
|
||||||
|
|
||||||
|
return voice_settings or None
|
||||||
|
|
||||||
|
|
||||||
def calculate_word_times(
|
def calculate_word_times(
|
||||||
alignment_info: Mapping[str, Any], cumulative_time: float
|
alignment_info: Mapping[str, Any], cumulative_time: float
|
||||||
) -> List[Tuple[str, float]]:
|
) -> List[Tuple[str, float]]:
|
||||||
@@ -145,6 +183,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
|||||||
similarity_boost: Optional[float] = None
|
similarity_boost: Optional[float] = None
|
||||||
style: Optional[float] = None
|
style: Optional[float] = None
|
||||||
use_speaker_boost: Optional[bool] = None
|
use_speaker_boost: Optional[bool] = None
|
||||||
|
speed: Optional[float] = None
|
||||||
auto_mode: Optional[bool] = True
|
auto_mode: Optional[bool] = True
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -202,6 +241,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
|||||||
"similarity_boost": params.similarity_boost,
|
"similarity_boost": params.similarity_boost,
|
||||||
"style": params.style,
|
"style": params.style,
|
||||||
"use_speaker_boost": params.use_speaker_boost,
|
"use_speaker_boost": params.use_speaker_boost,
|
||||||
|
"speed": params.speed,
|
||||||
"auto_mode": str(params.auto_mode).lower(),
|
"auto_mode": str(params.auto_mode).lower(),
|
||||||
}
|
}
|
||||||
self.set_model_name(model)
|
self.set_model_name(model)
|
||||||
@@ -224,28 +264,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
|||||||
return language_to_elevenlabs_language(language)
|
return language_to_elevenlabs_language(language)
|
||||||
|
|
||||||
def _set_voice_settings(self):
|
def _set_voice_settings(self):
|
||||||
voice_settings = {}
|
return build_elevenlabs_voice_settings(self._settings)
|
||||||
if (
|
|
||||||
self._settings["stability"] is not None
|
|
||||||
and self._settings["similarity_boost"] is not None
|
|
||||||
):
|
|
||||||
voice_settings["stability"] = self._settings["stability"]
|
|
||||||
voice_settings["similarity_boost"] = self._settings["similarity_boost"]
|
|
||||||
if self._settings["style"] is not None:
|
|
||||||
voice_settings["style"] = self._settings["style"]
|
|
||||||
if self._settings["use_speaker_boost"] is not None:
|
|
||||||
voice_settings["use_speaker_boost"] = self._settings["use_speaker_boost"]
|
|
||||||
else:
|
|
||||||
if self._settings["style"] is not None:
|
|
||||||
logger.warning(
|
|
||||||
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
|
||||||
)
|
|
||||||
if self._settings["use_speaker_boost"] is not None:
|
|
||||||
logger.warning(
|
|
||||||
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
|
||||||
)
|
|
||||||
|
|
||||||
return voice_settings or None
|
|
||||||
|
|
||||||
async def set_model(self, model: str):
|
async def set_model(self, model: str):
|
||||||
await super().set_model(model)
|
await super().set_model(model)
|
||||||
@@ -395,7 +414,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
|||||||
await self._websocket.send(json.dumps(msg))
|
await self._websocket.send(json.dumps(msg))
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self._websocket:
|
if not self._websocket:
|
||||||
@@ -441,6 +460,7 @@ class ElevenLabsHttpTTSService(TTSService):
|
|||||||
similarity_boost: Optional[float] = None
|
similarity_boost: Optional[float] = None
|
||||||
style: Optional[float] = None
|
style: Optional[float] = None
|
||||||
use_speaker_boost: Optional[bool] = None
|
use_speaker_boost: Optional[bool] = None
|
||||||
|
speed: Optional[float] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -470,6 +490,7 @@ class ElevenLabsHttpTTSService(TTSService):
|
|||||||
"similarity_boost": params.similarity_boost,
|
"similarity_boost": params.similarity_boost,
|
||||||
"style": params.style,
|
"style": params.style,
|
||||||
"use_speaker_boost": params.use_speaker_boost,
|
"use_speaker_boost": params.use_speaker_boost,
|
||||||
|
"speed": params.speed,
|
||||||
}
|
}
|
||||||
self.set_model_name(model)
|
self.set_model_name(model)
|
||||||
self.set_voice(voice_id)
|
self.set_voice(voice_id)
|
||||||
@@ -479,34 +500,8 @@ class ElevenLabsHttpTTSService(TTSService):
|
|||||||
def can_generate_metrics(self) -> bool:
|
def can_generate_metrics(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _set_voice_settings(self) -> Optional[Dict[str, Union[float, bool]]]:
|
def _set_voice_settings(self):
|
||||||
"""Configure voice settings if stability and similarity_boost are provided.
|
return build_elevenlabs_voice_settings(self._settings)
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of voice settings or None if required parameters are missing.
|
|
||||||
"""
|
|
||||||
voice_settings: Dict[str, Union[float, bool]] = {}
|
|
||||||
if (
|
|
||||||
self._settings["stability"] is not None
|
|
||||||
and self._settings["similarity_boost"] is not None
|
|
||||||
):
|
|
||||||
voice_settings["stability"] = float(self._settings["stability"])
|
|
||||||
voice_settings["similarity_boost"] = float(self._settings["similarity_boost"])
|
|
||||||
if self._settings["style"] is not None:
|
|
||||||
voice_settings["style"] = float(self._settings["style"])
|
|
||||||
if self._settings["use_speaker_boost"] is not None:
|
|
||||||
voice_settings["use_speaker_boost"] = bool(self._settings["use_speaker_boost"])
|
|
||||||
else:
|
|
||||||
if self._settings["style"] is not None:
|
|
||||||
logger.warning(
|
|
||||||
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
|
||||||
)
|
|
||||||
if self._settings["use_speaker_boost"] is not None:
|
|
||||||
logger.warning(
|
|
||||||
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
|
||||||
)
|
|
||||||
|
|
||||||
return voice_settings or None
|
|
||||||
|
|
||||||
async def start(self, frame: StartFrame):
|
async def start(self, frame: StartFrame):
|
||||||
await super().start(frame)
|
await super().start(frame)
|
||||||
@@ -521,7 +516,7 @@ class ElevenLabsHttpTTSService(TTSService):
|
|||||||
Yields:
|
Yields:
|
||||||
Frames containing audio data and status information
|
Frames containing audio data and status information
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
url = f"{self._base_url}/v1/text-to-speech/{self._voice_id}/stream"
|
url = f"{self._base_url}/v1/text-to-speech/{self._voice_id}/stream"
|
||||||
|
|
||||||
@@ -570,10 +565,12 @@ class ElevenLabsHttpTTSService(TTSService):
|
|||||||
|
|
||||||
await self.start_tts_usage_metrics(text)
|
await self.start_tts_usage_metrics(text)
|
||||||
|
|
||||||
yield TTSStartedFrame()
|
# Process the streaming response
|
||||||
|
CHUNK_SIZE = 1024
|
||||||
|
|
||||||
async for chunk in response.content:
|
yield TTSStartedFrame()
|
||||||
if chunk:
|
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||||
|
if len(chunk) > 0:
|
||||||
await self.stop_ttfb_metrics()
|
await self.stop_ttfb_metrics()
|
||||||
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
|||||||
logger.error(f"Error processing message: {e}")
|
logger.error(f"Error processing message: {e}")
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating Fish TTS: [{text}]")
|
logger.debug(f"{self}: Generating Fish TTS: [{text}]")
|
||||||
try:
|
try:
|
||||||
if not self._websocket or self._websocket.closed:
|
if not self._websocket or self._websocket.closed:
|
||||||
await self._connect()
|
await self._connect()
|
||||||
|
|||||||
@@ -9,12 +9,14 @@ import base64
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
BotStartedSpeakingFrame,
|
BotStartedSpeakingFrame,
|
||||||
BotStoppedSpeakingFrame,
|
BotStoppedSpeakingFrame,
|
||||||
@@ -152,6 +154,9 @@ class InputParams(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiMultimodalLiveLLMService(LLMService):
|
class GeminiMultimodalLiveLLMService(LLMService):
|
||||||
|
# Overriding the default adapter to use the Gemini one.
|
||||||
|
adapter_class = GeminiLLMAdapter
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -162,7 +167,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
|||||||
start_audio_paused: bool = False,
|
start_audio_paused: bool = False,
|
||||||
start_video_paused: bool = False,
|
start_video_paused: bool = False,
|
||||||
system_instruction: Optional[str] = None,
|
system_instruction: Optional[str] = None,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[Union[List[dict], ToolsSchema]] = None,
|
||||||
transcribe_user_audio: bool = False,
|
transcribe_user_audio: bool = False,
|
||||||
transcribe_model_audio: bool = False,
|
transcribe_model_audio: bool = False,
|
||||||
params: InputParams = InputParams(),
|
params: InputParams = InputParams(),
|
||||||
@@ -435,7 +440,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
|||||||
)
|
)
|
||||||
if self._tools:
|
if self._tools:
|
||||||
logger.debug(f"Gemini is configuring to use tools{self._tools}")
|
logger.debug(f"Gemini is configuring to use tools{self._tools}")
|
||||||
config.setup.tools = self._tools
|
config.setup.tools = self.get_llm_adapter().from_standard_tools(self._tools)
|
||||||
await self.send_client_event(config)
|
await self.send_client_event(config)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -701,11 +706,39 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
|||||||
await self.push_frame(TTSStoppedFrame())
|
await self.push_frame(TTSStoppedFrame())
|
||||||
|
|
||||||
def create_context_aggregator(
|
def create_context_aggregator(
|
||||||
self, context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = False
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_kwargs: Mapping[str, Any] = {},
|
||||||
|
assistant_kwargs: Mapping[str, Any] = {},
|
||||||
) -> GeminiMultimodalLiveContextAggregatorPair:
|
) -> GeminiMultimodalLiveContextAggregatorPair:
|
||||||
|
"""Create an instance of GeminiMultimodalLiveContextAggregatorPair from
|
||||||
|
an OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
|
assistant aggregators can be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the user context aggregator constructor. Defaults
|
||||||
|
to an empty mapping.
|
||||||
|
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the assistant context aggregator
|
||||||
|
constructor. Defaults to an empty mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeminiMultimodalLiveContextAggregatorPair: A pair of context
|
||||||
|
aggregators, one for the user and one for the assistant,
|
||||||
|
encapsulated in an GeminiMultimodalLiveContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
|
||||||
GeminiMultimodalLiveContext.upgrade(context)
|
GeminiMultimodalLiveContext.upgrade(context)
|
||||||
user = GeminiMultimodalLiveUserContextAggregator(context)
|
user = GeminiMultimodalLiveUserContextAggregator(context, **user_kwargs)
|
||||||
|
|
||||||
|
default_assistant_kwargs = {"expect_stripped_words": False}
|
||||||
|
default_assistant_kwargs.update(assistant_kwargs)
|
||||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(
|
assistant = GeminiMultimodalLiveAssistantContextAggregator(
|
||||||
context, expect_stripped_words=assistant_expect_stripped_words
|
context, **default_assistant_kwargs
|
||||||
)
|
)
|
||||||
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ class GladiaSTTService(STTService):
|
|||||||
maximum_duration_without_endpointing: Optional[int] = 10
|
maximum_duration_without_endpointing: Optional[int] = 10
|
||||||
audio_enhancer: Optional[bool] = None
|
audio_enhancer: Optional[bool] = None
|
||||||
words_accurate_timestamps: Optional[bool] = None
|
words_accurate_timestamps: Optional[bool] = None
|
||||||
|
speech_threshold: Optional[float] = 0.99
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -148,7 +149,6 @@ class GladiaSTTService(STTService):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||||
|
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._url = url
|
self._url = url
|
||||||
self._settings = {
|
self._settings = {
|
||||||
@@ -166,6 +166,7 @@ class GladiaSTTService(STTService):
|
|||||||
"maximum_duration_without_endpointing": params.maximum_duration_without_endpointing,
|
"maximum_duration_without_endpointing": params.maximum_duration_without_endpointing,
|
||||||
"pre_processing": {
|
"pre_processing": {
|
||||||
"audio_enhancer": params.audio_enhancer,
|
"audio_enhancer": params.audio_enhancer,
|
||||||
|
"speech_threshold": params.speech_threshold,
|
||||||
},
|
},
|
||||||
"realtime_processing": {
|
"realtime_processing": {
|
||||||
"words_accurate_timestamps": params.words_accurate_timestamps,
|
"words_accurate_timestamps": params.words_accurate_timestamps,
|
||||||
|
|||||||
@@ -12,12 +12,16 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from google.api_core.exceptions import DeadlineExceeded
|
from google.api_core.exceptions import DeadlineExceeded
|
||||||
|
from openai import AsyncStream
|
||||||
|
from openai.types.chat import ChatCompletionChunk
|
||||||
|
|
||||||
|
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||||
|
|
||||||
# Suppress gRPC fork warnings
|
# Suppress gRPC fork warnings
|
||||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Union
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -54,7 +58,10 @@ from pipecat.processors.frame_processor import FrameDirection
|
|||||||
from pipecat.services.ai_services import ImageGenService, LLMService, STTService, TTSService
|
from pipecat.services.ai_services import ImageGenService, LLMService, STTService, TTSService
|
||||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||||
from pipecat.services.openai import (
|
from pipecat.services.openai import (
|
||||||
|
BaseOpenAILLMService,
|
||||||
OpenAIAssistantContextAggregator,
|
OpenAIAssistantContextAggregator,
|
||||||
|
OpenAILLMService,
|
||||||
|
OpenAIUnhandledFunctionException,
|
||||||
OpenAIUserContextAggregator,
|
OpenAIUserContextAggregator,
|
||||||
)
|
)
|
||||||
from pipecat.transcriptions.language import Language
|
from pipecat.transcriptions.language import Language
|
||||||
@@ -722,7 +729,9 @@ class GoogleLLMContext(OpenAILLMContext):
|
|||||||
|
|
||||||
self.add_message(glm.Content(role="user", parts=parts))
|
self.add_message(glm.Content(role="user", parts=parts))
|
||||||
|
|
||||||
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
|
def add_audio_frames_message(
|
||||||
|
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||||
|
):
|
||||||
if not audio_frames:
|
if not audio_frames:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -731,8 +740,9 @@ class GoogleLLMContext(OpenAILLMContext):
|
|||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
data = b"".join(frame.audio for frame in audio_frames)
|
data = b"".join(frame.audio for frame in audio_frames)
|
||||||
if text:
|
# NOTE(aleix): According to the docs only text or inline_data should be needed.
|
||||||
parts.append(glm.Part(text=text))
|
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
|
||||||
|
parts.append(glm.Part(text=text))
|
||||||
parts.append(
|
parts.append(
|
||||||
glm.Part(
|
glm.Part(
|
||||||
inline_data=glm.Blob(
|
inline_data=glm.Blob(
|
||||||
@@ -942,6 +952,9 @@ class GoogleLLMService(LLMService):
|
|||||||
franca for all LLM services, so that it is easy to switch between different LLMs.
|
franca for all LLM services, so that it is easy to switch between different LLMs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Overriding the default adapter to use the Gemini one.
|
||||||
|
adapter_class = GeminiLLMAdapter
|
||||||
|
|
||||||
class InputParams(BaseModel):
|
class InputParams(BaseModel):
|
||||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||||
@@ -995,8 +1008,8 @@ class GoogleLLMService(LLMService):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
# f"Generating chat: {self._system_instruction} | {context.get_messages_for_logging()}"
|
# f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]"
|
||||||
f"Generating chat: {context.get_messages_for_logging()}"
|
f"{self}: Generating chat [{context.get_messages_for_logging()}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = context.messages
|
messages = context.messages
|
||||||
@@ -1172,19 +1185,155 @@ class GoogleLLMService(LLMService):
|
|||||||
if context:
|
if context:
|
||||||
await self._process_context(context)
|
await self._process_context(context)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_context_aggregator(
|
def create_context_aggregator(
|
||||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_kwargs: Mapping[str, Any] = {},
|
||||||
|
assistant_kwargs: Mapping[str, Any] = {},
|
||||||
) -> GoogleContextAggregatorPair:
|
) -> GoogleContextAggregatorPair:
|
||||||
|
"""Create an instance of GoogleContextAggregatorPair from an
|
||||||
|
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
|
assistant aggregators can be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the user context aggregator constructor. Defaults
|
||||||
|
to an empty mapping.
|
||||||
|
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the assistant context aggregator
|
||||||
|
constructor. Defaults to an empty mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GoogleContextAggregatorPair: A pair of context aggregators, one for
|
||||||
|
the user and one for the assistant, encapsulated in an
|
||||||
|
GoogleContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
|
||||||
if isinstance(context, OpenAILLMContext):
|
if isinstance(context, OpenAILLMContext):
|
||||||
context = GoogleLLMContext.upgrade_to_google(context)
|
context = GoogleLLMContext.upgrade_to_google(context)
|
||||||
user = GoogleUserContextAggregator(context)
|
user = GoogleUserContextAggregator(context, **user_kwargs)
|
||||||
assistant = GoogleAssistantContextAggregator(
|
assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs)
|
||||||
context, expect_stripped_words=assistant_expect_stripped_words
|
|
||||||
)
|
|
||||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||||
|
"""This class implements inference with Google's AI LLM models using the OpenAI format.
|
||||||
|
Ref - https://ai.google.dev/gemini-api/docs/openai
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
|
model: str = "gemini-2.0-flash",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||||
|
|
||||||
|
async def _process_context(self, context: OpenAILLMContext):
|
||||||
|
functions_list = []
|
||||||
|
arguments_list = []
|
||||||
|
tool_id_list = []
|
||||||
|
func_idx = 0
|
||||||
|
function_name = ""
|
||||||
|
arguments = ""
|
||||||
|
tool_call_id = ""
|
||||||
|
|
||||||
|
await self.start_ttfb_metrics()
|
||||||
|
|
||||||
|
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
|
||||||
|
context
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in chunk_stream:
|
||||||
|
if chunk.usage:
|
||||||
|
tokens = LLMTokenUsage(
|
||||||
|
prompt_tokens=chunk.usage.prompt_tokens,
|
||||||
|
completion_tokens=chunk.usage.completion_tokens,
|
||||||
|
total_tokens=chunk.usage.total_tokens,
|
||||||
|
)
|
||||||
|
await self.start_llm_usage_metrics(tokens)
|
||||||
|
|
||||||
|
if chunk.choices is None or len(chunk.choices) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self.stop_ttfb_metrics()
|
||||||
|
|
||||||
|
if not chunk.choices[0].delta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if chunk.choices[0].delta.tool_calls:
|
||||||
|
# We're streaming the LLM response to enable the fastest response times.
|
||||||
|
# For text, we just yield each chunk as we receive it and count on consumers
|
||||||
|
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||||
|
#
|
||||||
|
# If the LLM is a function call, we'll do some coalescing here.
|
||||||
|
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||||
|
# that they can start preparing to call the function with that name.
|
||||||
|
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||||
|
# the response is done, we package up all the arguments and the function name and
|
||||||
|
# yield a frame containing the function name and the arguments.
|
||||||
|
logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}")
|
||||||
|
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||||
|
if tool_call.index != func_idx:
|
||||||
|
functions_list.append(function_name)
|
||||||
|
arguments_list.append(arguments)
|
||||||
|
tool_id_list.append(tool_call_id)
|
||||||
|
function_name = ""
|
||||||
|
arguments = ""
|
||||||
|
tool_call_id = ""
|
||||||
|
func_idx += 1
|
||||||
|
if tool_call.function and tool_call.function.name:
|
||||||
|
function_name += tool_call.function.name
|
||||||
|
tool_call_id = tool_call.id
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
# Keep iterating through the response to collect all the argument fragments
|
||||||
|
arguments += tool_call.function.arguments
|
||||||
|
elif chunk.choices[0].delta.content:
|
||||||
|
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||||
|
|
||||||
|
# if we got a function name and arguments, check to see if it's a function with
|
||||||
|
# a registered handler. If so, run the registered callback, save the result to
|
||||||
|
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||||
|
# handler, raise an exception.
|
||||||
|
if function_name and arguments:
|
||||||
|
# added to the list as last function name and arguments not added to the list
|
||||||
|
functions_list.append(function_name)
|
||||||
|
arguments_list.append(arguments)
|
||||||
|
tool_id_list.append(tool_call_id)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
|
||||||
|
)
|
||||||
|
for index, (function_name, arguments, tool_id) in enumerate(
|
||||||
|
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||||
|
):
|
||||||
|
if function_name == "":
|
||||||
|
# TODO: Remove the _process_context method once Google resolves the bug
|
||||||
|
# where the index is incorrectly set to None instead of returning the actual index,
|
||||||
|
# which currently results in an empty function name('').
|
||||||
|
continue
|
||||||
|
if self.has_function(function_name):
|
||||||
|
run_llm = False
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
await self.call_function(
|
||||||
|
context=context,
|
||||||
|
function_name=function_name,
|
||||||
|
arguments=arguments,
|
||||||
|
tool_call_id=tool_id,
|
||||||
|
run_llm=run_llm,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise OpenAIUnhandledFunctionException(
|
||||||
|
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GoogleTTSService(TTSService):
|
class GoogleTTSService(TTSService):
|
||||||
class InputParams(BaseModel):
|
class InputParams(BaseModel):
|
||||||
pitch: Optional[str] = None
|
pitch: Optional[str] = None
|
||||||
@@ -1294,7 +1443,7 @@ class GoogleTTSService(TTSService):
|
|||||||
return ssml
|
return ssml
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.start_ttfb_metrics()
|
await self.start_ttfb_metrics()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -206,12 +206,34 @@ class GrokLLMService(OpenAILLMService):
|
|||||||
if tokens.completion_tokens > self._completion_tokens:
|
if tokens.completion_tokens > self._completion_tokens:
|
||||||
self._completion_tokens = tokens.completion_tokens
|
self._completion_tokens = tokens.completion_tokens
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_context_aggregator(
|
def create_context_aggregator(
|
||||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_kwargs: Mapping[str, Any] = {},
|
||||||
|
assistant_kwargs: Mapping[str, Any] = {},
|
||||||
) -> GrokContextAggregatorPair:
|
) -> GrokContextAggregatorPair:
|
||||||
user = OpenAIUserContextAggregator(context)
|
"""Create an instance of GrokContextAggregatorPair from an
|
||||||
assistant = GrokAssistantContextAggregator(
|
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
context, expect_stripped_words=assistant_expect_stripped_words
|
assistant aggregators can be provided.
|
||||||
)
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the user context aggregator constructor. Defaults
|
||||||
|
to an empty mapping.
|
||||||
|
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the assistant context aggregator
|
||||||
|
constructor. Defaults to an empty mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GrokContextAggregatorPair: A pair of context aggregators, one for
|
||||||
|
the user and one for the assistant, encapsulated in an
|
||||||
|
GrokContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
|
||||||
|
user = OpenAIUserContextAggregator(context, **user_kwargs)
|
||||||
|
assistant = GrokAssistantContextAggregator(context, **assistant_kwargs)
|
||||||
return GrokContextAggregatorPair(_user=user, _assistant=assistant)
|
return GrokContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ class LmntTTSService(InterruptibleTTSService):
|
|||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
"""Generate TTS audio from text."""
|
"""Generate TTS audio from text."""
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self._websocket:
|
if not self._websocket:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import base64
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx
|
import httpx
|
||||||
@@ -178,7 +178,7 @@ class BaseOpenAILLMService(LLMService):
|
|||||||
async def _stream_chat_completions(
|
async def _stream_chat_completions(
|
||||||
self, context: OpenAILLMContext
|
self, context: OpenAILLMContext
|
||||||
) -> AsyncStream[ChatCompletionChunk]:
|
) -> AsyncStream[ChatCompletionChunk]:
|
||||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
|
||||||
|
|
||||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||||
|
|
||||||
@@ -343,14 +343,35 @@ class OpenAILLMService(BaseOpenAILLMService):
|
|||||||
):
|
):
|
||||||
super().__init__(model=model, params=params, **kwargs)
|
super().__init__(model=model, params=params, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_context_aggregator(
|
def create_context_aggregator(
|
||||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_kwargs: Mapping[str, Any] = {},
|
||||||
|
assistant_kwargs: Mapping[str, Any] = {},
|
||||||
) -> OpenAIContextAggregatorPair:
|
) -> OpenAIContextAggregatorPair:
|
||||||
user = OpenAIUserContextAggregator(context)
|
"""Create an instance of OpenAIContextAggregatorPair from an
|
||||||
assistant = OpenAIAssistantContextAggregator(
|
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
context, expect_stripped_words=assistant_expect_stripped_words
|
assistant aggregators can be provided.
|
||||||
)
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the user context aggregator constructor. Defaults
|
||||||
|
to an empty mapping.
|
||||||
|
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the assistant context aggregator
|
||||||
|
constructor. Defaults to an empty mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||||
|
the user and one for the assistant, encapsulated in an
|
||||||
|
OpenAIContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
user = OpenAIUserContextAggregator(context, **user_kwargs)
|
||||||
|
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
|
||||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|
||||||
|
|
||||||
@@ -508,7 +529,7 @@ class OpenAITTSService(TTSService):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
try:
|
try:
|
||||||
await self.start_ttfb_metrics()
|
await self.start_ttfb_metrics()
|
||||||
|
|
||||||
@@ -530,8 +551,10 @@ class OpenAITTSService(TTSService):
|
|||||||
|
|
||||||
await self.start_tts_usage_metrics(text)
|
await self.start_tts_usage_metrics(text)
|
||||||
|
|
||||||
|
CHUNK_SIZE = 1024
|
||||||
|
|
||||||
yield TTSStartedFrame()
|
yield TTSStartedFrame()
|
||||||
async for chunk in r.iter_bytes(8192):
|
async for chunk in r.iter_bytes(CHUNK_SIZE):
|
||||||
if len(chunk) > 0:
|
if len(chunk) > 0:
|
||||||
await self.stop_ttfb_metrics()
|
await self.stop_ttfb_metrics()
|
||||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
|
from .azure import AzureRealtimeBetaLLMService
|
||||||
from .events import InputAudioTranscription, SessionProperties, TurnDetection
|
from .events import InputAudioTranscription, SessionProperties, TurnDetection
|
||||||
from .openai import OpenAIRealtimeBetaLLMService
|
from .openai import OpenAIRealtimeBetaLLMService
|
||||||
|
|||||||
64
src/pipecat/services/openai_realtime_beta/azure.py
Normal file
64
src/pipecat/services/openai_realtime_beta/azure.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .openai import OpenAIRealtimeBetaLLMService
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
logger.error(f"Exception: {e}")
|
||||||
|
logger.error(
|
||||||
|
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||||
|
)
|
||||||
|
raise Exception(f"Missing module: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
|
||||||
|
"""Subclass of OpenAI Realtime API Service with adjustments for Azure's wss connection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Constructor takes the same arguments as the parent class, OpenAIRealtimeBetaLLMService.
|
||||||
|
|
||||||
|
Note that the following are required arguments:
|
||||||
|
api_key: The API key for the Azure OpenAI service.
|
||||||
|
base_url: The base URL for the Azure OpenAI service.
|
||||||
|
|
||||||
|
base_url should be set to the full Azure endpoint URL including the api-version and the deployment name. For example,
|
||||||
|
|
||||||
|
wss://my-project.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-realtime-deployment
|
||||||
|
"""
|
||||||
|
super().__init__(base_url=base_url, api_key=api_key, **kwargs)
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def _connect(self):
|
||||||
|
try:
|
||||||
|
if self._websocket:
|
||||||
|
# Here we assume that if we have a websocket, we are connected. We
|
||||||
|
# handle disconnections in the send/recv code paths.
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Connecting to {self.base_url}, api key: {self.api_key}")
|
||||||
|
self._websocket = await websockets.connect(
|
||||||
|
uri=self.base_url,
|
||||||
|
extra_headers={
|
||||||
|
"api-key": self.api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._receive_task = self.create_task(self._receive_task_handler())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self} initialization error: {e}")
|
||||||
|
self._websocket = None
|
||||||
@@ -4,14 +4,16 @@
|
|||||||
# SPDX-License-Identifier: BSD 2-Clause License
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
#
|
#
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import websockets
|
import websockets
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
@@ -76,6 +78,9 @@ class OpenAIUnhandledFunctionException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIRealtimeBetaLLMService(LLMService):
|
class OpenAIRealtimeBetaLLMService(LLMService):
|
||||||
|
# Overriding the default adapter to use the OpenAIRealtimeLLMAdapter one.
|
||||||
|
adapter_class = OpenAIRealtimeLLMAdapter
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -571,11 +576,37 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
|||||||
await self.send_client_event(events.InputAudioBufferAppendEvent(audio=payload))
|
await self.send_client_event(events.InputAudioBufferAppendEvent(audio=payload))
|
||||||
|
|
||||||
def create_context_aggregator(
|
def create_context_aggregator(
|
||||||
self, context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = False
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_kwargs: Mapping[str, Any] = {},
|
||||||
|
assistant_kwargs: Mapping[str, Any] = {},
|
||||||
) -> OpenAIContextAggregatorPair:
|
) -> OpenAIContextAggregatorPair:
|
||||||
|
"""Create an instance of OpenAIContextAggregatorPair from an
|
||||||
|
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
|
assistant aggregators can be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the user context aggregator constructor. Defaults
|
||||||
|
to an empty mapping.
|
||||||
|
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||||
|
arguments for the assistant context aggregator
|
||||||
|
constructor. Defaults to an empty mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||||
|
the user and one for the assistant, encapsulated in an
|
||||||
|
OpenAIContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
|
||||||
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
|
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
|
||||||
user = OpenAIRealtimeUserContextAggregator(context)
|
user = OpenAIRealtimeUserContextAggregator(context, **user_kwargs)
|
||||||
assistant = OpenAIRealtimeAssistantContextAggregator(
|
|
||||||
context, expect_stripped_words=assistant_expect_stripped_words
|
default_assistant_kwargs = {"expect_stripped_words": False}
|
||||||
)
|
default_assistant_kwargs.update(assistant_kwargs)
|
||||||
|
assistant = OpenAIRealtimeAssistantContextAggregator(context, **default_assistant_kwargs)
|
||||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
|||||||
logger.error(f"Invalid JSON message: {message}")
|
logger.error(f"Invalid JSON message: {message}")
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Reconnect if the websocket is closed
|
# Reconnect if the websocket is closed
|
||||||
@@ -323,7 +323,8 @@ class PlayHTHttpTTSService(TTSService):
|
|||||||
api_key: str,
|
api_key: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
voice_url: str,
|
voice_url: str,
|
||||||
voice_engine: str = "Play3.0-mini-http", # Options: Play3.0-mini-http, Play3.0-mini-ws
|
voice_engine: str = "Play3.0-mini",
|
||||||
|
protocol: str = "http", # Options: http, ws
|
||||||
sample_rate: Optional[int] = None,
|
sample_rate: Optional[int] = None,
|
||||||
params: InputParams = InputParams(),
|
params: InputParams = InputParams(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -337,12 +338,24 @@ class PlayHTHttpTTSService(TTSService):
|
|||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
api_key=self._api_key,
|
api_key=self._api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if voice_engine contains protocol information (backward compatibility)
|
||||||
|
if "-http" in voice_engine:
|
||||||
|
# Extract the base engine name
|
||||||
|
voice_engine = voice_engine.replace("-http", "")
|
||||||
|
protocol = "http"
|
||||||
|
elif "-ws" in voice_engine:
|
||||||
|
# Extract the base engine name
|
||||||
|
voice_engine = voice_engine.replace("-ws", "")
|
||||||
|
protocol = "ws"
|
||||||
|
|
||||||
self._settings = {
|
self._settings = {
|
||||||
"language": self.language_to_service_language(params.language)
|
"language": self.language_to_service_language(params.language)
|
||||||
if params.language
|
if params.language
|
||||||
else "english",
|
else "english",
|
||||||
"format": Format.FORMAT_WAV,
|
"format": Format.FORMAT_WAV,
|
||||||
"voice_engine": voice_engine,
|
"voice_engine": voice_engine,
|
||||||
|
"protocol": protocol,
|
||||||
"speed": params.speed,
|
"speed": params.speed,
|
||||||
"seed": params.seed,
|
"seed": params.seed,
|
||||||
}
|
}
|
||||||
@@ -379,23 +392,26 @@ class PlayHTHttpTTSService(TTSService):
|
|||||||
return language_to_playht_language(language)
|
return language_to_playht_language(language)
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
options = self._create_options()
|
options = self._create_options()
|
||||||
b = bytearray()
|
|
||||||
in_header = True
|
|
||||||
|
|
||||||
await self.start_ttfb_metrics()
|
await self.start_ttfb_metrics()
|
||||||
|
|
||||||
playht_gen = self._client.tts(
|
playht_gen = self._client.tts(
|
||||||
text, voice_engine=self._settings["voice_engine"], options=options
|
text,
|
||||||
|
voice_engine=self._settings["voice_engine"],
|
||||||
|
protocol=self._settings["protocol"],
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.start_tts_usage_metrics(text)
|
await self.start_tts_usage_metrics(text)
|
||||||
|
|
||||||
yield TTSStartedFrame()
|
yield TTSStartedFrame()
|
||||||
|
|
||||||
|
b = bytearray()
|
||||||
|
in_header = True
|
||||||
async for chunk in playht_gen:
|
async for chunk in playht_gen:
|
||||||
# skip the RIFF header.
|
# skip the RIFF header.
|
||||||
if in_header:
|
if in_header:
|
||||||
@@ -410,11 +426,10 @@ class PlayHTHttpTTSService(TTSService):
|
|||||||
fh.read(size)
|
fh.read(size)
|
||||||
(data, size) = struct.unpack("<4sI", fh.read(8))
|
(data, size) = struct.unpack("<4sI", fh.read(8))
|
||||||
in_header = False
|
in_header = False
|
||||||
else:
|
elif len(chunk) > 0:
|
||||||
if len(chunk):
|
await self.stop_ttfb_metrics()
|
||||||
await self.stop_ttfb_metrics()
|
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
yield frame
|
||||||
yield frame
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self} error generating TTS: {e}")
|
logger.error(f"{self} error generating TTS: {e}")
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
|||||||
Yields:
|
Yields:
|
||||||
Frames containing audio data and timing information.
|
Frames containing audio data and timing information.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
try:
|
try:
|
||||||
if not self._websocket:
|
if not self._websocket:
|
||||||
await self._connect()
|
await self._connect()
|
||||||
@@ -376,7 +376,7 @@ class RimeHttpTTSService(TTSService):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "audio/pcm",
|
"Accept": "audio/pcm",
|
||||||
@@ -407,10 +407,10 @@ class RimeHttpTTSService(TTSService):
|
|||||||
yield TTSStartedFrame()
|
yield TTSStartedFrame()
|
||||||
|
|
||||||
# Process the streaming response
|
# Process the streaming response
|
||||||
chunk_size = 8192
|
CHUNK_SIZE = 1024
|
||||||
|
|
||||||
async for chunk in response.content.iter_chunked(chunk_size):
|
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||||
if chunk:
|
if len(chunk) > 0:
|
||||||
await self.stop_ttfb_metrics()
|
await self.stop_ttfb_metrics()
|
||||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||||
yield frame
|
yield frame
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class FastPitchTTSService(TTSService):
|
|||||||
await self.start_ttfb_metrics()
|
await self.start_ttfb_metrics()
|
||||||
yield TTSStartedFrame()
|
yield TTSStartedFrame()
|
||||||
|
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class XTTSService(TTSService):
|
|||||||
self._studio_speakers = await r.json()
|
self._studio_speakers = await r.json()
|
||||||
|
|
||||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||||
logger.debug(f"Generating TTS: [{text}]")
|
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||||
|
|
||||||
if not self._studio_speakers:
|
if not self._studio_speakers:
|
||||||
logger.error(f"{self} no studio speakers available")
|
logger.error(f"{self} no studio speakers available")
|
||||||
@@ -150,8 +150,10 @@ class XTTSService(TTSService):
|
|||||||
|
|
||||||
yield TTSStartedFrame()
|
yield TTSStartedFrame()
|
||||||
|
|
||||||
|
CHUNK_SIZE = 1024
|
||||||
|
|
||||||
buffer = bytearray()
|
buffer = bytearray()
|
||||||
async for chunk in r.content.iter_chunked(1024):
|
async for chunk in r.content.iter_chunked(CHUNK_SIZE):
|
||||||
if len(chunk) > 0:
|
if len(chunk) > 0:
|
||||||
await self.stop_ttfb_metrics()
|
await self.stop_ttfb_metrics()
|
||||||
# Append new chunk to the buffer.
|
# Append new chunk to the buffer.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer, EndOfTurnState
|
||||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
BotInterruptionFrame,
|
BotInterruptionFrame,
|
||||||
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
|
|||||||
StartInterruptionFrame,
|
StartInterruptionFrame,
|
||||||
StopInterruptionFrame,
|
StopInterruptionFrame,
|
||||||
SystemFrame,
|
SystemFrame,
|
||||||
|
UserEndOfTurnFrame,
|
||||||
UserStartedSpeakingFrame,
|
UserStartedSpeakingFrame,
|
||||||
UserStoppedSpeakingFrame,
|
UserStoppedSpeakingFrame,
|
||||||
VADParamsUpdateFrame,
|
VADParamsUpdateFrame,
|
||||||
@@ -64,12 +66,19 @@ class BaseInputTransport(FrameProcessor):
|
|||||||
def vad_analyzer(self) -> Optional[VADAnalyzer]:
|
def vad_analyzer(self) -> Optional[VADAnalyzer]:
|
||||||
return self._params.vad_analyzer
|
return self._params.vad_analyzer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_of_turn_analyzer(self) -> Optional[BaseEndOfTurnAnalyzer]:
|
||||||
|
return self._params.end_of_turn_analyzer
|
||||||
|
|
||||||
async def start(self, frame: StartFrame):
|
async def start(self, frame: StartFrame):
|
||||||
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
|
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
|
||||||
|
|
||||||
# Configure VAD analyzer.
|
# Configure VAD analyzer.
|
||||||
if self._params.vad_enabled and self._params.vad_analyzer:
|
if self._params.vad_enabled and self._params.vad_analyzer:
|
||||||
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
|
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
|
||||||
|
# Configure End of turn analyzer.
|
||||||
|
if self._params.end_of_turn_analyzer:
|
||||||
|
self._params.end_of_turn_analyzer.set_sample_rate(self._sample_rate)
|
||||||
# Start audio filter.
|
# Start audio filter.
|
||||||
if self._params.audio_in_filter:
|
if self._params.audio_in_filter:
|
||||||
await self._params.audio_in_filter.start(self._sample_rate)
|
await self._params.audio_in_filter.start(self._sample_rate)
|
||||||
@@ -198,8 +207,25 @@ class BaseInputTransport(FrameProcessor):
|
|||||||
vad_state = new_vad_state
|
vad_state = new_vad_state
|
||||||
return vad_state
|
return vad_state
|
||||||
|
|
||||||
|
async def _end_of_turn_analyze(self, audio_frame: InputAudioRawFrame) -> EndOfTurnState:
|
||||||
|
state = EndOfTurnState.INCOMPLETE
|
||||||
|
if self.end_of_turn_analyzer:
|
||||||
|
state = await self.get_event_loop().run_in_executor(
|
||||||
|
self._executor, self.end_of_turn_analyzer.analyze_audio, audio_frame.audio
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
async def _handle_end_of_turn(
|
||||||
|
self, audio_frame: InputAudioRawFrame, end_of_turn_state: EndOfTurnState
|
||||||
|
):
|
||||||
|
new_eot_state = await self._end_of_turn_analyze(audio_frame)
|
||||||
|
if new_eot_state != end_of_turn_state:
|
||||||
|
await self.push_frame(UserEndOfTurnFrame())
|
||||||
|
return new_eot_state
|
||||||
|
|
||||||
async def _audio_task_handler(self):
|
async def _audio_task_handler(self):
|
||||||
vad_state: VADState = VADState.QUIET
|
vad_state: VADState = VADState.QUIET
|
||||||
|
end_of_turn_state: EndOfTurnState = EndOfTurnState.INCOMPLETE
|
||||||
while True:
|
while True:
|
||||||
frame: InputAudioRawFrame = await self._audio_in_queue.get()
|
frame: InputAudioRawFrame = await self._audio_in_queue.get()
|
||||||
|
|
||||||
@@ -215,6 +241,9 @@ class BaseInputTransport(FrameProcessor):
|
|||||||
vad_state = await self._handle_vad(frame, vad_state)
|
vad_state = await self._handle_vad(frame, vad_state)
|
||||||
audio_passthrough = self._params.vad_audio_passthrough
|
audio_passthrough = self._params.vad_audio_passthrough
|
||||||
|
|
||||||
|
if self._params.end_of_turn_analyzer:
|
||||||
|
end_of_turn_state = await self._handle_end_of_turn(frame, end_of_turn_state)
|
||||||
|
|
||||||
# Push audio downstream if passthrough.
|
# Push audio downstream if passthrough.
|
||||||
if audio_passthrough:
|
if audio_passthrough:
|
||||||
await self.push_frame(frame)
|
await self.push_frame(frame)
|
||||||
|
|||||||
@@ -232,6 +232,9 @@ class BaseOutputTransport(FrameProcessor):
|
|||||||
await self.push_frame(BotStoppedSpeakingFrame())
|
await self.push_frame(BotStoppedSpeakingFrame())
|
||||||
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||||
self._bot_speaking = False
|
self._bot_speaking = False
|
||||||
|
# Clean audio buffer (there could be tiny left overs if not multiple
|
||||||
|
# to our output chunk size).
|
||||||
|
self._audio_buffer = bytearray()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Sink tasks
|
# Sink tasks
|
||||||
|
|||||||
@@ -4,18 +4,17 @@
|
|||||||
# SPDX-License-Identifier: BSD 2-Clause License
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
#
|
#
|
||||||
|
|
||||||
import inspect
|
from abc import abstractmethod
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||||
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
|
from pipecat.audio.mixers.base_audio_mixer import BaseAudioMixer
|
||||||
|
from pipecat.audio.turn.base_turn_analyzer import BaseEndOfTurnAnalyzer
|
||||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
||||||
from pipecat.processors.frame_processor import FrameProcessor
|
from pipecat.processors.frame_processor import FrameProcessor
|
||||||
from pipecat.utils.utils import obj_count, obj_id
|
from pipecat.utils.base_object import BaseObject
|
||||||
|
|
||||||
|
|
||||||
class TransportParams(BaseModel):
|
class TransportParams(BaseModel):
|
||||||
@@ -41,9 +40,10 @@ class TransportParams(BaseModel):
|
|||||||
vad_enabled: bool = False
|
vad_enabled: bool = False
|
||||||
vad_audio_passthrough: bool = False
|
vad_audio_passthrough: bool = False
|
||||||
vad_analyzer: Optional[VADAnalyzer] = None
|
vad_analyzer: Optional[VADAnalyzer] = None
|
||||||
|
end_of_turn_analyzer: Optional[BaseEndOfTurnAnalyzer] = None
|
||||||
|
|
||||||
|
|
||||||
class BaseTransport(ABC):
|
class BaseTransport(BaseObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -51,54 +51,14 @@ class BaseTransport(ABC):
|
|||||||
input_name: Optional[str] = None,
|
input_name: Optional[str] = None,
|
||||||
output_name: Optional[str] = None,
|
output_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self._id: int = obj_id()
|
super().__init__(name=name)
|
||||||
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
|
||||||
self._input_name = input_name
|
self._input_name = input_name
|
||||||
self._output_name = output_name
|
self._output_name = output_name
|
||||||
self._event_handlers: dict = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> int:
|
|
||||||
return self._id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def input(self) -> FrameProcessor:
|
def input(self) -> FrameProcessor:
|
||||||
raise NotImplementedError
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def output(self) -> FrameProcessor:
|
def output(self) -> FrameProcessor:
|
||||||
raise NotImplementedError
|
pass
|
||||||
|
|
||||||
def event_handler(self, event_name: str):
|
|
||||||
def decorator(handler):
|
|
||||||
self.add_event_handler(event_name, handler)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def add_event_handler(self, event_name: str, handler):
|
|
||||||
if event_name not in self._event_handlers:
|
|
||||||
raise Exception(f"Event handler {event_name} not registered")
|
|
||||||
self._event_handlers[event_name].append(handler)
|
|
||||||
|
|
||||||
def _register_event_handler(self, event_name: str):
|
|
||||||
if event_name in self._event_handlers:
|
|
||||||
raise Exception(f"Event handler {event_name} already registered")
|
|
||||||
self._event_handlers[event_name] = []
|
|
||||||
|
|
||||||
async def _call_event_handler(self, event_name: str, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
for handler in self._event_handlers[event_name]:
|
|
||||||
if inspect.iscoroutinefunction(handler):
|
|
||||||
await handler(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
handler(self, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.name
|
|
||||||
|
|||||||
@@ -195,6 +195,10 @@ class DailyMeetingTokenProperties(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Start cloud recording when the user joins the room. This can be used to always record and archive meetings, for example in a customer support context.",
|
description="Start cloud recording when the user joins the room. This can be used to always record and archive meetings, for example in a customer support context.",
|
||||||
)
|
)
|
||||||
|
permissions: Optional[dict] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Specifies the initial default permissions for a non-meeting-owner participant joining a call.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DailyMeetingTokenParams(BaseModel):
|
class DailyMeetingTokenParams(BaseModel):
|
||||||
|
|||||||
58
src/pipecat/utils/base_object.py
Normal file
58
src/pipecat/utils/base_object.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from pipecat.utils.utils import obj_count, obj_id
|
||||||
|
|
||||||
|
|
||||||
|
class BaseObject(ABC):
|
||||||
|
def __init__(self, *, name: Optional[str] = None):
|
||||||
|
self._id: int = obj_id()
|
||||||
|
self._name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||||
|
self._event_handlers: dict = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> int:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def event_handler(self, event_name: str):
|
||||||
|
def decorator(handler):
|
||||||
|
self.add_event_handler(event_name, handler)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def add_event_handler(self, event_name: str, handler):
|
||||||
|
if event_name not in self._event_handlers:
|
||||||
|
raise Exception(f"Event handler {event_name} not registered")
|
||||||
|
self._event_handlers[event_name].append(handler)
|
||||||
|
|
||||||
|
def _register_event_handler(self, event_name: str):
|
||||||
|
if event_name in self._event_handlers:
|
||||||
|
raise Exception(f"Event handler {event_name} already registered")
|
||||||
|
self._event_handlers[event_name] = []
|
||||||
|
|
||||||
|
async def _call_event_handler(self, event_name: str, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
for handler in self._event_handlers[event_name]:
|
||||||
|
if inspect.iscoroutinefunction(handler):
|
||||||
|
await handler(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
handler(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception in event handler {event_name}: {e}")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
@@ -8,6 +8,8 @@ class TestException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class TestFrameProcessor(FrameProcessor):
|
class TestFrameProcessor(FrameProcessor):
|
||||||
|
__test__ = False # Prevents pytest from collecting this class as a test
|
||||||
|
|
||||||
def __init__(self, test_frames):
|
def __init__(self, test_frames):
|
||||||
self.test_frames = test_frames
|
self.test_frames = test_frames
|
||||||
self._list_counter = 0
|
self._list_counter = 0
|
||||||
|
|||||||
@@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
from pipecat.frames.frames import (
|
||||||
|
LLMFullResponseEndFrame,
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
)
|
||||||
|
from pipecat.processors.frame_processor import FrameDirection
|
||||||
|
from pipecat.services.ai_services import LLMService
|
||||||
|
from pipecat.services.anthropic import AnthropicLLMService
|
||||||
|
from pipecat.services.google import GoogleLLMService
|
||||||
|
from pipecat.services.openai import OpenAILLMContext, OpenAILLMContextFrame, OpenAILLMService
|
||||||
|
from pipecat.utils.test_frame_processor import TestFrameProcessor
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
def standard_tools() -> ToolsSchema:
|
||||||
|
weather_function = FunctionSchema(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
properties={
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required=["location"],
|
||||||
|
)
|
||||||
|
tools_def = ToolsSchema(standard_tools=[weather_function])
|
||||||
|
return tools_def
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_llm_function_calling(llm: LLMService):
|
||||||
|
# Create an AsyncMock for the function
|
||||||
|
mock_fetch_weather = AsyncMock()
|
||||||
|
|
||||||
|
llm.register_function(None, mock_fetch_weather)
|
||||||
|
t = TestFrameProcessor([LLMFullResponseStartFrame, LLMTextFrame, LLMFullResponseEndFrame])
|
||||||
|
llm.link(t)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant who can report the weather in any location in the universe. Respond concisely. Your response will be turned into speech so use only simple words and punctuation.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": " How is the weather today in San Francisco, California?"},
|
||||||
|
]
|
||||||
|
context = OpenAILLMContext(messages, standard_tools())
|
||||||
|
# This is done by default inside the create_context_aggregator
|
||||||
|
context.set_llm_adapter(llm.get_llm_adapter())
|
||||||
|
|
||||||
|
frame = OpenAILLMContextFrame(context)
|
||||||
|
|
||||||
|
# This will fail if an exception is raised
|
||||||
|
await llm.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||||
|
|
||||||
|
# Assert that the mock function was called
|
||||||
|
mock_fetch_weather.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_function_calling_openai():
|
||||||
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||||
|
# This will fail if an exception is raised
|
||||||
|
await _test_llm_function_calling(llm)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY is not set")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_function_calling_gemini():
|
||||||
|
llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"), model="gemini-2.0-flash-001")
|
||||||
|
# This will fail if an exception is raised
|
||||||
|
await _test_llm_function_calling(llm)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("ANTHROPIC_API_KEY") is None, reason="ANTHROPIC_API_KEY is not set")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_function_calling_anthropic():
|
||||||
|
llm = AnthropicLLMService(
|
||||||
|
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20240620"
|
||||||
|
)
|
||||||
|
# This will fail if an exception is raised
|
||||||
|
await _test_llm_function_calling(llm)
|
||||||
176
tests/test_function_calling_adapters.py
Normal file
176
tests/test_function_calling_adapters.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||||
|
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||||
|
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||||
|
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||||
|
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||||
|
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class TestFunctionAdapters(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
"""Sets up a common tools schema for all tests."""
|
||||||
|
function_def = FunctionSchema(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get the weather in a given location",
|
||||||
|
properties={
|
||||||
|
"location": {"type": "string", "description": "The city, e.g. San Francisco"},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required=["location", "format"],
|
||||||
|
)
|
||||||
|
self.tools_def = ToolsSchema(standard_tools=[function_def])
|
||||||
|
|
||||||
|
def test_openai_adapter(self):
|
||||||
|
"""Test OpenAI adapter format transformation."""
|
||||||
|
expected = [
|
||||||
|
ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function={
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city, e.g. San Francisco",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert OpenAILLMAdapter().to_provider_tools_format(self.tools_def) == expected
|
||||||
|
|
||||||
|
def test_anthropic_adapter(self):
|
||||||
|
"""Test Anthropic adapter format transformation."""
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given location",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city, e.g. San Francisco",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert AnthropicLLMAdapter().to_provider_tools_format(self.tools_def) == expected
|
||||||
|
|
||||||
|
def test_gemini_adapter(self):
|
||||||
|
"""Test Gemini adapter format transformation."""
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"function_declarations": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city, e.g. San Francisco",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert GeminiLLMAdapter().to_provider_tools_format(self.tools_def) == expected
|
||||||
|
|
||||||
|
def test_openai_realtime_adapter(self):
|
||||||
|
"""Test Anthropic adapter format transformation."""
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city, e.g. San Francisco",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert OpenAIRealtimeLLMAdapter().to_provider_tools_format(self.tools_def) == expected
|
||||||
|
|
||||||
|
def test_gemini_adapter_with_custom_tools(self):
|
||||||
|
"""Test Gemini adapter format transformation."""
|
||||||
|
search_tool = {"google_search": {}}
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"function_declarations": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city, e.g. San Francisco",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
search_tool,
|
||||||
|
]
|
||||||
|
tools_def = self.tools_def
|
||||||
|
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
|
||||||
|
assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected
|
||||||
136
tests/test_llm_response.py
Normal file
136
tests/test_llm_response.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024-2025 Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from pipecat.frames.frames import (
|
||||||
|
LLMFullResponseEndFrame,
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
StartInterruptionFrame,
|
||||||
|
)
|
||||||
|
from pipecat.processors.aggregators.llm_response import LLMFullResponseAggregator
|
||||||
|
from pipecat.tests.utils import SleepFrame, run_test
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_empty(self):
|
||||||
|
completion_ok = False
|
||||||
|
|
||||||
|
aggregator = LLMFullResponseAggregator()
|
||||||
|
|
||||||
|
@aggregator.event_handler("on_completion")
|
||||||
|
async def on_completion(aggregator, completion, completed):
|
||||||
|
nonlocal completion_ok
|
||||||
|
completion_ok = completion == "" and completed
|
||||||
|
|
||||||
|
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
|
||||||
|
expected_down_frames = [LLMFullResponseStartFrame, LLMFullResponseEndFrame]
|
||||||
|
await run_test(
|
||||||
|
aggregator,
|
||||||
|
frames_to_send=frames_to_send,
|
||||||
|
expected_down_frames=expected_down_frames,
|
||||||
|
)
|
||||||
|
assert completion_ok
|
||||||
|
|
||||||
|
async def test_simple(self):
|
||||||
|
completion_ok = False
|
||||||
|
|
||||||
|
aggregator = LLMFullResponseAggregator()
|
||||||
|
|
||||||
|
@aggregator.event_handler("on_completion")
|
||||||
|
async def on_completion(aggregator, completion, completed):
|
||||||
|
nonlocal completion_ok
|
||||||
|
completion_ok = completion == "Hello from Pipecat!" and completed
|
||||||
|
|
||||||
|
frames_to_send = [
|
||||||
|
LLMFullResponseStartFrame(),
|
||||||
|
LLMTextFrame("Hello from Pipecat!"),
|
||||||
|
LLMFullResponseEndFrame(),
|
||||||
|
]
|
||||||
|
expected_down_frames = [LLMFullResponseStartFrame, LLMTextFrame, LLMFullResponseEndFrame]
|
||||||
|
await run_test(
|
||||||
|
aggregator,
|
||||||
|
frames_to_send=frames_to_send,
|
||||||
|
expected_down_frames=expected_down_frames,
|
||||||
|
)
|
||||||
|
assert completion_ok
|
||||||
|
|
||||||
|
async def test_multiple(self):
|
||||||
|
completion_ok = False
|
||||||
|
|
||||||
|
aggregator = LLMFullResponseAggregator()
|
||||||
|
|
||||||
|
@aggregator.event_handler("on_completion")
|
||||||
|
async def on_completion(aggregator, completion, completed):
|
||||||
|
nonlocal completion_ok
|
||||||
|
completion_ok = completion == "Hello from Pipecat!" and completed
|
||||||
|
|
||||||
|
frames_to_send = [
|
||||||
|
LLMFullResponseStartFrame(),
|
||||||
|
LLMTextFrame("Hello "),
|
||||||
|
LLMTextFrame("from "),
|
||||||
|
LLMTextFrame("Pipecat!"),
|
||||||
|
LLMFullResponseEndFrame(),
|
||||||
|
]
|
||||||
|
expected_down_frames = [
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
LLMFullResponseEndFrame,
|
||||||
|
]
|
||||||
|
await run_test(
|
||||||
|
aggregator,
|
||||||
|
frames_to_send=frames_to_send,
|
||||||
|
expected_down_frames=expected_down_frames,
|
||||||
|
)
|
||||||
|
assert completion_ok
|
||||||
|
|
||||||
|
async def test_interruption(self):
|
||||||
|
completion_ok = True
|
||||||
|
|
||||||
|
completion_result = [("Hello ", False), ("Hello there!", True)]
|
||||||
|
completion_index = 0
|
||||||
|
|
||||||
|
aggregator = LLMFullResponseAggregator()
|
||||||
|
|
||||||
|
@aggregator.event_handler("on_completion")
|
||||||
|
async def on_completion(aggregator, completion, completed):
|
||||||
|
nonlocal completion_result, completion_index, completion_ok
|
||||||
|
(completion_expected, completion_completed) = completion_result[completion_index]
|
||||||
|
completion_ok = (
|
||||||
|
completion_ok
|
||||||
|
and completion == completion_expected
|
||||||
|
and completed == completion_completed
|
||||||
|
)
|
||||||
|
completion_index += 1
|
||||||
|
|
||||||
|
frames_to_send = [
|
||||||
|
LLMFullResponseStartFrame(),
|
||||||
|
LLMTextFrame("Hello "),
|
||||||
|
SleepFrame(),
|
||||||
|
StartInterruptionFrame(),
|
||||||
|
LLMFullResponseStartFrame(),
|
||||||
|
LLMTextFrame("Hello "),
|
||||||
|
LLMTextFrame("there!"),
|
||||||
|
LLMFullResponseEndFrame(),
|
||||||
|
]
|
||||||
|
expected_down_frames = [
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
StartInterruptionFrame,
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
LLMFullResponseEndFrame,
|
||||||
|
]
|
||||||
|
await run_test(
|
||||||
|
aggregator,
|
||||||
|
frames_to_send=frames_to_send,
|
||||||
|
expected_down_frames=expected_down_frames,
|
||||||
|
)
|
||||||
|
assert completion_ok
|
||||||
Reference in New Issue
Block a user