Remove deprecated OpenAILLMContext as well as everything (code paths or whole types) dependent on it (all of which were also deprecated)
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.1
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: ruff
|
||||
language_version: python3
|
||||
args: [--fix]
|
||||
name: ruff
|
||||
entry: uv run ruff check --fix
|
||||
language: system
|
||||
types: [python]
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
entry: uv run ruff format
|
||||
language: system
|
||||
types: [python]
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.google.openai.llm import GoogleLLMOpenAIBetaService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY", ""),
|
||||
settings=ElevenLabsTTSService.Settings(
|
||||
voice=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||
),
|
||||
)
|
||||
|
||||
llm = GoogleLLMOpenAIBetaService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
settings=GoogleLLMOpenAIBetaService.Settings(
|
||||
system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.",
|
||||
),
|
||||
)
|
||||
# You can aslo register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
messages = [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": "Start a conversation with 'Hey there' to get the current weather.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -1,219 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeBetaLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
input_audio_noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeBetaLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
transcript = TranscriptProcessor()
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeBetaLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "developer", "content": "Say hello!"}],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transcript.user(), # Placed after the LLM, as LLM pushes TranscriptionFrames downstream
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # After the transcript output, to time with the audio output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, TranscriptionMessage):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}"
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -1,214 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
AzureRealtimeBetaLLMService,
|
||||
InputAudioTranscription,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# Define weather function using standardized schema
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(model="whisper-1"),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
# turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
-
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
llm = AzureRealtimeBetaLLMService(
|
||||
api_key=os.getenv("AZURE_REALTIME_API_KEY"),
|
||||
base_url=os.getenv("AZURE_REALTIME_BASE_URL"),
|
||||
session_properties=session_properties,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeBetaLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "developer", "content": "Say hello!"}],
|
||||
# [{"role": "developer", "content": [{"type": "text", "text": "Say hello!"}]}],
|
||||
# [
|
||||
# {
|
||||
# "role": "developer",
|
||||
# "content": [
|
||||
# {"type": "text", "text": "Say"},
|
||||
# {"type": "text", "text": "yo what's up!"},
|
||||
# ],
|
||||
# }
|
||||
# ],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -1,215 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeBetaLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
modalities=["text"],
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
input_audio_noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeBetaLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
),
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeBetaLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "developer", "content": "Say hello!"}],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -1,267 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeBetaLLMService,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
BASE_FILENAME = "/tmp/pipecat_conversation_"
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(params: FunctionCallParams):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
full_pattern = f"{BASE_FILENAME}*.json"
|
||||
|
||||
# Use glob to find all matching files
|
||||
matching_files = glob.glob(full_pattern)
|
||||
logger.debug(f"matching files: {matching_files}")
|
||||
|
||||
await params.result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.messages, indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
await params.result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await params.result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(params: FunctionCallParams):
|
||||
async def _reset():
|
||||
filename = params.arguments["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
params.context.set_messages(json.load(file))
|
||||
await params.llm.reset_conversation()
|
||||
await params.llm._create_response()
|
||||
except Exception as e:
|
||||
await params.result_callback({"success": False, "error": str(e)})
|
||||
|
||||
asyncio.create_task(_reset())
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "save_conversation",
|
||||
"description": "Save the current conversation. Use this function to persist the current conversation to external storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_saved_conversation_filenames",
|
||||
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "load_conversation",
|
||||
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn
|
||||
# it on by default
|
||||
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
# tools=tools,
|
||||
instructions="""Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
-
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeBetaLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext([], tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Color codes for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo "🔍 Running pre-commit checks..."
|
||||
|
||||
# Change to project root (one level up from scripts/)
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Format check
|
||||
echo "📝 Checking code formatting..."
|
||||
if ! NO_COLOR=1 uv run ruff format --diff --check; then
|
||||
echo -e "${RED}❌ Code formatting issues found. Run 'uv run ruff format' to fix.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Lint check
|
||||
echo "🔍 Running linter..."
|
||||
if ! uv run ruff check; then
|
||||
echo -e "${RED}❌ Linting issues found.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✅ All pre-commit checks passed!${NC}"
|
||||
@@ -22,12 +22,9 @@ class AdapterType(Enum):
|
||||
|
||||
Parameters:
|
||||
GEMINI: Google Gemini adapter - currently the only service supporting custom tools.
|
||||
SHIM: Backward compatibility shim for creating ToolsSchemas from lists of tools in
|
||||
any format, used by LLMContext.from_openai_context.
|
||||
"""
|
||||
|
||||
GEMINI = "gemini" # that is the only service where we are able to add custom tools for now
|
||||
SHIM = "shim" # for use as backward compatibility shim for creating ToolsSchemas from list of tools in any format
|
||||
|
||||
|
||||
class ToolsSchema:
|
||||
|
||||
@@ -222,18 +222,4 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
List of dictionaries in AWS Nova Sonic function format.
|
||||
"""
|
||||
functions_schema = tools_schema.standard_tools
|
||||
standard_tools = [
|
||||
self._to_aws_nova_sonic_function_format(func) for func in functions_schema
|
||||
]
|
||||
|
||||
# For backward compatibility, AWS Nova Sonic can still be used with
|
||||
# tools in dict format, even though it always uses `LLMContext` under
|
||||
# the hood (via `LLMContext.from_openai_context()`).
|
||||
# To support this behavior, we use "shimmed" custom tools here.
|
||||
# (We maintain this backward compatibility because users aren't
|
||||
# *knowingly* opting into the new `LLMContext`.)
|
||||
shimmed_tools = []
|
||||
if tools_schema.custom_tools:
|
||||
shimmed_tools = tools_schema.custom_tools.get(AdapterType.SHIM, [])
|
||||
|
||||
return standard_tools + shimmed_tools
|
||||
return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema]
|
||||
|
||||
@@ -256,11 +256,4 @@ class GrokRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
"""
|
||||
# Convert standard function tools
|
||||
functions_schema = tools_schema.standard_tools
|
||||
standard_tools = [self._to_grok_function_format(func) for func in functions_schema]
|
||||
|
||||
# Support shimmed custom tools for backward compatibility
|
||||
shimmed_tools = []
|
||||
if tools_schema.custom_tools:
|
||||
shimmed_tools = tools_schema.custom_tools.get(AdapterType.SHIM, [])
|
||||
|
||||
return standard_tools + shimmed_tools
|
||||
return [self._to_grok_function_format(func) for func in functions_schema]
|
||||
|
||||
@@ -236,18 +236,4 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
List of function definitions in OpenAI Realtime format.
|
||||
"""
|
||||
functions_schema = tools_schema.standard_tools
|
||||
standard_tools = [
|
||||
self._to_openai_realtime_function_format(func) for func in functions_schema
|
||||
]
|
||||
|
||||
# For backward compatibility, OpenAI Realtime can still be used with
|
||||
# tools in dict format, even though it always uses `LLMContext` under
|
||||
# the hood (via `LLMContext.from_openai_context()`).
|
||||
# To support this behavior, we use "shimmed" custom tools here.
|
||||
# (We maintain this backward compatibility because users aren't
|
||||
# *knowingly* opting into the new `LLMContext`.)
|
||||
shimmed_tools = []
|
||||
if tools_schema.custom_tools:
|
||||
shimmed_tools = tools_schema.custom_tools.get(AdapterType.SHIM, [])
|
||||
|
||||
return standard_tools + shimmed_tools
|
||||
return [self._to_openai_realtime_function_format(func) for func in functions_schema]
|
||||
|
||||
@@ -31,7 +31,6 @@ from pipecat.frames.frames import (
|
||||
VADParamsUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
@@ -444,7 +443,7 @@ Remember: Respond with `<dtmf>NUMBER</dtmf>` (single or multiple for sequences),
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
if isinstance(frame, (OpenAILLMContextFrame, LLMContextFrame)):
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
# Extract messages and pass to IVR processor
|
||||
all_messages = frame.context.get_messages()
|
||||
|
||||
|
||||
@@ -451,36 +451,6 @@ class TranslationFrame(TextFrame):
|
||||
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAILLMContextAssistantTimestampFrame(DataFrame):
|
||||
"""Timestamp information for assistant messages in LLM context.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAILLMContextAssistantTimestampFrame` is deprecated and will be removed in a future version.
|
||||
Use `LLMContextAssistantTimestampFrame` with the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
timestamp: Timestamp when the assistant message was created.
|
||||
"""
|
||||
|
||||
timestamp: str
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"OpenAILLMContextAssistantTimestampFrame is deprecated and will be removed in a future version. "
|
||||
"Use LLMContextAssistantTimestampFrame with the universal LLMContext and LLMContextAggregatorPair instead. "
|
||||
"See OpenAILLMContext docstring for migration guide.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextAssistantTimestampFrame(DataFrame):
|
||||
"""Timestamp information for assistant messages in LLM context.
|
||||
@@ -706,44 +676,6 @@ class LLMThoughtEndFrame(ControlFrame):
|
||||
return f"{self.name}(pts: {pts}, signature: {self.signature})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesFrame(DataFrame):
|
||||
"""Frame containing LLM messages for chat completion.
|
||||
|
||||
.. deprecated:: 0.0.79
|
||||
This class is deprecated and will be removed in a future version.
|
||||
Instead, use either:
|
||||
- `LLMMessagesUpdateFrame` with `run_llm=True`
|
||||
- `OpenAILLMContextFrame` with desired messages in a new context
|
||||
|
||||
A frame containing a list of LLM messages. Used to signal that an LLM
|
||||
service should run a chat completion and emit an LLMFullResponseStartFrame,
|
||||
TextFrames and an LLMFullResponseEndFrame. Note that the `messages`
|
||||
property in this class is mutable, and will be updated by various
|
||||
aggregators.
|
||||
|
||||
Parameters:
|
||||
messages: List of message dictionaries in LLM format.
|
||||
"""
|
||||
|
||||
messages: List[dict]
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"LLMMessagesFrame is deprecated and will be removed in a future version. "
|
||||
"Instead, use either "
|
||||
"`LLMMessagesUpdateFrame` with `run_llm=True`, or "
|
||||
"`OpenAILLMContextFrame` with desired messages in a new context",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMRunFrame(DataFrame):
|
||||
"""Frame to trigger LLM processing with current context.
|
||||
|
||||
@@ -14,11 +14,9 @@ from pipecat.frames.frames import (
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
@@ -32,8 +30,6 @@ class LLMLogObserver(BaseObserver):
|
||||
- LLMFullResponseEndFrame
|
||||
- LLMTextFrame
|
||||
- FunctionCallInProgressFrame
|
||||
- LLMMessagesFrame
|
||||
- OpenAILLMContextFrame
|
||||
|
||||
This allows you to track when the LLM starts responding, what it generates,
|
||||
and when it finishes.
|
||||
@@ -74,18 +70,9 @@ class LLMLogObserver(BaseObserver):
|
||||
logger.debug(
|
||||
f"🧠 {src} {arrow} LLM FUNCTION CALL ({frame.tool_call_id}): {frame.function_name!r}({frame.arguments}) at {time_sec:.2f}s"
|
||||
)
|
||||
# Log LLMMessagesFrame (input)
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
logger.debug(
|
||||
f"🧠 {arrow} {dst} LLM MESSAGES FRAME: {frame.messages} at {time_sec:.2f}s"
|
||||
)
|
||||
# Log OpenAILLMContextFrame (input)
|
||||
elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
messages = (
|
||||
frame.context.messages
|
||||
if isinstance(frame, OpenAILLMContextFrame)
|
||||
else frame.context.get_messages()
|
||||
)
|
||||
# Log LLMContextFrame (input)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
messages = frame.context.get_messages()
|
||||
logger.debug(f"🧠 {arrow} {dst} LLM CONTEXT FRAME: {messages} at {time_sec:.2f}s")
|
||||
# Log function call result (input)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
|
||||
@@ -48,7 +48,6 @@ from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.base_task import BasePipelineTask, PipelineTaskParams
|
||||
from pipecat.pipeline.pipeline import Pipeline, PipelineSink, PipelineSource
|
||||
from pipecat.pipeline.task_observer import TaskObserver
|
||||
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIObserverParams, RTVIProcessor
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager, TaskManager, TaskManagerParams
|
||||
@@ -1028,10 +1027,6 @@ class PipelineTask(BasePipelineTask):
|
||||
"""Build and return start metadata including user-provided values."""
|
||||
start_metadata = {}
|
||||
|
||||
# NOTE(aleix): Remove when OpenAILLMContext/LLMUserContextAggregator is removed.
|
||||
if self._find_processor(self._pipeline, LLMUserContextAggregator):
|
||||
start_metadata["deprecated_openaillmcontext"] = True
|
||||
|
||||
# Update with user provided metadata.
|
||||
start_metadata.update(self._params.start_metadata)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
"""Gated LLM context aggregator for controlled message flow."""
|
||||
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
@@ -49,7 +48,7 @@ class GatedLLMContextAggregator(FrameProcessor):
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
await self.push_frame(frame)
|
||||
elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
if self._start_open:
|
||||
self._start_open = False
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gated OpenAI LLM context aggregator for controlled message flow."""
|
||||
|
||||
from pipecat.processors.aggregators.gated_llm_context import GatedLLMContextAggregator
|
||||
|
||||
# Alias for backward compatibility with the previous name
|
||||
GatedOpenAILLMContextAggregator = GatedLLMContextAggregator
|
||||
@@ -33,9 +33,6 @@ from PIL import Image
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
|
||||
# "Re-export" types from OpenAI that we're using as universal context types.
|
||||
# NOTE: if universal message types need to someday diverge from OpenAI's, we
|
||||
# should consider managing our own definitions. But we should do so carefully,
|
||||
@@ -70,51 +67,6 @@ class LLMContext:
|
||||
and content formatting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_openai_context(openai_context: "OpenAILLMContext") -> "LLMContext":
|
||||
"""Create a universal LLM context from an OpenAI-specific context.
|
||||
|
||||
NOTE: this should only be used internally, for facilitating migration
|
||||
from OpenAILLMContext to LLMContext. New user code should use
|
||||
LLMContext directly.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`from_openai_context()` is deprecated and will be removed in a future version.
|
||||
Directly use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Args:
|
||||
openai_context: The OpenAI LLM context to convert.
|
||||
|
||||
Returns:
|
||||
New LLMContext instance with converted messages and settings.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"from_openai_context() (likely invoked by create_context_aggregator()) is deprecated and will be removed in a future version. "
|
||||
"Directly use the universal LLMContext and LLMContextAggregatorPair instead. "
|
||||
"See OpenAILLMContext docstring for migration guide.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Convert tools to ToolsSchema if needed.
|
||||
# If the tools are already a ToolsSchema, this is a no-op.
|
||||
# Otherwise, we wrap them in a shim ToolsSchema.
|
||||
converted_tools = openai_context.tools
|
||||
if isinstance(converted_tools, list):
|
||||
converted_tools = ToolsSchema(
|
||||
standard_tools=[], custom_tools={AdapterType.SHIM: converted_tools}
|
||||
)
|
||||
return LLMContext(
|
||||
messages=openai_context.get_messages(),
|
||||
tools=converted_tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[LLMContextMessage]] = None,
|
||||
@@ -246,33 +198,6 @@ class LLMContext:
|
||||
"""
|
||||
return self.get_messages()
|
||||
|
||||
def get_messages_for_persistent_storage(self) -> List[LLMContextMessage]:
|
||||
"""Get messages suitable for persistent storage.
|
||||
|
||||
NOTE: the only reason this method exists is because we're "silently"
|
||||
switching from OpenAILLMContext to LLMContext under the hood in some
|
||||
services and don't want to trip up users who may have been relying on
|
||||
this method, which is part of the public API of OpenAILLMContext but
|
||||
doesn't need to be for LLMContext.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
Use `get_messages()` instead.
|
||||
|
||||
Returns:
|
||||
List of conversation messages.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"get_messages_for_persistent_storage() is deprecated, use get_messages() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self.get_messages()
|
||||
|
||||
def get_messages(self, llm_specific_filter: Optional[str] = None) -> List[LLMContextMessage]:
|
||||
"""Get the current messages list.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,413 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI LLM context management for Pipecat.
|
||||
|
||||
This module provides classes for managing OpenAI-specific conversation contexts,
|
||||
including message handling, tool management, and image/audio processing capabilities.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This module is deprecated.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame, Frame
|
||||
|
||||
# JSON custom encoder to handle bytes arrays so that we can log contexts
|
||||
# with images to the console.
|
||||
|
||||
|
||||
class CustomEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder for handling special data types in logging.
|
||||
|
||||
Provides specialized encoding for io.BytesIO objects to display
|
||||
readable representations in log output instead of raw binary data.
|
||||
"""
|
||||
|
||||
def default(self, obj):
|
||||
"""Encode special objects for JSON serialization.
|
||||
|
||||
Args:
|
||||
obj: The object to encode.
|
||||
|
||||
Returns:
|
||||
Encoded representation of the object.
|
||||
"""
|
||||
if isinstance(obj, io.BytesIO):
|
||||
# Convert the first 8 bytes to an ASCII hex string
|
||||
return f"{obj.getbuffer()[0:8].hex()}..."
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class OpenAILLMContext:
|
||||
"""Manages conversation context for OpenAI LLM interactions.
|
||||
|
||||
Handles message history, tool definitions, tool choices, and multimedia content
|
||||
for OpenAI API conversations. Provides methods for message manipulation,
|
||||
content formatting, and integration with various LLM adapters.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAILLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
|
||||
**Before:**
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
**After:**
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[ChatCompletionMessageParam]] = None,
|
||||
tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = NOT_GIVEN,
|
||||
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
|
||||
):
|
||||
"""Initialize the OpenAI LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial list of conversation messages.
|
||||
tools: Available tools for the LLM to use.
|
||||
tool_choice: Tool selection strategy for the LLM.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAILLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"OpenAILLMContext is deprecated and will be removed in a future version. "
|
||||
"Use the universal LLMContext and LLMContextAggregatorPair instead. "
|
||||
"See OpenAILLMContext docstring for migration guide.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
|
||||
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||
self._tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = tools
|
||||
self._llm_adapter: Optional[BaseLLMAdapter] = None
|
||||
|
||||
def get_llm_adapter(self) -> Optional[BaseLLMAdapter]:
|
||||
"""Get the current LLM adapter.
|
||||
|
||||
Returns:
|
||||
The currently set LLM adapter, or None if not set.
|
||||
"""
|
||||
return self._llm_adapter
|
||||
|
||||
def set_llm_adapter(self, llm_adapter: BaseLLMAdapter):
|
||||
"""Set the LLM adapter for context processing.
|
||||
|
||||
Args:
|
||||
llm_adapter: The LLM adapter to use for tool conversion.
|
||||
"""
|
||||
self._llm_adapter = llm_adapter
|
||||
|
||||
@staticmethod
|
||||
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
||||
"""Create a context from a list of message dictionaries.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries to convert to context.
|
||||
|
||||
Returns:
|
||||
New OpenAILLMContext instance with the provided messages.
|
||||
"""
|
||||
context = OpenAILLMContext()
|
||||
|
||||
for message in messages:
|
||||
context.add_message(message)
|
||||
return context
|
||||
|
||||
@property
|
||||
def messages(self) -> List[ChatCompletionMessageParam]:
|
||||
"""Get the current messages list.
|
||||
|
||||
Returns:
|
||||
List of conversation messages.
|
||||
"""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def tools(self) -> List[ChatCompletionToolParam] | NotGiven | List[Any]:
|
||||
"""Get the tools list, converting through adapter if available.
|
||||
|
||||
Returns:
|
||||
Tools list, potentially converted by the LLM adapter.
|
||||
"""
|
||||
if self._llm_adapter:
|
||||
return self._llm_adapter.from_standard_tools(self._tools)
|
||||
return self._tools
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven:
|
||||
"""Get the current tool choice setting.
|
||||
|
||||
Returns:
|
||||
The tool choice configuration.
|
||||
"""
|
||||
return self._tool_choice
|
||||
|
||||
def add_message(self, message: ChatCompletionMessageParam):
|
||||
"""Add a single message to the context.
|
||||
|
||||
Args:
|
||||
message: The message to add to the conversation history.
|
||||
"""
|
||||
self._messages.append(message)
|
||||
|
||||
def add_messages(self, messages: List[ChatCompletionMessageParam]):
|
||||
"""Add multiple messages to the context.
|
||||
|
||||
Args:
|
||||
messages: List of messages to add to the conversation history.
|
||||
"""
|
||||
self._messages.extend(messages)
|
||||
|
||||
def set_messages(self, messages: List[ChatCompletionMessageParam]):
|
||||
"""Replace all messages in the context.
|
||||
|
||||
Args:
|
||||
messages: New list of messages to replace the current history.
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
|
||||
def get_messages(self) -> List[ChatCompletionMessageParam]:
|
||||
"""Get a copy of the current messages list.
|
||||
|
||||
Returns:
|
||||
List of all messages in the conversation history.
|
||||
"""
|
||||
return self._messages
|
||||
|
||||
def get_messages_json(self) -> str:
|
||||
"""Get messages as a formatted JSON string.
|
||||
|
||||
Returns:
|
||||
JSON string representation of all messages with custom encoding.
|
||||
"""
|
||||
return json.dumps(self._messages, cls=CustomEncoder, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_messages_for_logging(self) -> List[Dict[str, Any]]:
|
||||
"""Get sanitized messages suitable for logging.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:image/"):
|
||||
item["image_url"]["url"] = "data:image/..."
|
||||
if "mime_type" in msg and msg["mime_type"].startswith("image/"):
|
||||
msg["data"] = "..."
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert from OpenAI message format to OpenAI message format (passthrough).
|
||||
|
||||
OpenAI's format allows both simple string content and structured content::
|
||||
|
||||
Simple: {"role": "user", "content": "Hello"}
|
||||
Structured: {"role": "user", "content": [{"type": "text", "text": "Hello"}]}
|
||||
|
||||
Since OpenAI is our standard format, this is a passthrough function.
|
||||
|
||||
Args:
|
||||
message: Message in OpenAI format.
|
||||
|
||||
Returns:
|
||||
Same message, unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
"""Convert from OpenAI message format to OpenAI message format (passthrough).
|
||||
|
||||
OpenAI's format is our standard format throughout Pipecat. This function
|
||||
returns a list containing the original message to maintain consistency with
|
||||
other LLM services that may need to return multiple messages.
|
||||
|
||||
Args:
|
||||
obj: Message in OpenAI format with either simple string content
|
||||
or structured list content.
|
||||
|
||||
Returns:
|
||||
List containing the original messages, preserving the content format.
|
||||
"""
|
||||
return [obj]
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
"""Get messages for initializing conversation history.
|
||||
|
||||
Returns:
|
||||
List of messages suitable for history initialization.
|
||||
"""
|
||||
return self._messages
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Returns:
|
||||
List of messages converted to standard format for storage.
|
||||
"""
|
||||
messages = []
|
||||
for m in self._messages:
|
||||
standard_messages = self.to_standard_messages(m)
|
||||
messages.extend(standard_messages)
|
||||
return messages
|
||||
|
||||
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
|
||||
"""Set the tool choice configuration.
|
||||
|
||||
Args:
|
||||
tool_choice: Tool selection strategy for the LLM.
|
||||
"""
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = NOT_GIVEN):
|
||||
"""Set the available tools for the LLM.
|
||||
|
||||
Args:
|
||||
tools: List of tools available to the LLM, or NOT_GIVEN to disable tools.
|
||||
"""
|
||||
if tools != NOT_GIVEN and isinstance(tools, list) and len(tools) == 0:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
size: Image dimensions as (width, height) tuple.
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content = []
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
|
||||
)
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
|
||||
"""Add a message containing audio frames.
|
||||
|
||||
Args:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
|
||||
Note:
|
||||
This method is currently a placeholder for future implementation.
|
||||
"""
|
||||
# todo: implement for OpenAI models and others
|
||||
pass
|
||||
|
||||
def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size):
|
||||
"""Create a WAV file header for audio data.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
num_channels: Number of audio channels.
|
||||
bits_per_sample: Bits per audio sample.
|
||||
data_size: Size of audio data in bytes.
|
||||
|
||||
Returns:
|
||||
WAV header as a bytearray.
|
||||
"""
|
||||
# RIFF chunk descriptor
|
||||
header = bytearray()
|
||||
header.extend(b"RIFF") # ChunkID
|
||||
header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8
|
||||
header.extend(b"WAVE") # Format
|
||||
# "fmt " sub-chunk
|
||||
header.extend(b"fmt ") # Subchunk1ID
|
||||
header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM)
|
||||
header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM)
|
||||
header.extend(num_channels.to_bytes(2, "little")) # NumChannels
|
||||
header.extend(sample_rate.to_bytes(4, "little")) # SampleRate
|
||||
# Calculate byte rate and block align
|
||||
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
|
||||
block_align = num_channels * (bits_per_sample // 8)
|
||||
header.extend(byte_rate.to_bytes(4, "little")) # ByteRate
|
||||
header.extend(block_align.to_bytes(2, "little")) # BlockAlign
|
||||
header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample
|
||||
# "data" sub-chunk
|
||||
header.extend(b"data") # Subchunk2ID
|
||||
header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size
|
||||
return header
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAILLMContextFrame(Frame):
|
||||
"""Frame containing OpenAI-specific LLM context.
|
||||
|
||||
Like an LLMMessagesFrame, but with extra context specific to the OpenAI
|
||||
API. The context in this message is also mutable, and will be changed by the
|
||||
OpenAIContextAggregator frame processor.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAILLMContextFrame` is deprecated and will be removed in a future version.
|
||||
Use `LLMContextFrame` with the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
context: The OpenAI LLM context containing messages, tools, and configuration.
|
||||
"""
|
||||
|
||||
context: OpenAILLMContext
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"OpenAILLMContextFrame is deprecated and will be removed in a future version. "
|
||||
"Use LLMContextFrame with the universal `LLMContext` and `LLMContextAggregatorPair` instead. "
|
||||
"See OpenAILLMContext docstring for migration guide.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,81 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Vision image frame aggregation for Pipecat.
|
||||
|
||||
This module provides frame aggregation functionality to combine text and image
|
||||
frames into vision frames for multimodal processing.
|
||||
"""
|
||||
|
||||
from pipecat.frames.frames import Frame, InputImageRawFrame, TextFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class VisionImageFrameAggregator(FrameProcessor):
|
||||
"""Aggregates consecutive text and image frames into vision frames.
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
VisionImageRawFrame has been removed in favor of context frames
|
||||
(LLMContextFrame or OpenAILLMContextFrame), so this aggregator is not
|
||||
needed anymore. See the 12* examples for the new recommended pattern.
|
||||
|
||||
This aggregator waits for a consecutive TextFrame and an InputImageRawFrame.
|
||||
After the InputImageRawFrame arrives it will output a VisionImageRawFrame
|
||||
combining both the text and image data for multimodal processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the vision image frame aggregator.
|
||||
|
||||
The aggregator starts with no cached text, waiting for the first
|
||||
TextFrame to arrive before it can create vision frames.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"VisionImageFrameAggregator is deprecated. "
|
||||
"VisionImageRawFrame has been removed in favor of context frames "
|
||||
"(LLMContextFrame or OpenAILLMContextFrame), so this aggregator is "
|
||||
"not needed anymore. See the 12* examples for the new recommended "
|
||||
"pattern.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__()
|
||||
self._describe_text = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and aggregate text with images.
|
||||
|
||||
Caches TextFrames and combines them with subsequent InputImageRawFrames
|
||||
to create VisionImageRawFrames. Other frames are passed through unchanged.
|
||||
|
||||
Args:
|
||||
frame: The incoming frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
self._describe_text = frame.text
|
||||
elif isinstance(frame, InputImageRawFrame):
|
||||
if self._describe_text:
|
||||
context = OpenAILLMContext()
|
||||
context.add_image_frame_message(
|
||||
text=self._describe_text,
|
||||
image=frame.image,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = OpenAILLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
self._describe_text = None
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -196,7 +196,6 @@ class FrameProcessor(BaseObject):
|
||||
# Other properties (deprecated)
|
||||
self._allow_interruptions = False
|
||||
self._interruption_strategies: List[BaseInterruptionStrategy] = []
|
||||
self._deprecated_openaillmcontext = False
|
||||
|
||||
# Indicates whether we have received the StartFrame.
|
||||
self.__started = False
|
||||
@@ -826,9 +825,6 @@ class FrameProcessor(BaseObject):
|
||||
self._interruption_strategies = frame.interruption_strategies
|
||||
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
|
||||
|
||||
# NOTE(aleix): Remove when OpenAILLMContext/LLMUserContextAggregator is removed.
|
||||
self._deprecated_openaillmcontext = "deprecated_openaillmcontext" in frame.metadata
|
||||
|
||||
self.__create_process_task()
|
||||
|
||||
async def __cancel(self, frame: CancelFrame):
|
||||
|
||||
@@ -17,7 +17,6 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseStartFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
try:
|
||||
@@ -65,15 +64,11 @@ class LangchainProcessor(FrameProcessor):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
# Messages are accumulated on the context as a list of messages.
|
||||
# The last one by the human is the one we want to send to the LLM.
|
||||
logger.debug(f"Got transcription frame {frame}")
|
||||
messages = (
|
||||
frame.context.messages
|
||||
if isinstance(frame, OpenAILLMContextFrame)
|
||||
else frame.context.get_messages()
|
||||
)
|
||||
messages = frame.context.get_messages()
|
||||
text: str = messages[-1]["content"]
|
||||
|
||||
await self._ainvoke(text.strip())
|
||||
|
||||
@@ -59,7 +59,6 @@ from pipecat.metrics.metrics import (
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi.frames import (
|
||||
RTVIServerMessageFrame,
|
||||
@@ -358,10 +357,7 @@ class RTVIObserver(BaseObserver):
|
||||
and self._params.user_transcription_enabled
|
||||
):
|
||||
await self._handle_user_transcriptions(frame)
|
||||
elif (
|
||||
isinstance(frame, (OpenAILLMContextFrame, LLMContextFrame))
|
||||
and self._params.user_llm_enabled
|
||||
):
|
||||
elif isinstance(frame, LLMContextFrame) and self._params.user_llm_enabled:
|
||||
await self._handle_context(frame)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
|
||||
await self.send_rtvi_message(RTVI.BotLLMStartedMessage())
|
||||
@@ -575,13 +571,10 @@ class RTVIObserver(BaseObserver):
|
||||
if message:
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
async def _handle_context(self, frame: OpenAILLMContextFrame | LLMContextFrame):
|
||||
async def _handle_context(self, frame: LLMContextFrame):
|
||||
"""Process LLM context frames to extract user messages for the RTVI client."""
|
||||
try:
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
messages = frame.context.messages
|
||||
else:
|
||||
messages = frame.context.get_messages()
|
||||
messages = frame.context.get_messages()
|
||||
if not messages:
|
||||
return
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
try:
|
||||
@@ -72,7 +71,7 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
messages = frame.context.get_messages()
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
|
||||
@@ -37,7 +37,6 @@ from pipecat.frames.frames import (
|
||||
LLMEnablePromptCachingFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
@@ -45,16 +44,6 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN
|
||||
@@ -115,44 +104,6 @@ class AnthropicLLMSettings(LLMSettings):
|
||||
return instance
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnthropicContextAggregatorPair:
|
||||
"""Pair of context aggregators for Anthropic conversations.
|
||||
|
||||
Encapsulates both user and assistant context aggregators
|
||||
to manage conversation flow and message formatting.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AnthropicContextAggregatorPair` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator.
|
||||
_assistant: The assistant context aggregator.
|
||||
"""
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
_user: "AnthropicUserContextAggregator"
|
||||
_assistant: "AnthropicAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "AnthropicUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "AnthropicAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class AnthropicLLMService(LLMService):
|
||||
"""LLM service for Anthropic's Claude models.
|
||||
|
||||
@@ -351,7 +302,7 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
async def run_inference(
|
||||
self,
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
@@ -371,21 +322,15 @@ class AnthropicLLMService(LLMService):
|
||||
system = NOT_GIVEN
|
||||
tools = []
|
||||
effective_instruction = system_instruction or self._settings.system_instruction
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
invocation_params = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
enable_prompt_caching=self._settings.enable_prompt_caching,
|
||||
system_instruction=effective_instruction,
|
||||
)
|
||||
messages = invocation_params["messages"]
|
||||
system = invocation_params["system"]
|
||||
tools = invocation_params["tools"]
|
||||
else:
|
||||
context = AnthropicLLMContext.upgrade_to_anthropic(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", NOT_GIVEN)
|
||||
tools = context.tools or []
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
invocation_params = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
enable_prompt_caching=self._settings.enable_prompt_caching,
|
||||
system_instruction=effective_instruction,
|
||||
)
|
||||
messages = invocation_params["messages"]
|
||||
system = invocation_params["system"]
|
||||
tools = invocation_params["tools"]
|
||||
|
||||
# Build params using the same method as streaming completions
|
||||
params = {
|
||||
@@ -410,70 +355,17 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
return next((block.text for block in response.content if hasattr(block, "text")), None)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
"""Create Anthropic-specific context aggregators.
|
||||
|
||||
Creates a pair of context aggregators optimized for Anthropic's message format,
|
||||
including support for function calls, tool usage, and image handling.
|
||||
|
||||
Args:
|
||||
context: The LLM context.
|
||||
user_params: User aggregator parameters.
|
||||
assistant_params: Assistant aggregator parameters.
|
||||
|
||||
Returns:
|
||||
A pair of context aggregators, one for the user and one for the assistant,
|
||||
encapsulated in an AnthropicContextAggregatorPair.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AnthropicLLMContext.from_openai_context(context)
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
user = AnthropicUserContextAggregator(context, params=user_params)
|
||||
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
|
||||
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def _get_llm_invocation_params(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
# Universal LLMContext
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
enable_prompt_caching=self._settings.enable_prompt_caching,
|
||||
system_instruction=self._settings.system_instruction,
|
||||
)
|
||||
return params
|
||||
|
||||
# Anthropic-specific context
|
||||
messages = (
|
||||
context.get_messages_with_cache_control_markers()
|
||||
if self._settings.enable_prompt_caching
|
||||
else context.messages
|
||||
)
|
||||
return AnthropicLLMInvocationParams(
|
||||
system=context.system,
|
||||
messages=messages,
|
||||
tools=context.tools or [],
|
||||
def _get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams:
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
enable_prompt_caching=self._settings.enable_prompt_caching,
|
||||
system_instruction=self._settings.system_instruction,
|
||||
)
|
||||
return params
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
# and use that estimate if we are interrupted, because we almost certainly won't
|
||||
@@ -491,15 +383,10 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
params_from_context = self._get_llm_invocation_params(context)
|
||||
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
context_type_for_logging = "universal"
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
else:
|
||||
context_type_for_logging = "LLM-specific"
|
||||
messages_for_logging = context.get_messages_for_logging()
|
||||
adapter = self.get_llm_adapter()
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from {context_type_for_logging} context [{params_from_context['system']}] | {messages_for_logging}"
|
||||
f"{self}: Generating chat from context [{params_from_context['system']}] | {messages_for_logging}"
|
||||
)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -666,14 +553,8 @@ class AnthropicLLMService(LLMService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
# NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal
|
||||
# LLMContext with it
|
||||
context = AnthropicLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._settings.enable_prompt_caching = frame.enable
|
||||
@@ -707,581 +588,3 @@ class AnthropicLLMService(LLMService):
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
|
||||
class AnthropicLLMContext(OpenAILLMContext):
|
||||
"""LLM context specialized for Anthropic's message format and features.
|
||||
|
||||
Extends OpenAILLMContext to handle Anthropic-specific features like
|
||||
system messages, prompt caching, and message format conversions.
|
||||
Manages conversation state and message history formatting.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AnthropicLLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
*,
|
||||
system: Union[str, NotGiven] = NOT_GIVEN,
|
||||
):
|
||||
"""Initialize the Anthropic LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial list of conversation messages.
|
||||
tools: Available function calling tools.
|
||||
tool_choice: Tool selection preference.
|
||||
system: System message content.
|
||||
"""
|
||||
# Super handles deprecation warning
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.__setup_local()
|
||||
self.system = system
|
||||
|
||||
def __setup_local(self):
|
||||
# For beta prompt caching. This is a counter that tracks the number of turns
|
||||
# we've seen above the cache threshold. We reset this when we reset the
|
||||
# messages list. We only care about this number being 0, 1, or 2. But
|
||||
# it's easiest just to treat it as a counter.
|
||||
self.turns_above_cache_threshold = 0
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_anthropic(obj: OpenAILLMContext) -> "AnthropicLLMContext":
|
||||
"""Upgrade an OpenAI context to Anthropic format.
|
||||
|
||||
Converts message format and restructures content for Anthropic compatibility.
|
||||
|
||||
Args:
|
||||
obj: The OpenAI context to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded Anthropic context.
|
||||
"""
|
||||
logger.debug(f"Upgrading to Anthropic: {obj}")
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AnthropicLLMContext):
|
||||
obj.__class__ = AnthropicLLMContext
|
||||
obj.__setup_local()
|
||||
obj._restructure_from_openai_messages()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
"""Create Anthropic context from OpenAI context.
|
||||
|
||||
Args:
|
||||
openai_context: The OpenAI context to convert.
|
||||
|
||||
Returns:
|
||||
New Anthropic context with converted messages.
|
||||
"""
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
tools=openai_context.tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
self.set_llm_adapter(openai_context.get_llm_adapter())
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "AnthropicLLMContext":
|
||||
"""Create context from a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
|
||||
Returns:
|
||||
New Anthropic context with the provided messages.
|
||||
"""
|
||||
self = cls(messages=messages)
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
"""Set the messages list and reset cache tracking.
|
||||
|
||||
Args:
|
||||
messages: New list of messages to set.
|
||||
"""
|
||||
self.turns_above_cache_threshold = 0
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
def to_standard_messages(self, obj):
|
||||
"""Convert Anthropic message format to standard structured format.
|
||||
|
||||
Handles text content and function calls for both user and assistant messages.
|
||||
|
||||
Args:
|
||||
obj: Message in Anthropic format.
|
||||
|
||||
Returns:
|
||||
List of messages in standard format.
|
||||
|
||||
Examples:
|
||||
Input Anthropic format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "tool_use", "id": "123", "name": "search", "input": {"q": "test"}}
|
||||
]
|
||||
}
|
||||
|
||||
Output standard format::
|
||||
|
||||
[
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
# todo: image format (?)
|
||||
# tool_use
|
||||
role = obj.get("role")
|
||||
content = obj.get("content")
|
||||
if role == "assistant":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "tool_use":
|
||||
tool_items.append(
|
||||
{
|
||||
"type": "function",
|
||||
"id": item["id"],
|
||||
"function": {
|
||||
"name": item["name"],
|
||||
"arguments": json.dumps(item["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
if tool_items:
|
||||
messages.append({"role": role, "tool_calls": tool_items})
|
||||
return messages
|
||||
elif role == "user":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "tool_result":
|
||||
tool_items.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["tool_use_id"],
|
||||
"content": item["content"],
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
messages.extend(tool_items)
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to Anthropic format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Message in Anthropic format.
|
||||
|
||||
Examples:
|
||||
Input standard format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Output Anthropic format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "123",
|
||||
"name": "search",
|
||||
"input": {"q": "test"}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# todo: image messages (?)
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
],
|
||||
}
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
# check for empty text strings
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
content = "(empty)"
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item["type"] == "text" and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
|
||||
return message
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add an image message to the context.
|
||||
|
||||
Converts the image to base64 JPEG format and adds it as a user message
|
||||
with optional accompanying text.
|
||||
|
||||
Args:
|
||||
format: The image format (e.g., 'RGB', 'RGBA').
|
||||
size: Image dimensions as (width, height).
|
||||
image: Raw image bytes.
|
||||
text: Optional text to accompany the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# Anthropic docs say that the image should be the first content block in the message.
|
||||
content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
]
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_message(self, message):
|
||||
"""Add a message to the context, merging with previous message if same role.
|
||||
|
||||
Anthropic requires alternating roles, so consecutive messages from the same
|
||||
role are merged together.
|
||||
|
||||
Args:
|
||||
message: The message to add to the context.
|
||||
"""
|
||||
try:
|
||||
if self.messages:
|
||||
# Anthropic requires that roles alternate. If this message's role is the same as the
|
||||
# last message, we should add this message's content to the last message.
|
||||
if self.messages[-1]["role"] == message["role"]:
|
||||
# if the last message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(self.messages[-1]["content"], str):
|
||||
self.messages[-1]["content"] = [
|
||||
{"type": "text", "text": self.messages[-1]["content"]}
|
||||
]
|
||||
# if this message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
# append the content of this message to the last message
|
||||
self.messages[-1]["content"].extend(message["content"])
|
||||
else:
|
||||
self.messages.append(message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def get_messages_with_cache_control_markers(self) -> List[dict]:
|
||||
"""Get messages with prompt caching markers applied.
|
||||
|
||||
Adds cache control markers to appropriate messages based on the
|
||||
number of turns above the cache threshold.
|
||||
|
||||
Returns:
|
||||
List of messages with cache control markers added.
|
||||
"""
|
||||
try:
|
||||
messages = copy.deepcopy(self.messages)
|
||||
if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user":
|
||||
if isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}]
|
||||
messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
if (
|
||||
self.turns_above_cache_threshold >= 2
|
||||
and len(messages) > 2
|
||||
and messages[-3]["role"] == "user"
|
||||
):
|
||||
if isinstance(messages[-3]["content"], str):
|
||||
messages[-3]["content"] = [{"type": "text", "text": messages[-3]["content"]}]
|
||||
messages[-3]["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding cache control marker: {e}")
|
||||
return self.messages
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
|
||||
try:
|
||||
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our context.messages list. (For
|
||||
# compatibility with Open AI messages format.)
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
if len(self.messages) == 1:
|
||||
# If we have only have a system message in the list, all we can really do
|
||||
# without introducing too much magic is change the role to "user".
|
||||
self.messages[0]["role"] = "user"
|
||||
else:
|
||||
# If we have more than one message, we'll pull the system message out of the
|
||||
# list.
|
||||
self.system = self.messages[0]["content"]
|
||||
self.messages.pop(0)
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(self.messages) - 1:
|
||||
current_message = self.messages[i]
|
||||
next_message = self.messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
self.messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in self.messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Includes system message at the beginning if present.
|
||||
|
||||
Returns:
|
||||
List of messages suitable for storage.
|
||||
"""
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
if self.system:
|
||||
messages.insert(0, {"role": "system", "content": self.system})
|
||||
return messages
|
||||
|
||||
def get_messages_for_logging(self) -> List[Dict[str, Any]]:
|
||||
"""Get messages formatted for logging with sensitive data redacted.
|
||||
|
||||
Replaces image data with placeholder text for cleaner logs.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image":
|
||||
item["source"]["data"] = "..."
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
|
||||
class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
"""Anthropic-specific user context aggregator.
|
||||
|
||||
Handles aggregation of user messages for Anthropic LLM services.
|
||||
Inherits all functionality from the base LLMUserContextAggregator.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AnthropicUserContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# Claude returns a text content block along with a tool use content block. This works quite nicely
|
||||
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
|
||||
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
|
||||
#
|
||||
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
|
||||
# chattiness about it's tool thinking.
|
||||
#
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""Context aggregator for assistant messages in Anthropic conversations.
|
||||
|
||||
Handles function call lifecycle management including in-progress tracking,
|
||||
result handling, and cancellation for Anthropic's tool use format.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AnthropicAssistantContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle a function call that is starting.
|
||||
|
||||
Creates tool use message and placeholder tool result for tracking.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call details.
|
||||
"""
|
||||
assistant_message = {"role": "assistant", "content": []}
|
||||
assistant_message["content"].append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": frame.tool_call_id,
|
||||
"name": frame.function_name,
|
||||
"input": frame.arguments,
|
||||
}
|
||||
)
|
||||
self._context.add_message(assistant_message)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": frame.tool_call_id,
|
||||
"content": "IN_PROGRESS",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle the result of a completed function call.
|
||||
|
||||
Updates the tool result with actual return value or completion status.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call result.
|
||||
"""
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result, ensure_ascii=False)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
"""Handle cancellation of a function call.
|
||||
|
||||
Updates the tool result to indicate cancellation.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call cancellation details.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message["role"] == "user":
|
||||
for content in message["content"]:
|
||||
if (
|
||||
isinstance(content, dict)
|
||||
and content["type"] == "tool_result"
|
||||
and content["tool_use_id"] == tool_call_id
|
||||
):
|
||||
content["content"] = result
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle a user image frame with function call context.
|
||||
|
||||
Marks the associated function call as completed and adds the image
|
||||
to the conversation context.
|
||||
|
||||
Args:
|
||||
frame: User image frame with request context.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
@@ -26,15 +26,11 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
def default_context_to_payload_transformer(
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
) -> Optional[str]:
|
||||
"""Default transformer to create AgentCore payload from LLM context.
|
||||
|
||||
@@ -118,9 +114,7 @@ class AWSAgentCoreProcessor(FrameProcessor):
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: Optional[str] = None,
|
||||
context_to_payload_transformer: Optional[
|
||||
Callable[[LLMContext | OpenAILLMContext], Optional[str]]
|
||||
] = None,
|
||||
context_to_payload_transformer: Optional[Callable[[LLMContext], Optional[str]]] = None,
|
||||
response_to_output_transformer: Optional[Callable[[str], Optional[str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -200,7 +194,7 @@ class AWSAgentCoreProcessor(FrameProcessor):
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
# Create payload to invoke AgentCore agent
|
||||
payload = self._context_to_payload_transformer(frame.context)
|
||||
|
||||
|
||||
@@ -38,21 +38,10 @@ from pipecat.frames.frames import (
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven
|
||||
@@ -87,657 +76,6 @@ class AWSBedrockLLMSettings(LLMSettings):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSBedrockContextAggregatorPair:
|
||||
"""Container for AWS Bedrock context aggregators.
|
||||
|
||||
Provides convenient access to both user and assistant context aggregators
|
||||
for AWS Bedrock LLM operations.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSBedrockContextAggregatorPair` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator instance.
|
||||
_assistant: The assistant context aggregator instance.
|
||||
"""
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
_user: "AWSBedrockUserContextAggregator"
|
||||
_assistant: "AWSBedrockAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "AWSBedrockUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "AWSBedrockAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class AWSBedrockLLMContext(OpenAILLMContext):
|
||||
"""AWS Bedrock-specific LLM context implementation.
|
||||
|
||||
Extends OpenAI LLM context to handle AWS Bedrock's specific message format
|
||||
and system message handling. Manages conversion between OpenAI and Bedrock
|
||||
message formats.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSBedrockLLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
*,
|
||||
system: Optional[str] = None,
|
||||
):
|
||||
"""Initialize AWS Bedrock LLM context.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages in OpenAI format.
|
||||
tools: List of available function calling tools.
|
||||
tool_choice: Tool selection strategy or specific tool choice.
|
||||
system: System message content for AWS Bedrock.
|
||||
"""
|
||||
# Super handles deprecation warning
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.system = system
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext":
|
||||
"""Upgrade an OpenAI LLM context to AWS Bedrock format.
|
||||
|
||||
Args:
|
||||
obj: The OpenAI LLM context to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded AWS Bedrock LLM context.
|
||||
"""
|
||||
logger.debug(f"Upgrading to AWS Bedrock: {obj}")
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext):
|
||||
obj.__class__ = AWSBedrockLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
else:
|
||||
obj._restructure_from_bedrock_messages()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
"""Create AWS Bedrock context from OpenAI context.
|
||||
|
||||
Args:
|
||||
openai_context: The OpenAI LLM context to convert.
|
||||
|
||||
Returns:
|
||||
New AWS Bedrock LLM context instance.
|
||||
"""
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
tools=openai_context.tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
self.set_llm_adapter(openai_context.get_llm_adapter())
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext":
|
||||
"""Create AWS Bedrock context from message list.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format.
|
||||
|
||||
Returns:
|
||||
New AWS Bedrock LLM context instance.
|
||||
"""
|
||||
self = cls(messages=messages)
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
"""Set the messages list and restructure for Bedrock format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to set.
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
def to_standard_messages(self, obj):
|
||||
"""Convert AWS Bedrock message format to standard structured format.
|
||||
|
||||
Handles text content and function calls for both user and assistant messages.
|
||||
|
||||
Args:
|
||||
obj: Message in AWS Bedrock format.
|
||||
|
||||
Returns:
|
||||
List of messages in standard format.
|
||||
|
||||
Examples:
|
||||
AWS Bedrock format input::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"text": "Hello"},
|
||||
{"toolUse": {"toolUseId": "123", "name": "search", "input": {"q": "test"}}}
|
||||
]
|
||||
}
|
||||
|
||||
Standard format output::
|
||||
|
||||
[
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
role = obj.get("role")
|
||||
content = obj.get("content")
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif "toolUse" in item:
|
||||
tool_use = item["toolUse"]
|
||||
tool_items.append(
|
||||
{
|
||||
"type": "function",
|
||||
"id": tool_use["toolUseId"],
|
||||
"function": {
|
||||
"name": tool_use["name"],
|
||||
"arguments": json.dumps(tool_use["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
if tool_items:
|
||||
messages.append({"role": role, "tool_calls": tool_items})
|
||||
return messages
|
||||
elif role == "user":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif "toolResult" in item:
|
||||
tool_result = item["toolResult"]
|
||||
# Extract content from toolResult
|
||||
result_content = ""
|
||||
if isinstance(tool_result["content"], list):
|
||||
for content_item in tool_result["content"]:
|
||||
if "text" in content_item:
|
||||
result_content = content_item["text"]
|
||||
elif "json" in content_item:
|
||||
result_content = json.dumps(content_item["json"])
|
||||
else:
|
||||
result_content = tool_result["content"]
|
||||
|
||||
tool_items.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_result["toolUseId"],
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
messages.extend(tool_items)
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to AWS Bedrock format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Message in AWS Bedrock format.
|
||||
|
||||
Examples:
|
||||
Standard format input::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
AWS Bedrock format output::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "123",
|
||||
"name": "search",
|
||||
"input": {"q": "test"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
if message["role"] == "tool":
|
||||
# Try to parse the content as JSON if it looks like JSON
|
||||
try:
|
||||
if message["content"].strip().startswith("{") and message[
|
||||
"content"
|
||||
].strip().endswith("}"):
|
||||
content_json = json.loads(message["content"])
|
||||
tool_result_content = [{"json": content_json}]
|
||||
else:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
except (json.JSONDecodeError, ValueError, AttributeError):
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": tool_result_content,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
|
||||
# Handle text content
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
return {"role": message["role"], "content": [{"text": "(empty)"}]}
|
||||
else:
|
||||
return {"role": message["role"], "content": [{"text": content}]}
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
# fix empty text
|
||||
if item.get("type", "") == "text":
|
||||
text_content = item["text"] if item["text"] != "" else "(empty)"
|
||||
new_content.append({"text": text_content})
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
new_item = {
|
||||
"image": {
|
||||
"format": "jpeg",
|
||||
"source": {
|
||||
"bytes": base64.b64decode(item["image_url"]["url"].split(",")[1])
|
||||
},
|
||||
}
|
||||
}
|
||||
new_content.append(new_item)
|
||||
# In the case where there's a single image in the list (like what
|
||||
# would result from a UserImageRawFrame), ensure that the image
|
||||
# comes before text
|
||||
image_indices = [i for i, item in enumerate(new_content) if "image" in item]
|
||||
text_indices = [i for i, item in enumerate(new_content) if "text" in item]
|
||||
if len(image_indices) == 1 and text_indices:
|
||||
img_idx = image_indices[0]
|
||||
first_txt_idx = text_indices[0]
|
||||
if img_idx > first_txt_idx:
|
||||
# Move image before the first text
|
||||
image_item = new_content.pop(img_idx)
|
||||
new_content.insert(first_txt_idx, image_item)
|
||||
return {"role": message["role"], "content": new_content}
|
||||
|
||||
return message
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add an image message to the context.
|
||||
|
||||
Args:
|
||||
format: The image format (e.g., 'RGB', 'RGBA').
|
||||
size: The image dimensions as (width, height).
|
||||
image: The raw image data as bytes.
|
||||
text: Optional text to accompany the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# Image should be the first content block in the message
|
||||
content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}]
|
||||
if text:
|
||||
content.append({"text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_message(self, message):
|
||||
"""Add a message to the context, merging with previous message if same role.
|
||||
|
||||
AWS Bedrock requires alternating roles, so consecutive messages from the
|
||||
same role are merged together.
|
||||
|
||||
Args:
|
||||
message: The message to add to the context.
|
||||
"""
|
||||
try:
|
||||
if self.messages:
|
||||
# AWS Bedrock requires that roles alternate. If this message's
|
||||
# role is the same as the last message, we should add this
|
||||
# message's content to the last message.
|
||||
if self.messages[-1]["role"] == message["role"]:
|
||||
# if the last message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(self.messages[-1]["content"], str):
|
||||
self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}]
|
||||
# if this message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
# append the content of this message to the last message
|
||||
self.messages[-1]["content"].extend(message["content"])
|
||||
else:
|
||||
self.messages.append(message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def _restructure_from_bedrock_messages(self):
|
||||
"""Restructure messages in AWS Bedrock format.
|
||||
|
||||
Handles system messages, merging consecutive messages with the same role,
|
||||
and ensuring proper content formatting.
|
||||
"""
|
||||
# Handle system message if present at the beginning
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
if len(self.messages) == 1:
|
||||
self.messages[0]["role"] = "user"
|
||||
else:
|
||||
system_content = self.messages.pop(0)["content"]
|
||||
if isinstance(system_content, str):
|
||||
system_content = [{"text": system_content}]
|
||||
|
||||
if self.system:
|
||||
if isinstance(self.system, str):
|
||||
self.system = [{"text": self.system}]
|
||||
self.system.extend(system_content)
|
||||
else:
|
||||
self.system = system_content
|
||||
|
||||
# Ensure content is properly formatted
|
||||
for msg in self.messages:
|
||||
if isinstance(msg["content"], str):
|
||||
msg["content"] = [{"text": msg["content"]}]
|
||||
elif not msg["content"]:
|
||||
msg["content"] = [{"text": "(empty)"}]
|
||||
elif isinstance(msg["content"], list):
|
||||
for idx, item in enumerate(msg["content"]):
|
||||
if isinstance(item, dict) and "text" in item and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
elif isinstance(item, str) and item == "":
|
||||
msg["content"][idx] = {"text": "(empty)"}
|
||||
|
||||
# Merge consecutive messages with the same role
|
||||
merged_messages = []
|
||||
for msg in self.messages:
|
||||
if merged_messages and merged_messages[-1]["role"] == msg["role"]:
|
||||
merged_messages[-1]["content"].extend(msg["content"])
|
||||
else:
|
||||
merged_messages.append(msg)
|
||||
|
||||
self.messages.clear()
|
||||
self.messages.extend(merged_messages)
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
|
||||
try:
|
||||
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our context.messages list. (For
|
||||
# compatibility with Open AI messages format.)
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
self.system = self.messages[0]["content"]
|
||||
self.messages.pop(0)
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(self.messages) - 1:
|
||||
current_message = self.messages[i]
|
||||
next_message = self.messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
self.messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in self.messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Returns:
|
||||
List of messages including system message if present.
|
||||
"""
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
if self.system:
|
||||
messages.insert(0, {"role": "system", "content": self.system})
|
||||
return messages
|
||||
|
||||
def get_messages_for_logging(self) -> List[Dict[str, Any]]:
|
||||
"""Get messages formatted for logging with sensitive data redacted.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("image"):
|
||||
item["image"]["source"]["bytes"] = "..."
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
|
||||
class AWSBedrockUserContextAggregator(LLMUserContextAggregator):
|
||||
"""User context aggregator for AWS Bedrock LLM service.
|
||||
|
||||
Handles aggregation of user messages and frames for AWS Bedrock format.
|
||||
Inherits all functionality from the base LLM user context aggregator.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSBedrockUserContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Args:
|
||||
context: The LLM context to aggregate messages into.
|
||||
params: Configuration parameters for the aggregator.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
pass
|
||||
|
||||
|
||||
class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""Assistant context aggregator for AWS Bedrock LLM service.
|
||||
|
||||
Handles aggregation of assistant responses and function calls for AWS Bedrock
|
||||
format, including tool use and tool result handling.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSBedrockAssistantContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Args:
|
||||
context: The LLM context to aggregate messages into.
|
||||
params: Configuration parameters for the aggregator.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle function call in progress frame.
|
||||
|
||||
Args:
|
||||
frame: The function call in progress frame to handle.
|
||||
"""
|
||||
# Format tool use according to AWS Bedrock API
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": frame.tool_call_id,
|
||||
"name": frame.function_name,
|
||||
"input": frame.arguments if frame.arguments else {},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": frame.tool_call_id,
|
||||
"content": [{"text": "IN_PROGRESS"}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result frame.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result, ensure_ascii=False)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
"""Handle function call cancel frame.
|
||||
|
||||
Args:
|
||||
frame: The function call cancel frame to handle.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message["role"] == "user":
|
||||
for content in message["content"]:
|
||||
if (
|
||||
isinstance(content, dict)
|
||||
and content.get("toolResult")
|
||||
and content["toolResult"]["toolUseId"] == tool_call_id
|
||||
):
|
||||
content["toolResult"]["content"] = [{"text": result}]
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle user image frame.
|
||||
|
||||
Args:
|
||||
frame: The user image frame to handle.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
|
||||
class AWSBedrockLLMService(LLMService):
|
||||
"""AWS Bedrock Large Language Model service implementation.
|
||||
|
||||
@@ -924,7 +262,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
|
||||
async def run_inference(
|
||||
self,
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
@@ -943,17 +281,12 @@ class AWSBedrockLLMService(LLMService):
|
||||
messages = []
|
||||
system = []
|
||||
effective_instruction = system_instruction or self._settings.system_instruction
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=effective_instruction
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}] or None
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=effective_instruction
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}] or None
|
||||
|
||||
# Prepare request parameters using the same method as streaming
|
||||
inference_config = self._build_inference_config()
|
||||
@@ -1021,44 +354,6 @@ class AWSBedrockLLMService(LLMService):
|
||||
response = await client.converse_stream(**request_params)
|
||||
return response
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AWSBedrockContextAggregatorPair:
|
||||
"""Create AWS Bedrock-specific context aggregators.
|
||||
|
||||
Creates a pair of context aggregators optimized for AWS Bedrocks's message
|
||||
format, including support for function calls, tool usage, and image handling.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
AWSBedrockContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
AWSBedrockContextAggregatorPair.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AWSBedrockLLMContext.from_openai_context(context)
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
user = AWSBedrockUserContextAggregator(context, params=user_params)
|
||||
assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params)
|
||||
|
||||
return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def _create_no_op_tool(self):
|
||||
"""Create a no-operation tool for AWS Bedrock when tool content exists but no tools are defined.
|
||||
|
||||
@@ -1074,27 +369,15 @@ class AWSBedrockLLMService(LLMService):
|
||||
}
|
||||
}
|
||||
|
||||
def _get_llm_invocation_params(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AWSBedrockLLMInvocationParams:
|
||||
# Universal LLMContext
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=self._settings.system_instruction
|
||||
)
|
||||
return params
|
||||
|
||||
# AWS Bedrock-specific context
|
||||
return AWSBedrockLLMInvocationParams(
|
||||
system=getattr(context, "system", None),
|
||||
messages=context.messages,
|
||||
tools=context.tools or [],
|
||||
tool_choice=context.tool_choice,
|
||||
def _get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=self._settings.system_instruction
|
||||
)
|
||||
return params
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: AWSBedrockLLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
# Usage tracking
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
@@ -1173,15 +456,10 @@ class AWSBedrockLLMService(LLMService):
|
||||
request_params["performanceConfig"] = {"latency": self._settings.latency}
|
||||
|
||||
# Log request params with messages redacted for logging
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
context_type_for_logging = "universal"
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
else:
|
||||
context_type_for_logging = "LLM-specific"
|
||||
messages_for_logging = context.get_messages_for_logging()
|
||||
adapter = self.get_llm_adapter()
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from {context_type_for_logging} context [{system}] | {messages_for_logging}"
|
||||
f"{self}: Generating chat from context [{system}] | {messages_for_logging}"
|
||||
)
|
||||
|
||||
async with self._aws_session.client(
|
||||
@@ -1287,14 +565,8 @@ class AWSBedrockLLMService(LLMService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
# NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal
|
||||
# LLMContext with it
|
||||
context = AWSBedrockLLMContext.from_messages(frame.messages)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Context management for AWS Nova Sonic LLM service.
|
||||
|
||||
This module provides specialized context aggregators and message handling for AWS Nova Sonic,
|
||||
including conversation history management and role-specific message processing.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
AWS Nova Sonic no longer uses types from this module under the hood.
|
||||
It now uses ``LLMContext`` and ``LLMContextAggregatorPair``.
|
||||
Using the new patterns should allow you to not need types from this module.
|
||||
|
||||
BEFORE::
|
||||
|
||||
# Setup
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Context frame type
|
||||
frame: OpenAILLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: AWSNovaSonicLLMContext
|
||||
# or
|
||||
context: OpenAILLMContext
|
||||
|
||||
AFTER::
|
||||
|
||||
# Setup
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Context frame type
|
||||
frame: LLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: LLMContext
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.aws.nova_sonic.context (or "
|
||||
"pipecat.services.aws_nova_sonic.context) are deprecated. \n"
|
||||
"AWS Nova Sonic no longer uses types from this module under the hood. \n"
|
||||
"It now uses `LLMContext` and `LLMContextAggregatorPair`. \n"
|
||||
"Using the new patterns should allow you to not need types from this module.\n\n"
|
||||
"BEFORE:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = OpenAILLMContext(messages, tools)\n"
|
||||
"context_aggregator = llm.create_context_aggregator(context)\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: OpenAILLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: AWSNovaSonicLLMContext\n"
|
||||
"# or\n"
|
||||
"context: OpenAILLMContext\n\n"
|
||||
"```\n\n"
|
||||
"AFTER:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = LLMContext(messages, tools)\n"
|
||||
"context_aggregator = LLMContextAggregatorPair(context)\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: LLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: LLMContext\n\n"
|
||||
"```",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""Roles supported in AWS Nova Sonic conversations.
|
||||
|
||||
Parameters:
|
||||
SYSTEM: System-level messages (not used in conversation history).
|
||||
USER: Messages sent by the user.
|
||||
ASSISTANT: Messages sent by the assistant.
|
||||
TOOL: Messages sent by tools (not used in conversation history).
|
||||
"""
|
||||
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistoryMessage:
|
||||
"""A single message in AWS Nova Sonic conversation history.
|
||||
|
||||
Parameters:
|
||||
role: The role of the message sender (USER or ASSISTANT only).
|
||||
text: The text content of the message.
|
||||
"""
|
||||
|
||||
role: Role # only USER and ASSISTANT
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistory:
|
||||
"""Complete conversation history for AWS Nova Sonic initialization.
|
||||
|
||||
Parameters:
|
||||
system_instruction: System-level instruction for the conversation.
|
||||
messages: List of conversation messages between user and assistant.
|
||||
"""
|
||||
|
||||
system_instruction: str = None
|
||||
messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMContext(OpenAILLMContext):
|
||||
"""Specialized LLM context for AWS Nova Sonic service.
|
||||
|
||||
Extends OpenAI context with Nova Sonic-specific message handling,
|
||||
conversation history management, and text buffering capabilities.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSNovaSonicLLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize AWS Nova Sonic LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial messages for the context.
|
||||
tools: Available tools for the context.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Super handles deprecation warning
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self, system_instruction: str = ""):
|
||||
self._assistant_text = ""
|
||||
self._user_text = ""
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_nova_sonic(
|
||||
obj: OpenAILLMContext, system_instruction: str
|
||||
) -> "AWSNovaSonicLLMContext":
|
||||
"""Upgrade an OpenAI context to AWS Nova Sonic context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAI context to upgrade.
|
||||
system_instruction: System instruction for the context.
|
||||
|
||||
Returns:
|
||||
The upgraded AWS Nova Sonic context.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext):
|
||||
obj.__class__ = AWSNovaSonicLLMContext
|
||||
obj.__setup_local(system_instruction)
|
||||
return obj
|
||||
|
||||
# NOTE: this method has the side-effect of updating _system_instruction from messages
|
||||
def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory:
|
||||
"""Get conversation history for initializing AWS Nova Sonic session.
|
||||
|
||||
Processes stored messages and extracts system instruction and conversation
|
||||
history in the format expected by AWS Nova Sonic.
|
||||
|
||||
Returns:
|
||||
Formatted conversation history with system instruction and messages.
|
||||
"""
|
||||
history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction)
|
||||
|
||||
# Bail if there are no messages
|
||||
if not self.messages:
|
||||
return history
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into "instruction"
|
||||
if messages[0].get("role") == "system":
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
history.system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
history.system_instruction = content[0].get("text")
|
||||
if history.system_instruction:
|
||||
self._system_instruction = history.system_instruction
|
||||
|
||||
# Process remaining messages to fill out conversation history.
|
||||
# Nova Sonic supports "user" and "assistant" messages in history.
|
||||
for message in messages:
|
||||
history_message = self.from_standard_message(message)
|
||||
if history_message:
|
||||
history.messages.append(history_message)
|
||||
|
||||
return history
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Returns:
|
||||
List of messages including system instruction if present.
|
||||
"""
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
# If we have a system instruction and messages doesn't already contain it, add it
|
||||
if self._system_instruction and not (messages and messages[0].get("role") == "system"):
|
||||
messages.insert(0, {"role": "system", "content": self._system_instruction})
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
|
||||
"""Convert standard message format to Nova Sonic format.
|
||||
|
||||
Args:
|
||||
message: Standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Nova Sonic conversation history message, or None if not convertible.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if message.get("role") == "user" or message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
# There won't be content if this is an assistant tool call entry.
|
||||
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
|
||||
# history
|
||||
if content:
|
||||
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
|
||||
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
|
||||
# Sonic conversation history
|
||||
|
||||
def buffer_user_text(self, text):
|
||||
"""Buffer user text for later flushing to context.
|
||||
|
||||
Args:
|
||||
text: User text to buffer.
|
||||
"""
|
||||
self._user_text += f" {text}" if self._user_text else text
|
||||
# logger.debug(f"User text buffered: {self._user_text}")
|
||||
|
||||
def flush_aggregated_user_text(self) -> str:
|
||||
"""Flush buffered user text to context as a complete message.
|
||||
|
||||
Returns:
|
||||
The flushed user text, or empty string if no text was buffered.
|
||||
"""
|
||||
if not self._user_text:
|
||||
return ""
|
||||
user_text = self._user_text
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}],
|
||||
}
|
||||
self._user_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (user): {self.get_messages_for_logging()}")
|
||||
return user_text
|
||||
|
||||
def buffer_assistant_text(self, text):
|
||||
"""Buffer assistant text for later flushing to context.
|
||||
|
||||
Args:
|
||||
text: Assistant text to buffer.
|
||||
"""
|
||||
self._assistant_text += text
|
||||
# logger.debug(f"Assistant text buffered: {self._assistant_text}")
|
||||
|
||||
def flush_aggregated_assistant_text(self):
|
||||
"""Flush buffered assistant text to context as a complete message."""
|
||||
if not self._assistant_text:
|
||||
return
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": self._assistant_text}],
|
||||
}
|
||||
self._assistant_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicMessagesUpdateFrame(DataFrame):
|
||||
"""Frame containing updated AWS Nova Sonic context.
|
||||
|
||||
Parameters:
|
||||
context: The updated AWS Nova Sonic LLM context.
|
||||
"""
|
||||
|
||||
context: AWSNovaSonicLLMContext
|
||||
|
||||
|
||||
class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""Context aggregator for user messages in AWS Nova Sonic conversations.
|
||||
|
||||
Extends the OpenAI user context aggregator to emit Nova Sonic-specific
|
||||
context update frames.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSNovaSonicUserContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process frames and emit Nova Sonic-specific context updates.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction the frame is traveling.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Parent does not push LLMMessagesUpdateFrame
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context))
|
||||
|
||||
|
||||
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Context aggregator for assistant messages in AWS Nova Sonic conversations.
|
||||
|
||||
Provides specialized handling for assistant responses and function calls
|
||||
in AWS Nova Sonic context, with custom frame processing logic.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSNovaSonicAssistantContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Nova Sonic-specific logic.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction the frame is traveling.
|
||||
"""
|
||||
# HACK: For now, disable the context aggregator by making it just pass through all frames
|
||||
# that the parent handles (except the function call stuff, which we still need).
|
||||
# For an explanation of this hack, see
|
||||
# AWSNovaSonicLLMService._report_assistant_response_text_added.
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
TextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
UserImageRawFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
),
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call results for AWS Nova Sonic.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the LLM
|
||||
# itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side
|
||||
# context. Let's push a special frame to do that.
|
||||
await self.push_frame(
|
||||
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicContextAggregatorPair:
|
||||
"""Pair of user and assistant context aggregators for AWS Nova Sonic.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`AWSNovaSonicContextAggregatorPair` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator.
|
||||
_assistant: The assistant context aggregator.
|
||||
"""
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
_user: AWSNovaSonicUserContextAggregator
|
||||
_assistant: AWSNovaSonicAssistantContextAggregator
|
||||
|
||||
def user(self) -> AWSNovaSonicUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> AWSNovaSonicAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frames for AWS Nova Sonic LLM service."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call result for AWS Nova Sonic processing.
|
||||
|
||||
This frame wraps a standard function call result frame to enable
|
||||
AWS Nova Sonic-specific handling and context updates.
|
||||
|
||||
Parameters:
|
||||
result_frame: The underlying function call result frame.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
@@ -49,15 +49,7 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven
|
||||
@@ -531,12 +523,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
context = (
|
||||
frame.context
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
await self._handle_context(context)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._handle_input_audio_frame(frame)
|
||||
@@ -1353,44 +1341,6 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# We're no longer waiting for a trigger transcription
|
||||
self._waiting_for_trigger_transcription = False
|
||||
|
||||
#
|
||||
# context
|
||||
#
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create context aggregator pair for managing conversation context.
|
||||
|
||||
NOTE: this method exists only for backward compatibility. New code
|
||||
should instead do::
|
||||
|
||||
context = LLMContext(...)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
Args:
|
||||
context: The OpenAI LLM context.
|
||||
user_params: Parameters for the user context aggregator.
|
||||
assistant_params: Parameters for the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
A pair of user and assistant context aggregators.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
# from_openai_context handles deprecation warning
|
||||
context = LLMContext.from_openai_context(context)
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
#
|
||||
# assistant response trigger
|
||||
# HACK: only needed for the older Nova Sonic (as opposed to Nova 2 Sonic) model
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Context management for AWS Nova Sonic LLM service.
|
||||
|
||||
This module provides specialized context aggregators and message handling for AWS Nova Sonic,
|
||||
including conversation history management and role-specific message processing.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
AWS Nova Sonic no longer uses types from this module under the hood.
|
||||
It now uses `LLMContext` and `LLMContextAggregatorPair`.
|
||||
Using the new patterns should allow you to not need types from this module.
|
||||
|
||||
See deprecation warning in pipecat.services.aws.nova_sonic.context for more
|
||||
details.
|
||||
"""
|
||||
|
||||
from pipecat.services.aws.nova_sonic.context import *
|
||||
@@ -1,21 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frames for AWS Nova Sonic LLM service."""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.services.aws.nova_sonic.frames import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.aws_nova_sonic.frames are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.aws.nova_sonic.frames instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,7 +0,0 @@
|
||||
from .file_api import GeminiFileAPI
|
||||
from .gemini import GeminiMultimodalLiveLLMService
|
||||
|
||||
__all__ = [
|
||||
"GeminiFileAPI",
|
||||
"GeminiMultimodalLiveLLMService",
|
||||
]
|
||||
@@ -1,44 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Event models and utilities for Google Gemini Multimodal Live API.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Importing StartSensitivity and EndSensitivity from this module is deprecated.
|
||||
Import them directly from google.genai.types instead.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from google.genai.types import (
|
||||
EndSensitivity as _EndSensitivity,
|
||||
)
|
||||
from google.genai.types import (
|
||||
StartSensitivity as _StartSensitivity,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# These aliases are just here for backward compatibility, since we used to
|
||||
# define public-facing StartSensitivity and EndSensitivity enums in this
|
||||
# module.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Importing StartSensitivity and EndSensitivity from "
|
||||
"pipecat.services.gemini_multimodal_live.events is deprecated. "
|
||||
"Please import them directly from google.genai.types instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
StartSensitivity = _StartSensitivity
|
||||
EndSensitivity = _EndSensitivity
|
||||
@@ -1,39 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gemini File API client for uploading and managing files.
|
||||
|
||||
This module provides a client for Google's Gemini File API, enabling file
|
||||
uploads, metadata retrieval, listing, and deletion. Files uploaded through
|
||||
this API can be referenced in Gemini generative model calls.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Importing GeminiFileAPI from this module is deprecated.
|
||||
Import it from pipecat.services.google.gemini_live.file_api instead.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from pipecat.services.google.gemini_live.file_api import GeminiFileAPI as _GeminiFileAPI
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# These aliases are just here for backward compatibility, since we used to
|
||||
# define public-facing StartSensitivity and EndSensitivity enums in this
|
||||
# module.
|
||||
warnings.warn(
|
||||
"Importing GeminiFileAPI from "
|
||||
"pipecat.services.gemini_multimodal_live.file_api is deprecated. "
|
||||
"Please import it from pipecat.services.google.gemini_live.file_api instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
GeminiFileAPI = _GeminiFileAPI
|
||||
@@ -1,57 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Google Gemini Live API service implementation.
|
||||
|
||||
This module provides real-time conversational AI capabilities using Google's
|
||||
Gemini Live API, supporting both text and audio modalities with
|
||||
voice transcription, streaming responses, and tool usage.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
This module is deprecated. Please use the equivalent types from
|
||||
pipecat.services.google.gemini_live.llm instead. Note that the new type names
|
||||
do not include 'Multimodal'.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.services.google.gemini_live.llm import (
|
||||
ContextWindowCompressionParams as _ContextWindowCompressionParams,
|
||||
)
|
||||
from pipecat.services.google.gemini_live.llm import (
|
||||
GeminiLiveAssistantContextAggregator,
|
||||
GeminiLiveContext,
|
||||
GeminiLiveContextAggregatorPair,
|
||||
GeminiLiveLLMService,
|
||||
GeminiLiveUserContextAggregator,
|
||||
GeminiModalities,
|
||||
)
|
||||
from pipecat.services.google.gemini_live.llm import GeminiMediaResolution as _GeminiMediaResolution
|
||||
from pipecat.services.google.gemini_live.llm import GeminiVADParams as _GeminiVADParams
|
||||
from pipecat.services.google.gemini_live.llm import InputParams as _InputParams
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.gemini_multimodal_live.gemini are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.google.gemini_live.llm instead. Note that the new type "
|
||||
"names do not include 'Multimodal' "
|
||||
"(e.g. `GeminiMultimodalLiveLLMService` is now `GeminiLiveLLMService`).",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
GeminiMultimodalLiveContext = GeminiLiveContext
|
||||
GeminiMultimodalLiveUserContextAggregator = GeminiLiveUserContextAggregator
|
||||
GeminiMultimodalLiveAssistantContextAggregator = GeminiLiveAssistantContextAggregator
|
||||
GeminiMultimodalLiveContextAggregatorPair = GeminiLiveContextAggregatorPair
|
||||
GeminiMultimodalModalities = GeminiModalities
|
||||
GeminiMediaResolution = _GeminiMediaResolution
|
||||
GeminiVADParams = _GeminiVADParams
|
||||
ContextWindowCompressionParams = _ContextWindowCompressionParams
|
||||
InputParams = _InputParams
|
||||
GeminiMultimodalLiveLLMService = GeminiLiveLLMService
|
||||
@@ -59,23 +59,11 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame, LLMSearchResult
|
||||
from pipecat.services.google.utils import update_google_client_http_options
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
@@ -224,274 +212,6 @@ def language_to_gemini_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class GeminiLiveContext(OpenAILLMContext):
|
||||
"""Extended OpenAI context for Gemini Live API.
|
||||
|
||||
Provides Gemini-specific context management including system instruction
|
||||
extraction and message format conversion for the Live API.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
Gemini Live no longer uses `GeminiLiveContext` under the hood.
|
||||
It now uses `LLMContext`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def upgrade(obj: OpenAILLMContext) -> "GeminiLiveContext":
|
||||
"""Upgrade an OpenAI context to Gemini context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAI context to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded Gemini context instance.
|
||||
"""
|
||||
# This warning is here rather than `__init__` since `upgrade()` was the
|
||||
# "main" way that GeminiLiveContext instances were created.
|
||||
# Almost no users should be seeing this message anyway, as
|
||||
# GeminiLiveContext instances were typically created under the hood:
|
||||
# the user would pass an OpenAILLMContext instance, which would be
|
||||
# upgraded without them necessarily knowing.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"GeminiLiveContext is deprecated. "
|
||||
"Gemini Live no longer uses GeminiLiveContext under the hood. "
|
||||
"It now uses LLMContext.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GeminiLiveContext):
|
||||
logger.debug(f"Upgrading to Gemini Live Context: {obj}")
|
||||
obj.__class__ = GeminiLiveContext
|
||||
obj._restructure_from_openai_messages()
|
||||
return obj
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
pass
|
||||
|
||||
def extract_system_instructions(self):
|
||||
"""Extract system instructions from context messages.
|
||||
|
||||
Returns:
|
||||
Combined system instruction text from all system messages.
|
||||
"""
|
||||
system_instruction = ""
|
||||
for item in self.messages:
|
||||
if item.get("role") == "system":
|
||||
content = item.get("content", "")
|
||||
if content:
|
||||
if system_instruction and not system_instruction.endswith("\n"):
|
||||
system_instruction += "\n"
|
||||
system_instruction += str(content)
|
||||
return system_instruction
|
||||
|
||||
def add_file_reference(self, file_uri: str, mime_type: str, text: Optional[str] = None):
|
||||
"""Add a file reference to the context.
|
||||
|
||||
This adds a user message with a file reference that will be sent during context initialization.
|
||||
|
||||
Args:
|
||||
file_uri: URI of the uploaded file
|
||||
mime_type: MIME type of the file
|
||||
text: Optional text prompt to accompany the file
|
||||
"""
|
||||
# Create parts list with file reference
|
||||
parts = []
|
||||
if text:
|
||||
parts.append({"type": "text", "text": text})
|
||||
|
||||
# Add file reference part
|
||||
parts.append(
|
||||
{"type": "file_data", "file_data": {"mime_type": mime_type, "file_uri": file_uri}}
|
||||
)
|
||||
|
||||
# Add to messages
|
||||
message = {"role": "user", "content": parts}
|
||||
self.messages.append(message)
|
||||
logger.info(f"Added file reference to context: {file_uri}")
|
||||
|
||||
def get_messages_for_initializing_history(self) -> List[Content]:
|
||||
"""Get messages formatted for Gemini history initialization.
|
||||
|
||||
Returns:
|
||||
List of messages in Gemini format for conversation history.
|
||||
"""
|
||||
messages: List[Content] = []
|
||||
for item in self.messages:
|
||||
role = item.get("role")
|
||||
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
content = item.get("content")
|
||||
parts: List[Part] = []
|
||||
if isinstance(content, str):
|
||||
parts = [Part(text=content)]
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if part.get("type") == "text":
|
||||
parts.append(Part(text=part.get("text")))
|
||||
elif part.get("type") == "file_data":
|
||||
file_data = part.get("file_data", {})
|
||||
parts.append(
|
||||
Part(
|
||||
file_data=FileData(
|
||||
mime_type=file_data.get("mime_type"),
|
||||
file_uri=file_data.get("file_uri"),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported content type: {str(part)[:80]}")
|
||||
else:
|
||||
logger.warning(f"Unsupported content type: {str(content)[:80]}")
|
||||
messages.append(Content(role=role, parts=parts))
|
||||
return messages
|
||||
|
||||
|
||||
class GeminiLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""User context aggregator for Gemini Live.
|
||||
|
||||
Extends OpenAI user aggregator to handle Gemini-specific message passing
|
||||
while maintaining compatibility with the standard aggregation pipeline.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
Gemini Live no longer expects a `GeminiLiveUserContextAggregator`.
|
||||
It now expects a `LLMUserAggregator`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize Gemini Live user context aggregator."""
|
||||
# Almost no users should be seeing this message, as
|
||||
# `GeminiLiveUserContextAggregator`` instances were typically created
|
||||
# under the hood, as part of `llm.create_context_aggregator()`.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"GeminiLiveUserContextAggregator is deprecated. "
|
||||
"Gemini Live no longer expects a GeminiLiveUserContextAggregator. "
|
||||
"It now expects a LLMUserAggregator.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process incoming frames for user context aggregation.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
# kind of a hack just to pass the LLMMessagesAppendFrame through, but it's fine for now
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class GeminiLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Assistant context aggregator for Gemini Live.
|
||||
|
||||
Handles assistant response aggregation while filtering out LLMTextFrames
|
||||
to prevent duplicate context entries, as Gemini Live pushes both
|
||||
LLMTextFrames and TTSTextFrames.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
Gemini Live no longer uses `GeminiLiveAssistantContextAggregator` under the hood.
|
||||
It now uses `LLMAssistantAggregator`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize Gemini Live assistant context aggregator."""
|
||||
# Almost no users should be seeing this message, as
|
||||
# `GeminiLiveAssistantContextAggregator` instances were typically
|
||||
# created under the hood, as part of `llm.create_context_aggregator()`.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"GeminiLiveAssistantContextAggregator is deprecated. "
|
||||
"Gemini Live no longer uses GeminiLiveAssistantContextAggregator under the hood. "
|
||||
"It now uses LLMAssistantAggregator.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames for assistant context aggregation.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The frame processing direction.
|
||||
"""
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the GeminiLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
if not isinstance(frame, LLMTextFrame):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle user image frames.
|
||||
|
||||
Args:
|
||||
frame: The user image frame to handle.
|
||||
"""
|
||||
# We don't want to store any images in the context. Revisit this later
|
||||
# when the API evolves.
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeminiLiveContextAggregatorPair:
|
||||
"""Pair of user and assistant context aggregators for Gemini Live.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
`GeminiLiveContextAggregatorPair` is deprecated.
|
||||
Use `LLMContextAggregatorPair` instead.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator instance.
|
||||
_assistant: The assistant context aggregator instance.
|
||||
"""
|
||||
|
||||
_user: GeminiLiveUserContextAggregator
|
||||
_assistant: GeminiLiveAssistantContextAggregator
|
||||
|
||||
def __post_init__(self):
|
||||
# Almost no users should be seeing this message, as
|
||||
# `GeminiLiveContextAggregatorPair` instances were typically created
|
||||
# under the hood, with `llm.create_context_aggregator()`.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"GeminiLiveContextAggregatorPair is deprecated. "
|
||||
"Use LLMContextAggregatorPair instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def user(self) -> GeminiLiveUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> GeminiLiveAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class GeminiModalities(Enum):
|
||||
"""Supported modalities for Gemini Live.
|
||||
|
||||
@@ -945,23 +665,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._settings.language = self._language_code
|
||||
logger.info(f"Set Gemini language to: {self._language_code}")
|
||||
|
||||
async def set_context(self, context: OpenAILLMContext):
|
||||
"""Set the context explicitly from outside the pipeline.
|
||||
|
||||
This is useful when initializing a conversation because in server-side VAD mode we might not have a
|
||||
way to trigger the pipeline. This sends the history to the server. The `inference_on_context_initialization`
|
||||
flag controls whether to set the turnComplete flag when we do this. Without that flag, the model will
|
||||
not respond. This is often what we want when setting the context at the beginning of a conversation.
|
||||
|
||||
Args:
|
||||
context: The OpenAI LLM context to set.
|
||||
"""
|
||||
if self._context:
|
||||
logger.error("Context already set. Can only set up Gemini Live context once.")
|
||||
return
|
||||
self._context = GeminiLiveContext.upgrade(context)
|
||||
await self._create_initial_response()
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
@@ -1053,13 +756,8 @@ class GeminiLiveLLMService(LLMService):
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
context = (
|
||||
frame.context
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
await self._handle_context(context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
await self._handle_context(frame.context)
|
||||
elif isinstance(frame, InputTextRawFrame):
|
||||
await self._send_user_text(frame.text)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -2078,40 +1776,3 @@ class GeminiLiveLLMService(LLMService):
|
||||
# cost/stability implications for a service cluster, let's just treat a
|
||||
# send-side error as fatal.
|
||||
await self.push_error(error_msg=f"Send error: {error}")
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create an instance of GeminiLiveContextAggregatorPair from an OpenAILLMContext.
|
||||
|
||||
Constructor keyword arguments for both the user and assistant aggregators can be provided.
|
||||
|
||||
NOTE: this method exists only for backward compatibility. New code
|
||||
should instead do::
|
||||
|
||||
context = LLMContext(...)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
Args:
|
||||
context: The LLM context to use.
|
||||
user_params: User aggregator parameters. Defaults to LLMUserAggregatorParams().
|
||||
assistant_params: Assistant aggregator parameters. Defaults to LLMAssistantAggregatorParams().
|
||||
|
||||
Returns:
|
||||
A pair of user and assistant context aggregators.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
# from_openai_context handles deprecation warning
|
||||
context = LLMContext.from_openai_context(context)
|
||||
assistant_params.expect_stripped_words = False
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
@@ -34,29 +34,16 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.google.utils import update_google_client_http_options
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.settings import (
|
||||
NOT_GIVEN,
|
||||
LLMSettings,
|
||||
@@ -90,595 +77,6 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""Google-specific user context aggregator.
|
||||
|
||||
Extends OpenAI user context aggregator to handle Google AI's specific
|
||||
Content and Part message format for user messages.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIUserContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
"""Add the aggregated user text to the context as a Google Content message.
|
||||
|
||||
Args:
|
||||
aggregation: The aggregated user text to add as a user message.
|
||||
"""
|
||||
self._context.add_message(Content(role="user", parts=[Part(text=aggregation)]))
|
||||
|
||||
|
||||
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Google-specific assistant context aggregator.
|
||||
|
||||
Extends OpenAI assistant context aggregator to handle Google AI's specific
|
||||
Content and Part message format for assistant responses and function calls.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`GoogleAssistantContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
"""Handle aggregated assistant text response.
|
||||
|
||||
Args:
|
||||
aggregation: The aggregated text response from the assistant.
|
||||
"""
|
||||
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle function call in progress frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call details.
|
||||
"""
|
||||
self._context.add_message(
|
||||
Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
self._context.add_message(
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
id=frame.tool_call_id,
|
||||
name=frame.function_name,
|
||||
response={"response": "IN_PROGRESS"},
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call result.
|
||||
"""
|
||||
if frame.result:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, frame.result
|
||||
)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
"""Handle function call cancellation frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call cancellation details.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message.role == "user":
|
||||
for part in message.parts:
|
||||
if part.function_response and part.function_response.id == tool_call_id:
|
||||
part.function_response.response = {
|
||||
"value": json.dumps(result, ensure_ascii=False)
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleContextAggregatorPair:
|
||||
"""Pair of Google context aggregators for user and assistant messages.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`GoogleContextAggregatorPair` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for handling user messages.
|
||||
_assistant: Assistant context aggregator for handling assistant responses.
|
||||
"""
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
_user: GoogleUserContextAggregator
|
||||
_assistant: GoogleAssistantContextAggregator
|
||||
|
||||
def user(self) -> GoogleUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> GoogleAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class GoogleLLMContext(OpenAILLMContext):
|
||||
"""Google AI LLM context that extends OpenAI context for Google-specific formatting.
|
||||
|
||||
This class handles conversion between OpenAI-style messages and Google AI's
|
||||
Content/Part format, including system messages, function calls, and media.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`GoogleLLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize GoogleLLMContext.
|
||||
|
||||
Args:
|
||||
messages: Initial messages in OpenAI format.
|
||||
tools: Available tools/functions for the model.
|
||||
tool_choice: Tool choice configuration.
|
||||
"""
|
||||
# Super handles deprecation warning
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.system_message = None
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
|
||||
"""Upgrade an OpenAI context to a Google context.
|
||||
|
||||
Args:
|
||||
obj: OpenAI LLM context to upgrade.
|
||||
|
||||
Returns:
|
||||
GoogleLLMContext instance with converted messages.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
|
||||
logger.debug(f"Upgrading to Google: {obj}")
|
||||
obj.__class__ = GoogleLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
return obj
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
"""Set messages and restructure them for Google format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to set.
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
def add_messages(self, messages: List):
|
||||
"""Add messages to the context, converting to Google format as needed.
|
||||
|
||||
Args:
|
||||
messages: List of messages to add (can be mixed formats).
|
||||
"""
|
||||
# Convert each message individually
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, Content):
|
||||
# Already in Gemini format
|
||||
converted_messages.append(msg)
|
||||
else:
|
||||
# Convert from standard format to Gemini format
|
||||
converted = self.from_standard_message(msg)
|
||||
if converted is not None:
|
||||
converted_messages.append(converted)
|
||||
|
||||
# Add the converted messages to our existing messages
|
||||
self._messages.extend(converted_messages)
|
||||
|
||||
def get_messages_for_logging(self) -> List[Dict[str, Any]]:
|
||||
"""Get messages formatted for logging with sensitive data redacted.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
obj = message.to_json_dict()
|
||||
try:
|
||||
if "parts" in obj:
|
||||
for part in obj["parts"]:
|
||||
if "inline_data" in part:
|
||||
part["inline_data"]["data"] = "..."
|
||||
except Exception as e:
|
||||
logger.debug(f"Error: {e}")
|
||||
msgs.append(obj)
|
||||
return msgs
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
"""Add an image message to the context.
|
||||
|
||||
Args:
|
||||
format: Image format (e.g., 'RGB', 'RGBA').
|
||||
size: Image dimensions as (width, height).
|
||||
image: Raw image bytes.
|
||||
text: Optional text to accompany the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(Part(text=text))
|
||||
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
|
||||
|
||||
self.add_message(Content(role="user", parts=parts))
|
||||
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add audio frames as a message to the context.
|
||||
|
||||
Args:
|
||||
audio_frames: List of audio frames to add.
|
||||
text: Text description of the audio content.
|
||||
"""
|
||||
if not audio_frames:
|
||||
return
|
||||
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
parts = []
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
# NOTE(aleix): According to the docs only text or inline_data should be needed.
|
||||
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
|
||||
parts.append(Part(text=text))
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="audio/wav",
|
||||
data=(
|
||||
bytes(
|
||||
self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.add_message(Content(role="user", parts=parts))
|
||||
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
|
||||
# self.add_message(message)
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's format.
|
||||
System messages are stored separately and return None.
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Content object with role and parts, or None for system messages.
|
||||
|
||||
Examples:
|
||||
Standard text message::
|
||||
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello there"
|
||||
}
|
||||
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(text="Hello there")]
|
||||
)
|
||||
|
||||
Standard function call message::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"query": "test"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="model",
|
||||
parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))]
|
||||
)
|
||||
|
||||
System message returns None and stores content in self.system_message.
|
||||
"""
|
||||
role = message["role"]
|
||||
content = message.get("content", [])
|
||||
if role == "system":
|
||||
# System instructions are returned as plain text
|
||||
if isinstance(content, str):
|
||||
self.system_message = content
|
||||
elif isinstance(content, list):
|
||||
# If content is a list, we assume it's a list of text parts, per the standard
|
||||
self.system_message = " ".join(
|
||||
part["text"] for part in content if part.get("type") == "text"
|
||||
)
|
||||
return None
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
parts.append(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
args=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
try:
|
||||
response = json.loads(message["content"])
|
||||
if isinstance(response, dict):
|
||||
response_dict = response
|
||||
else:
|
||||
response_dict = {"value": response}
|
||||
except Exception as e:
|
||||
# Response might not be JSON-deserializable (e.g. plain text).
|
||||
response_dict = {"value": message["content"]}
|
||||
parts.append(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name="tool_call_result", # seems to work to hard-code the same name every time
|
||||
response=response_dict,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
parts.append(Part(text=content))
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
parts.append(Part(text=c["text"]))
|
||||
elif c["type"] == "image_url":
|
||||
# Extract MIME type from data URL (format: "data:image/jpeg;base64,...")
|
||||
url = c["image_url"]["url"]
|
||||
mime_type = (
|
||||
url.split(":")[1].split(";")[0] if url.startswith("data:") else "image/jpeg"
|
||||
)
|
||||
parts.append(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type=mime_type,
|
||||
data=base64.b64decode(url.split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
message = Content(role=role, parts=parts)
|
||||
return message
|
||||
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
"""Convert Google Content object to standard structured format.
|
||||
|
||||
Handles text, images, and function calls from Google's Content/Part objects.
|
||||
|
||||
Args:
|
||||
obj: Google Content object with role and parts.
|
||||
|
||||
Returns:
|
||||
List containing a single message in standard format.
|
||||
|
||||
Examples:
|
||||
Google Content with text::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(text="Hello")]
|
||||
)
|
||||
|
||||
Converts to::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
}
|
||||
]
|
||||
|
||||
Google Content with function call::
|
||||
|
||||
Content(
|
||||
role="model",
|
||||
parts=[Part(function_call=FunctionCall(name="search", args={"q": "test"}))]
|
||||
)
|
||||
|
||||
Converts to::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "search",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q": "test"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
Google Content with image::
|
||||
|
||||
Content(
|
||||
role="user",
|
||||
parts=[Part(inline_data=Blob(mime_type="image/jpeg", data=bytes_data))]
|
||||
)
|
||||
|
||||
Converts to::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,<encoded_data>"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
msg = {"role": obj.role, "content": []}
|
||||
if msg["role"] == "model":
|
||||
msg["role"] = "assistant"
|
||||
|
||||
for part in obj.parts:
|
||||
if part.text:
|
||||
msg["content"].append({"type": "text", "text": part.text})
|
||||
elif part.inline_data:
|
||||
encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
|
||||
}
|
||||
)
|
||||
elif part.function_call:
|
||||
args = part.function_call.args if hasattr(part.function_call, "args") else {}
|
||||
msg["tool_calls"] = [
|
||||
{
|
||||
"id": part.function_call.name,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": part.function_call.name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
elif part.function_response:
|
||||
msg["role"] = "tool"
|
||||
resp = (
|
||||
part.function_response.response
|
||||
if hasattr(part.function_response, "response")
|
||||
else {}
|
||||
)
|
||||
msg["tool_call_id"] = part.function_response.name
|
||||
msg["content"] = json.dumps(resp)
|
||||
|
||||
# there might be no content parts for tool_calls messages
|
||||
if not msg["content"]:
|
||||
del msg["content"]
|
||||
return [msg]
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
"""Restructures messages to ensure proper Google format and message ordering.
|
||||
|
||||
This method handles conversion of OpenAI-formatted messages to Google format,
|
||||
with special handling for function calls, function responses, and system messages.
|
||||
System messages are added back to the context as user messages when needed.
|
||||
|
||||
The final message order is preserved as:
|
||||
1. Function calls (from model)
|
||||
2. Function responses (from user)
|
||||
3. Text messages (converted from system messages)
|
||||
|
||||
Note:
|
||||
System messages are only added back when there are no regular text
|
||||
messages in the context, ensuring proper conversation continuity
|
||||
after function calls.
|
||||
"""
|
||||
self.system_message = None
|
||||
converted_messages = []
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
for message in self._messages:
|
||||
if isinstance(message, Content):
|
||||
# Keep existing Google-formatted messages (e.g., function calls/responses)
|
||||
converted_messages.append(message)
|
||||
continue
|
||||
|
||||
# Convert OpenAI format to Google format, system messages return None
|
||||
converted = self.from_standard_message(message)
|
||||
if converted is not None:
|
||||
converted_messages.append(converted)
|
||||
|
||||
# Update message list
|
||||
self._messages[:] = converted_messages
|
||||
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
len(msg.parts) == 1
|
||||
and getattr(msg.parts[0], "text", None)
|
||||
and not getattr(msg.parts[0], "function_call", None)
|
||||
and not getattr(msg.parts[0], "function_response", None)
|
||||
for msg in self._messages
|
||||
)
|
||||
|
||||
# Add system message back as a user message if we only have function messages
|
||||
if self.system_message and not has_regular_messages:
|
||||
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
|
||||
|
||||
# Remove any empty messages
|
||||
self._messages = [m for m in self._messages if m.parts]
|
||||
|
||||
|
||||
class GoogleThinkingConfig(BaseModel):
|
||||
"""Configuration for controlling the model's internal "thinking" process used before generating a response.
|
||||
|
||||
@@ -741,8 +139,7 @@ class GoogleLLMService(LLMService):
|
||||
"""Google AI (Gemini) LLM service implementation.
|
||||
|
||||
This class implements inference with Google's AI models, translating internally
|
||||
from an OpenAILLMContext or a universal LLMContext to the messages format
|
||||
expected by the Google AI model.
|
||||
from an LLMContext to the messages format expected by the Google AI model.
|
||||
"""
|
||||
|
||||
Settings = GoogleLLMSettings
|
||||
@@ -885,7 +282,7 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
async def run_inference(
|
||||
self,
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
@@ -905,19 +302,13 @@ class GoogleLLMService(LLMService):
|
||||
system = []
|
||||
tools = []
|
||||
effective_instruction = system_instruction or self._settings.system_instruction
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=effective_instruction
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system_instruction"]
|
||||
tools = params["tools"]
|
||||
else:
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system_message", None)
|
||||
tools = context.tools or []
|
||||
adapter = self.get_llm_adapter()
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=effective_instruction
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system_instruction"]
|
||||
tools = params["tools"]
|
||||
|
||||
# Build generation config using the same method as streaming
|
||||
generation_params = self._build_generation_params(
|
||||
@@ -1004,17 +395,24 @@ class GoogleLLMService(LLMService):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unset thinking budget: {e}")
|
||||
|
||||
async def _stream_content(
|
||||
self, params_from_context: GeminiLLMInvocationParams
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
messages = params_from_context["messages"]
|
||||
async def _stream_content(self, context: LLMContext) -> AsyncIterator[GenerateContentResponse]:
|
||||
adapter = self.get_llm_adapter()
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=self._settings.system_instruction
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
messages = params["messages"]
|
||||
|
||||
# The adapter already resolved system_instruction vs context system message.
|
||||
system_instruction = params_from_context["system_instruction"]
|
||||
system_instruction = params["system_instruction"]
|
||||
|
||||
tools = []
|
||||
if params_from_context["tools"]:
|
||||
tools = params_from_context["tools"]
|
||||
if params["tools"]:
|
||||
tools = params["tools"]
|
||||
elif self._tools:
|
||||
tools = self._tools
|
||||
tool_config = None
|
||||
@@ -1040,37 +438,8 @@ class GoogleLLMService(LLMService):
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
async def _stream_content_specific_context(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from LLM-specific context [{context.system_message}] | {context.get_messages_for_logging()}"
|
||||
)
|
||||
|
||||
params = GeminiLLMInvocationParams(
|
||||
messages=context.messages,
|
||||
system_instruction=context.system_message,
|
||||
tools=context.tools,
|
||||
)
|
||||
|
||||
return await self._stream_content(params)
|
||||
|
||||
async def _stream_content_universal_context(
|
||||
self, context: LLMContext
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
adapter = self.get_llm_adapter()
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context, system_instruction=self._settings.system_instruction
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from universal context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
return await self._stream_content(params)
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
prompt_tokens = 0
|
||||
@@ -1083,12 +452,8 @@ class GoogleLLMService(LLMService):
|
||||
accumulated_text = ""
|
||||
|
||||
try:
|
||||
# Generate content using either OpenAILLMContext or universal LLMContext
|
||||
response = await (
|
||||
self._stream_content_specific_context(context)
|
||||
if isinstance(context, OpenAILLMContext)
|
||||
else self._stream_content_universal_context(context)
|
||||
)
|
||||
# Generate content from LLMContext
|
||||
response = await self._stream_content(context)
|
||||
|
||||
function_calls = []
|
||||
async for chunk in response:
|
||||
@@ -1274,15 +639,8 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
context = None
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = GoogleLLMContext.upgrade_to_google(frame.context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
# Handle universal (LLM-agnostic) LLM context frames
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
# NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal
|
||||
# LLMContext with it
|
||||
context = GoogleLLMContext(frame.messages)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -1305,41 +663,3 @@ class GoogleLLMService(LLMService):
|
||||
except Exception:
|
||||
# Do nothing - we're shutting down anyway
|
||||
pass
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> GoogleContextAggregatorPair:
|
||||
"""Create Google-specific context aggregators.
|
||||
|
||||
Creates a pair of context aggregators optimized for Google's message format,
|
||||
including support for function calls, tool usage, and image handling.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
GoogleContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
GoogleContextAggregatorPair.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
user = GoogleUserContextAggregator(context, params=user_params)
|
||||
assistant = GoogleAssistantContextAggregator(context, params=assistant_params)
|
||||
|
||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deprecated: use ``pipecat.services.google.openai.llm`` instead."""
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Module `pipecat.services.google.llm_openai` is deprecated, "
|
||||
"use `pipecat.services.google.openai.llm` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from pipecat.services.google.openai.llm import * # noqa: E402, F401, F403
|
||||
@@ -1,5 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
@@ -1,213 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Google LLM service using OpenAI-compatible API format.
|
||||
|
||||
This module provides integration with Google's AI LLM models using the OpenAI
|
||||
API format through Google's Gemini API OpenAI compatibility layer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleOpenAILLMSettings(BaseOpenAILLMService.Settings):
|
||||
"""Settings for GoogleLLMOpenAIBetaService."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
"""Google LLM service using OpenAI-compatible API format.
|
||||
|
||||
This service provides access to Google's AI LLM models (like Gemini) through
|
||||
the OpenAI API format. It handles streaming responses, function calls, and
|
||||
tool usage while maintaining compatibility with OpenAI's interface.
|
||||
|
||||
Note: This service includes a workaround for a Google API bug where function
|
||||
call indices may be incorrectly set to None, resulting in empty function names.
|
||||
|
||||
.. deprecated:: 0.0.82
|
||||
GoogleLLMOpenAIBetaService is deprecated and will be removed in a future version.
|
||||
Use GoogleLLMService instead for better integration with Google's native API.
|
||||
|
||||
Reference:
|
||||
https://ai.google.dev/gemini-api/docs/openai
|
||||
"""
|
||||
|
||||
Settings = GoogleOpenAILLMSettings
|
||||
_settings: Settings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
model: Optional[str] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Google LLM service.
|
||||
|
||||
Args:
|
||||
api_key: Google API key for authentication.
|
||||
base_url: Base URL for Google's OpenAI-compatible API.
|
||||
model: Google model name to use (e.g., "gemini-2.0-flash").
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=GoogleLLMOpenAIBetaService.Settings(model=...)`` instead.
|
||||
|
||||
settings: Runtime-updatable settings. When provided alongside deprecated
|
||||
parameters, ``settings`` values take precedence.
|
||||
**kwargs: Additional arguments passed to the parent OpenAILLMService.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"GoogleLLMOpenAIBetaService is deprecated and will be removed in a future version. "
|
||||
"Use GoogleLLMService instead for better integration with Google's native API.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = self.Settings(model="gemini-2.0-flash")
|
||||
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
if model is not None:
|
||||
self._warn_init_param_moved_to_settings("model", "model")
|
||||
default_settings.model = model
|
||||
|
||||
# 3. (No step 3, as there's no params object to apply)
|
||||
|
||||
# 4. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(api_key=api_key, base_url=base_url, settings=default_settings, **kwargs)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream: AsyncStream[
|
||||
ChatCompletionChunk
|
||||
] = await self._stream_chat_completions_specific_context(context)
|
||||
|
||||
# Use context manager to ensure stream is closed on cancellation/exception.
|
||||
# Without this, CancelledError during iteration leaves the underlying socket open.
|
||||
async with chunk_stream:
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens or 0,
|
||||
completion_tokens=chunk.usage.completion_tokens or 0,
|
||||
total_tokens=chunk.usage.total_tokens or 0,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We're streaming the LLM response to enable the fastest response times.
|
||||
# For text, we just yield each chunk as we receive it and count on consumers
|
||||
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||
#
|
||||
# If the LLM is a function call, we'll do some coalescing here.
|
||||
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||
# that they can start preparing to call the function with that name.
|
||||
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||
# the response is done, we package up all the arguments and the function name and
|
||||
# yield a frame containing the function name and the arguments.
|
||||
logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}")
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
logger.debug(
|
||||
f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
|
||||
)
|
||||
|
||||
function_calls = []
|
||||
for function_name, arguments, tool_id in zip(
|
||||
functions_list, arguments_list, tool_id_list
|
||||
):
|
||||
if function_name == "":
|
||||
# TODO: Remove the _process_context method once Google resolves the bug
|
||||
# where the index is incorrectly set to None instead of returning the actual index,
|
||||
# which currently results in an empty function name('').
|
||||
continue
|
||||
|
||||
arguments = json.loads(arguments)
|
||||
|
||||
function_calls.append(
|
||||
FunctionCallFromLLM(
|
||||
context=context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
@@ -55,11 +55,6 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import LLMSettings
|
||||
@@ -110,7 +105,7 @@ class FunctionCallParams:
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
llm: "LLMService"
|
||||
context: OpenAILLMContext | LLMContext
|
||||
context: LLMContext
|
||||
result_callback: FunctionCallResultCallback
|
||||
|
||||
|
||||
@@ -153,7 +148,7 @@ class FunctionCallRunnerItem:
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
context: OpenAILLMContext | LLMContext
|
||||
context: LLMContext
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@@ -247,7 +242,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
|
||||
async def run_inference(
|
||||
self,
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
@@ -267,41 +262,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
"""
|
||||
raise NotImplementedError(f"run_inference() not supported by {self.__class__.__name__}")
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> Any:
|
||||
"""Create a context aggregator for managing LLM conversation context.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create an aggregator for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
A context aggregator instance.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"create_context_aggregator() is deprecated and will be removed in a future version. "
|
||||
"Use the universal LLMContext and LLMContextAggregatorPair directly instead. "
|
||||
"See OpenAILLMContext docstring for migration guide.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
pass
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the LLM service.
|
||||
|
||||
|
||||
@@ -17,12 +17,8 @@ from typing import Any, Dict, List, Optional
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, LLMMessagesFrame
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
try:
|
||||
@@ -227,9 +223,7 @@ class Mem0MemoryService(FrameProcessor):
|
||||
logger.error(f"Error retrieving memories from Mem0: {e}")
|
||||
return []
|
||||
|
||||
async def _enhance_context_with_memories(
|
||||
self, context: LLMContext | OpenAILLMContext, query: str
|
||||
):
|
||||
async def _enhance_context_with_memories(self, context: LLMContext, query: str):
|
||||
"""Enhance the LLM context with relevant memories.
|
||||
|
||||
Args:
|
||||
@@ -272,13 +266,9 @@ class Mem0MemoryService(FrameProcessor):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
messages = None
|
||||
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
messages = frame.messages
|
||||
context = LLMContext(messages)
|
||||
|
||||
if context:
|
||||
try:
|
||||
@@ -302,12 +292,8 @@ class Mem0MemoryService(FrameProcessor):
|
||||
# Store the conversation in Mem0 as a background task
|
||||
self.create_task(self._store_messages(messages_to_store), name="mem0_store")
|
||||
|
||||
# If we received an LLMMessagesFrame, create a new one with the enhanced messages
|
||||
if messages is not None:
|
||||
await self.push_frame(LLMMessagesFrame(context.get_messages()))
|
||||
else:
|
||||
# Otherwise, pass the enhanced context frame downstream
|
||||
await self.push_frame(frame)
|
||||
# Pass the enhanced context frame downstream
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Error processing with Mem0: {str(e)}", exception=e
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import Optional
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
@@ -84,7 +83,7 @@ class NvidiaLLMService(OpenAILLMService):
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
|
||||
@@ -31,15 +31,10 @@ from pipecat.frames.frames import (
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN
|
||||
@@ -61,11 +56,10 @@ class OpenAILLMSettings(LLMSettings):
|
||||
class BaseOpenAILLMService(LLMService):
|
||||
"""Base class for all services that use the AsyncOpenAI client.
|
||||
|
||||
This service consumes OpenAILLMContextFrame or LLMContextFrame frames,
|
||||
which contain a reference to an OpenAILLMContext or LLMContext object. The
|
||||
context defines what is sent to the LLM for completion, including user,
|
||||
assistant, and system messages, as well as tool choices and function call
|
||||
configurations.
|
||||
This service consumes LLMContextFrame frames, which contain a reference to
|
||||
an LLMContext object. The context defines what is sent to the LLM for
|
||||
completion, including user, assistant, and system messages, as well as tool
|
||||
choices and function call configurations.
|
||||
"""
|
||||
|
||||
Settings = OpenAILLMSettings
|
||||
@@ -274,19 +268,27 @@ class BaseOpenAILLMService(LLMService):
|
||||
"""
|
||||
return self._full_model_name
|
||||
|
||||
async def get_chat_completions(
|
||||
self, params_from_context: OpenAILLMInvocationParams
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
async def get_chat_completions(self, context: LLMContext) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""Get streaming chat completions from OpenAI API with optional timeout and retry.
|
||||
|
||||
Args:
|
||||
params_from_context: Parameters, derived from the LLM context, to
|
||||
use for the chat completion. Contains messages, tools, and tool
|
||||
choice.
|
||||
context: Context to use for the chat completion.
|
||||
Contains messages, tools, and tool choice.
|
||||
|
||||
Returns:
|
||||
Async stream of chat completion chunks.
|
||||
"""
|
||||
adapter = self.get_llm_adapter()
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from context {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
params_from_context: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
system_instruction=self._settings.system_instruction,
|
||||
convert_developer_to_user=not self.supports_developer_role,
|
||||
)
|
||||
|
||||
params = self.build_chat_completion_params(params_from_context)
|
||||
|
||||
if self._retry_on_timeout:
|
||||
@@ -340,7 +342,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
|
||||
async def run_inference(
|
||||
self,
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
@@ -357,17 +359,12 @@ class BaseOpenAILLMService(LLMService):
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
effective_instruction = system_instruction or self._settings.system_instruction
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
invocation_params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
system_instruction=effective_instruction,
|
||||
convert_developer_to_user=not self.supports_developer_role,
|
||||
)
|
||||
else:
|
||||
invocation_params = OpenAILLMInvocationParams(
|
||||
messages=context.messages, tools=context.tools, tool_choice=context.tool_choice
|
||||
)
|
||||
adapter = self.get_llm_adapter()
|
||||
invocation_params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
system_instruction=effective_instruction,
|
||||
convert_developer_to_user=not self.supports_developer_role,
|
||||
)
|
||||
|
||||
# Build params using the same method as streaming completions
|
||||
params = self.build_chat_completion_params(invocation_params)
|
||||
@@ -389,59 +386,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _stream_chat_completions_specific_context(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from LLM-specific context {context.get_messages_for_logging()}"
|
||||
)
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
# base64 encode any images
|
||||
for message in messages:
|
||||
if message.get("mime_type") == "image/jpeg":
|
||||
# Avoid .getvalue() which makes a full copy of BytesIO
|
||||
raw_bytes = message["data"].read()
|
||||
encoded_image = base64.b64encode(raw_bytes).decode("utf-8")
|
||||
text = message.get("content", "")
|
||||
message["content"] = [
|
||||
{"type": "text", "text": text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
},
|
||||
]
|
||||
# Explicit cleanup
|
||||
del message["data"]
|
||||
del message["mime_type"]
|
||||
|
||||
params = OpenAILLMInvocationParams(
|
||||
messages=messages, tools=context.tools, tool_choice=context.tool_choice
|
||||
)
|
||||
chunks = await self.get_chat_completions(params)
|
||||
|
||||
return chunks
|
||||
|
||||
async def _stream_chat_completions_universal_context(
|
||||
self, context: LLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
adapter = self.get_llm_adapter()
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from universal context {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params(
|
||||
context,
|
||||
system_instruction=self._settings.system_instruction,
|
||||
convert_developer_to_user=not self.supports_developer_role,
|
||||
)
|
||||
chunks = await self.get_chat_completions(params)
|
||||
|
||||
return chunks
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
@@ -452,12 +398,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Generate chat completions using either OpenAILLMContext or universal LLMContext
|
||||
chunk_stream = await (
|
||||
self._stream_chat_completions_specific_context(context)
|
||||
if isinstance(context, OpenAILLMContext)
|
||||
else self._stream_chat_completions_universal_context(context)
|
||||
)
|
||||
# Generate chat completions from LLMContext
|
||||
chunk_stream = await self.get_chat_completions(context)
|
||||
|
||||
# Ensure stream and its async iterator are closed on cancellation/exception
|
||||
# to prevent socket leaks and uvloop crashes. Closing the iterator first
|
||||
@@ -582,9 +524,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for LLM completion requests.
|
||||
|
||||
Handles OpenAILLMContextFrame, LLMContextFrame, LLMMessagesFrame,
|
||||
and LLMUpdateSettingsFrame to trigger LLM completions and manage
|
||||
settings.
|
||||
Handles LLMContextFrame to trigger LLM completions.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
@@ -593,16 +533,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
# Handle OpenAI-specific context frames
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
# Handle universal (LLM-agnostic) LLM context frames
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
# NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal
|
||||
# LLMContext with it
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -18,51 +18,9 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIContextAggregatorPair:
|
||||
"""Pair of OpenAI context aggregators for user and assistant messages.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIContextAggregatorPair` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
_user: User context aggregator for processing user messages.
|
||||
_assistant: Assistant context aggregator for processing assistant messages.
|
||||
"""
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
_user: "OpenAIUserContextAggregator"
|
||||
_assistant: "OpenAIAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIUserContextAggregator":
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAssistantContextAggregator":
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAILLMService(BaseOpenAILLMService):
|
||||
"""OpenAI LLM service implementation.
|
||||
|
||||
@@ -145,161 +103,3 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(service_tier=service_tier, settings=default_settings, **kwargs)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create OpenAI-specific context aggregators.
|
||||
|
||||
Creates a pair of context aggregators optimized for OpenAI's message format,
|
||||
including support for function calls, tool usage, and image handling.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for user message aggregation.
|
||||
assistant_params: Parameters for assistant message aggregation.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIContextAggregatorPair.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
user = OpenAIUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
||||
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
"""OpenAI-specific user context aggregator.
|
||||
|
||||
Handles aggregation of user messages for OpenAI LLM services.
|
||||
Inherits all functionality from the base LLMUserContextAggregator.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIUserContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"""OpenAI-specific assistant context aggregator.
|
||||
|
||||
Handles aggregation of assistant messages for OpenAI LLM services,
|
||||
with specialized support for OpenAI's function calling format,
|
||||
tool usage tracking, and image message handling.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIAssistantContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle a function call in progress.
|
||||
|
||||
Adds the function call to the context with an IN_PROGRESS status
|
||||
to track ongoing function execution.
|
||||
|
||||
Args:
|
||||
frame: Frame containing function call progress information.
|
||||
"""
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle the result of a function call.
|
||||
|
||||
Updates the context with the function call result, replacing any
|
||||
previous IN_PROGRESS status.
|
||||
|
||||
Args:
|
||||
frame: Frame containing the function call result.
|
||||
"""
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result, ensure_ascii=False)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
"""Handle a cancelled function call.
|
||||
|
||||
Updates the context to mark the function call as cancelled.
|
||||
|
||||
Args:
|
||||
frame: Frame containing the function call cancellation information.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
message["role"] == "tool"
|
||||
and message["tool_call_id"]
|
||||
and message["tool_call_id"] == tool_call_id
|
||||
):
|
||||
message["content"] = result
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle a user image frame from a function call request.
|
||||
|
||||
Marks the associated function call as completed and adds the image
|
||||
to the context for processing.
|
||||
|
||||
Args:
|
||||
frame: Frame containing the user image and request context.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime LLM context and aggregator implementations.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
OpenAI Realtime no longer uses types from this module under the hood.
|
||||
It now uses ``LLMContext`` and ``LLMContextAggregatorPair``.
|
||||
Using the new patterns should allow you to not need types from this module.
|
||||
|
||||
BEFORE::
|
||||
|
||||
# Setup
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# Context aggregator type
|
||||
context_aggregator: OpenAIContextAggregatorPair
|
||||
|
||||
# Context frame type
|
||||
frame: OpenAILLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: OpenAIRealtimeLLMContext
|
||||
# or
|
||||
context: OpenAILLMContext
|
||||
|
||||
AFTER::
|
||||
|
||||
# Setup
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Context aggregator type
|
||||
context_aggregator: LLMContextAggregatorPair
|
||||
|
||||
# Context frame type
|
||||
frame: LLMContextFrame
|
||||
|
||||
# Context type
|
||||
context: LLMContext
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai.realtime.llm (or "
|
||||
"pipecat.services.openai_realtime.llm) are deprecated. \n"
|
||||
"OpenAI Realtime no longer uses types from this module under the hood. \n"
|
||||
"It now uses `LLMContext` and `LLMContextAggregatorPair`. \n"
|
||||
"Using the new patterns should allow you to not need types from this module.\n\n"
|
||||
"BEFORE:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = OpenAILLMContext(messages, tools)\n"
|
||||
"context_aggregator = llm.create_context_aggregator(context)\n\n"
|
||||
"# Context aggregator type\n"
|
||||
"context_aggregator: OpenAIContextAggregatorPair\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: OpenAILLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: OpenAIRealtimeLLMContext\n"
|
||||
"# or\n"
|
||||
"context: OpenAILLMContext\n\n"
|
||||
"```\n\n"
|
||||
"AFTER:\n"
|
||||
"```\n"
|
||||
"# Setup\n"
|
||||
"context = LLMContext(messages, tools)\n"
|
||||
"context_aggregator = LLMContextAggregatorPair(context)\n\n"
|
||||
"# Context aggregator type\n"
|
||||
"context_aggregator: LLMContextAggregatorPair\n\n"
|
||||
"# Context frame type\n"
|
||||
"frame: LLMContextFrame\n\n"
|
||||
"# Context type\n"
|
||||
"context: LLMContext\n\n"
|
||||
"```\n",
|
||||
)
|
||||
|
||||
import copy
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
from . import events
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
|
||||
class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
"""OpenAI Realtime LLM context with session management and message conversion.
|
||||
|
||||
Extends the standard OpenAI LLM context to support real-time session properties,
|
||||
instruction management, and conversion between standard message formats and
|
||||
realtime conversation items.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIRealtimeLLMContext` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize the OpenAIRealtimeLLMContext.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages. Defaults to None.
|
||||
tools: Available function tools. Defaults to None.
|
||||
**kwargs: Additional arguments passed to parent OpenAILLMContext.
|
||||
"""
|
||||
# Super handles deprecation warning
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self):
|
||||
self.llm_needs_settings_update = True
|
||||
self.llm_needs_initial_messages = True
|
||||
self._session_instructions = ""
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
"""Upgrade a standard OpenAI LLM context to a realtime context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAILLMContext instance to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded OpenAIRealtimeLLMContext instance.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
|
||||
obj.__class__ = OpenAIRealtimeLLMContext
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# todo
|
||||
# - finish implementing all frames
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert a standard message format to a realtime conversation item.
|
||||
|
||||
Args:
|
||||
message: The standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
A ConversationItem instance for the realtime API.
|
||||
"""
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
return events.ConversationItem(
|
||||
role="user",
|
||||
type="message",
|
||||
content=[events.ItemContent(type="input_text", text=content)],
|
||||
)
|
||||
if message.get("role") == "assistant" and message.get("tool_calls"):
|
||||
tc = message.get("tool_calls")[0]
|
||||
return events.ConversationItem(
|
||||
type="function_call",
|
||||
call_id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
)
|
||||
logger.error(f"Unhandled message type in from_standard_message: {message}")
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
"""Get conversation items for initializing the realtime session history.
|
||||
|
||||
Converts the context's messages to a format suitable for the realtime API,
|
||||
handling system instructions and conversation history packaging.
|
||||
|
||||
Returns:
|
||||
List of conversation items for session initialization.
|
||||
"""
|
||||
# We can't load a long conversation history into the openai realtime api yet. (The API/model
|
||||
# forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So
|
||||
# our general strategy until this is fixed is just to put everything into a first "user"
|
||||
# message as a single input.
|
||||
if not self.messages:
|
||||
return []
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into session
|
||||
# "instructions"
|
||||
if messages[0].get("role") == "system":
|
||||
self.llm_needs_settings_update = True
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
self._session_instructions = content
|
||||
elif isinstance(content, list):
|
||||
self._session_instructions = content[0].get("text")
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# If we have just a single "user" item, we can just send it normally
|
||||
if len(messages) == 1 and messages[0].get("role") == "user":
|
||||
return [self.from_standard_message(messages[0])]
|
||||
|
||||
# Otherwise, let's pack everything into a single "user" message with a bit of
|
||||
# explanation for the LLM
|
||||
intro_text = """
|
||||
This is a previously saved conversation. Please treat this conversation history as a
|
||||
starting point for the current conversation."""
|
||||
|
||||
trailing_text = """
|
||||
This is the end of the previously saved conversation. Please continue the conversation
|
||||
from here. If the last message is a user instruction or question, act on that instruction
|
||||
or answer the question. If the last message is an assistant response, simple say that you
|
||||
are ready to continue the conversation."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "\n\n".join(
|
||||
[intro_text, json.dumps(messages, indent=2), trailing_text]
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def add_user_content_item_as_message(self, item):
|
||||
"""Add a user content item as a standard message to the context.
|
||||
|
||||
Args:
|
||||
item: The conversation item to add as a user message.
|
||||
"""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": item.content[0].transcript}],
|
||||
}
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""User context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles user input frames and generates appropriate context updates
|
||||
for the realtime conversation, including message updates and tool settings.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIRealtimeUserContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process incoming frames and handle realtime-specific frame types.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
|
||||
# messages are only processed by the user context aggregator, which is generally what we want. But
|
||||
# we also need to send new messages over the websocket, so the openai realtime API has them
|
||||
# in its context.
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(RealtimeMessagesUpdateFrame(context=self._context))
|
||||
|
||||
# Parent also doesn't push the LLMSetToolsFrame.
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push user input aggregation.
|
||||
|
||||
Currently ignores all user input coming into the pipeline as realtime
|
||||
audio input is handled directly by the service.
|
||||
"""
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Assistant context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles assistant output frames from the realtime service, filtering
|
||||
out duplicate text frames and managing function call results.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`OpenAIRealtimeAssistantContextAggregator` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
# Super handles deprecation warning
|
||||
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# OpenAIRealtimeLLMService also pushes TranscriptionFrames and InterimTranscriptionFrames,
|
||||
# so we need to ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process assistant frames, filtering out duplicate text content.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result and notify the realtime service.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frame types for OpenAI Realtime API integration.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
OpenAI Realtime no longer uses types from this module under the hood.
|
||||
|
||||
It now works more like most LLM services in Pipecat, relying on updates to
|
||||
its context, pushed by context aggregators, to update its internal state.
|
||||
|
||||
Listen for ``LLMContextFrame`` s for context updates.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai.realtime.frames are deprecated. \n"
|
||||
"OpenAI Realtime no longer uses types from this module under the hood. \n\n"
|
||||
"It now works more like other LLM services in Pipecat, relying on updates to \n"
|
||||
"its context, pushed by context aggregators, to update its internal state.\n\n"
|
||||
"Listen for `LLMContextFrame`s for context updates.\n"
|
||||
)
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.openai.realtime.context import OpenAIRealtimeLLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeMessagesUpdateFrame(DataFrame):
|
||||
"""Frame indicating that the realtime context messages have been updated.
|
||||
|
||||
Parameters:
|
||||
context: The updated OpenAI realtime LLM context.
|
||||
"""
|
||||
|
||||
context: "OpenAIRealtimeLLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call results for the realtime service.
|
||||
|
||||
Parameters:
|
||||
result_frame: The function call result frame to send to the realtime API.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
@@ -48,15 +48,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import (
|
||||
@@ -564,13 +556,8 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
pass
|
||||
elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
context = (
|
||||
frame.context
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
await self._handle_context(context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
await self._handle_context(frame.context)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if not self._audio_input_paused:
|
||||
await self._send_user_audio(frame)
|
||||
@@ -1133,74 +1120,3 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
output=json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create an instance of OpenAIContextAggregatorPair from an OpenAILLMContext.
|
||||
|
||||
NOTE: this method exists only for backward compatibility. New code
|
||||
should instead do::
|
||||
|
||||
context = LLMContext(...)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
Constructor keyword arguments for both the user and assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context: The LLM context.
|
||||
user_params: User aggregator parameters.
|
||||
assistant_params: Assistant aggregator parameters.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIContextAggregatorPair.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
# Log warning about transcription frame direction change in 0.0.92.
|
||||
# We're putting this warning here rather than in the constructor so
|
||||
# that it shows up for folks who haven't updated their code at all
|
||||
# since 0.0.92, gives them a way to acknowledge and dismiss the
|
||||
# warning, and encourages adoption of a new preferred pattern.
|
||||
logger.warning(
|
||||
"As of version 0.0.92, TranscriptionFrames and InterimTranscriptionFrames "
|
||||
"now go upstream from OpenAIRealtimeLLMService, so if you're using "
|
||||
"TranscriptProcessor, say, you'll want to adjust accordingly:\n\n"
|
||||
"pipeline = Pipeline(\n"
|
||||
" [\n"
|
||||
" transport.input(),\n"
|
||||
" context_aggregator.user(),\n\n"
|
||||
" # BEFORE\n"
|
||||
" llm,\n"
|
||||
" transcript.user(),\n\n"
|
||||
" # AFTER\n"
|
||||
" transcript.user(),\n"
|
||||
" llm,\n\n"
|
||||
" transport.output(),\n"
|
||||
" transcript.assistant(),\n"
|
||||
" context_aggregator.assistant(),\n"
|
||||
" ]\n"
|
||||
")\n\n"
|
||||
"Also, LLMTextFrames are no longer pushed from "
|
||||
"OpenAIRealtimeLLMService when it's configured with "
|
||||
"output_modalities=['audio']. Listen for TTSTextFrames instead.\n\n"
|
||||
"Once you've made the appropriate changes (if needed), you can "
|
||||
"dismiss this warning by updating to the new context-setup pattern:\n\n"
|
||||
" context = LLMContext(messages, tools)\n"
|
||||
" context_aggregator = LLMContextAggregatorPair(context)\n"
|
||||
)
|
||||
# from_openai_context handles deprecation warning already
|
||||
context = LLMContext.from_openai_context(context)
|
||||
assistant_params.expect_stripped_words = False
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime LLM context and aggregator implementations.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
OpenAI Realtime no longer uses types from this module under the hood.
|
||||
It now uses `LLMContext` and `LLMContextAggregatorPair`.
|
||||
Using the new patterns should allow you to not need types from this module.
|
||||
|
||||
See deprecation warning in pipecat.services.openai.realtime.context for
|
||||
more details.
|
||||
"""
|
||||
|
||||
from pipecat.services.openai.realtime.context import *
|
||||
@@ -1,21 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frame types for OpenAI Realtime API integration."""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.services.openai.realtime.frames import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Types in pipecat.services.openai_realtime.frames are deprecated. "
|
||||
"Please use the equivalent types from "
|
||||
"pipecat.services.openai.realtime.frames instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
from .azure import AzureRealtimeBetaLLMService
|
||||
from .events import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from .openai import OpenAIRealtimeBetaLLMService
|
||||
|
||||
__all__ = [
|
||||
"AzureRealtimeBetaLLMService",
|
||||
"InputAudioNoiseReduction",
|
||||
"InputAudioTranscription",
|
||||
"SemanticTurnDetection",
|
||||
"SessionProperties",
|
||||
"TurnDetection",
|
||||
"OpenAIRealtimeBetaLLMService",
|
||||
]
|
||||
@@ -1,94 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Azure OpenAI Realtime Beta LLM service implementation."""
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .openai import OpenAIRealtimeBetaLLMService
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureRealtimeBetaLLMSettings(OpenAIRealtimeBetaLLMService.Settings):
|
||||
"""Settings for AzureRealtimeBetaLLMService."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
|
||||
"""Azure OpenAI Realtime Beta LLM service with Azure-specific authentication.
|
||||
|
||||
.. deprecated:: 0.0.84
|
||||
`AzureRealtimeBetaLLMService` is deprecated, use `AzureRealtimeLLMService` instead.
|
||||
This class will be removed in version 1.0.0.
|
||||
|
||||
Extends the OpenAI Realtime service to work with Azure OpenAI endpoints,
|
||||
using Azure's authentication headers and endpoint format. Provides the same
|
||||
real-time audio and text communication capabilities as the base OpenAI service.
|
||||
"""
|
||||
|
||||
Settings = AzureRealtimeBetaLLMSettings
|
||||
_settings: Settings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Azure Realtime Beta LLM service.
|
||||
|
||||
Args:
|
||||
api_key: The API key for the Azure OpenAI service.
|
||||
base_url: The full Azure WebSocket endpoint URL including api-version and deployment.
|
||||
Example: "wss://my-project.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-realtime-deployment"
|
||||
**kwargs: Additional arguments passed to parent OpenAIRealtimeBetaLLMService.
|
||||
"""
|
||||
super().__init__(base_url=base_url, api_key=api_key, **kwargs)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"AzureRealtimeBetaLLMService is deprecated and will be removed in version 1.0.0. "
|
||||
"Use AzureRealtimeLLMService instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
if self._websocket:
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
|
||||
logger.info(f"Connecting to {self.base_url}")
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self.base_url,
|
||||
additional_headers={
|
||||
"api-key": self.api_key,
|
||||
},
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
@@ -1,272 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime LLM context and aggregator implementations."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
from . import events
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
|
||||
class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
"""OpenAI Realtime LLM context with session management and message conversion.
|
||||
|
||||
Extends the standard OpenAI LLM context to support real-time session properties,
|
||||
instruction management, and conversion between standard message formats and
|
||||
realtime conversation items.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize the OpenAIRealtimeLLMContext.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages. Defaults to None.
|
||||
tools: Available function tools. Defaults to None.
|
||||
**kwargs: Additional arguments passed to parent OpenAILLMContext.
|
||||
"""
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self):
|
||||
self.llm_needs_settings_update = True
|
||||
self.llm_needs_initial_messages = True
|
||||
self._session_instructions = ""
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
"""Upgrade a standard OpenAI LLM context to a realtime context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAILLMContext instance to upgrade.
|
||||
|
||||
Returns:
|
||||
The upgraded OpenAIRealtimeLLMContext instance.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
|
||||
obj.__class__ = OpenAIRealtimeLLMContext
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# todo
|
||||
# - finish implementing all frames
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert a standard message format to a realtime conversation item.
|
||||
|
||||
Args:
|
||||
message: The standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
A ConversationItem instance for the realtime API.
|
||||
"""
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
return events.ConversationItem(
|
||||
role="user",
|
||||
type="message",
|
||||
content=[events.ItemContent(type="input_text", text=content)],
|
||||
)
|
||||
if message.get("role") == "assistant" and message.get("tool_calls"):
|
||||
tc = message.get("tool_calls")[0]
|
||||
return events.ConversationItem(
|
||||
type="function_call",
|
||||
call_id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
)
|
||||
logger.error(f"Unhandled message type in from_standard_message: {message}")
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
"""Get conversation items for initializing the realtime session history.
|
||||
|
||||
Converts the context's messages to a format suitable for the realtime API,
|
||||
handling system instructions and conversation history packaging.
|
||||
|
||||
Returns:
|
||||
List of conversation items for session initialization.
|
||||
"""
|
||||
# We can't load a long conversation history into the openai realtime api yet. (The API/model
|
||||
# forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So
|
||||
# our general strategy until this is fixed is just to put everything into a first "user"
|
||||
# message as a single input.
|
||||
if not self.messages:
|
||||
return []
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into session
|
||||
# "instructions"
|
||||
if messages[0].get("role") == "system":
|
||||
self.llm_needs_settings_update = True
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
self._session_instructions = content
|
||||
elif isinstance(content, list):
|
||||
self._session_instructions = content[0].get("text")
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# If we have just a single "user" item, we can just send it normally
|
||||
if len(messages) == 1 and messages[0].get("role") == "user":
|
||||
return [self.from_standard_message(messages[0])]
|
||||
|
||||
# Otherwise, let's pack everything into a single "user" message with a bit of
|
||||
# explanation for the LLM
|
||||
intro_text = """
|
||||
This is a previously saved conversation. Please treat this conversation history as a
|
||||
starting point for the current conversation."""
|
||||
|
||||
trailing_text = """
|
||||
This is the end of the previously saved conversation. Please continue the conversation
|
||||
from here. If the last message is a user instruction or question, act on that instruction
|
||||
or answer the question. If the last message is an assistant response, simple say that you
|
||||
are ready to continue the conversation."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "\n\n".join(
|
||||
[intro_text, json.dumps(messages, indent=2), trailing_text]
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def add_user_content_item_as_message(self, item):
|
||||
"""Add a user content item as a standard message to the context.
|
||||
|
||||
Args:
|
||||
item: The conversation item to add as a user message.
|
||||
"""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": item.content[0].transcript}],
|
||||
}
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""User context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles user input frames and generates appropriate context updates
|
||||
for the realtime conversation, including message updates and tool settings.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process incoming frames and handle realtime-specific frame types.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
|
||||
# messages are only processed by the user context aggregator, which is generally what we want. But
|
||||
# we also need to send new messages over the websocket, so the openai realtime API has them
|
||||
# in its context.
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(RealtimeMessagesUpdateFrame(context=self._context))
|
||||
|
||||
# Parent also doesn't push the LLMSetToolsFrame.
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push user input aggregation.
|
||||
|
||||
Currently ignores all user input coming into the pipeline as realtime
|
||||
audio input is handled directly by the service.
|
||||
"""
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Assistant context aggregator for OpenAI Realtime API.
|
||||
|
||||
Handles assistant output frames from the realtime service, filtering
|
||||
out duplicate text frames and managing function call results.
|
||||
|
||||
Args:
|
||||
context: The OpenAI realtime LLM context.
|
||||
**kwargs: Additional arguments passed to parent aggregator.
|
||||
"""
|
||||
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# OpenAIRealtimeLLMService also pushes TranscriptionFrames and InterimTranscriptionFrames,
|
||||
# so we need to ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process assistant frames, filtering out duplicate text content.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call result and notify the realtime service.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
@@ -1,978 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Event models and data structures for OpenAI Realtime API communication."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
#
|
||||
# session properties
|
||||
#
|
||||
|
||||
|
||||
class InputAudioTranscription(BaseModel):
|
||||
"""Configuration for audio transcription settings."""
|
||||
|
||||
model: str = "gpt-4o-transcribe"
|
||||
language: Optional[str]
|
||||
prompt: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = "gpt-4o-transcribe",
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
):
|
||||
"""Initialize InputAudioTranscription.
|
||||
|
||||
Args:
|
||||
model: Transcription model to use (e.g., "gpt-4o-transcribe", "whisper-1").
|
||||
language: Optional language code for transcription.
|
||||
prompt: Optional transcription hint text.
|
||||
"""
|
||||
super().__init__(model=model, language=language, prompt=prompt)
|
||||
|
||||
|
||||
class TurnDetection(BaseModel):
|
||||
"""Server-side voice activity detection configuration.
|
||||
|
||||
Parameters:
|
||||
type: Detection type, must be "server_vad".
|
||||
threshold: Voice activity detection threshold (0.0-1.0). Defaults to 0.5.
|
||||
prefix_padding_ms: Padding before speech starts in milliseconds. Defaults to 300.
|
||||
silence_duration_ms: Silence duration to detect speech end in milliseconds. Defaults to 800.
|
||||
"""
|
||||
|
||||
type: Optional[Literal["server_vad"]] = "server_vad"
|
||||
threshold: Optional[float] = 0.5
|
||||
prefix_padding_ms: Optional[int] = 300
|
||||
silence_duration_ms: Optional[int] = 800
|
||||
|
||||
|
||||
class SemanticTurnDetection(BaseModel):
|
||||
"""Semantic-based turn detection configuration.
|
||||
|
||||
Parameters:
|
||||
type: Detection type, must be "semantic_vad".
|
||||
eagerness: Turn detection eagerness level. Can be "low", "medium", "high", or "auto".
|
||||
create_response: Whether to automatically create responses on turn detection.
|
||||
interrupt_response: Whether to interrupt ongoing responses on turn detection.
|
||||
"""
|
||||
|
||||
type: Optional[Literal["semantic_vad"]] = "semantic_vad"
|
||||
eagerness: Optional[Literal["low", "medium", "high", "auto"]] = None
|
||||
create_response: Optional[bool] = None
|
||||
interrupt_response: Optional[bool] = None
|
||||
|
||||
|
||||
class InputAudioNoiseReduction(BaseModel):
|
||||
"""Input audio noise reduction configuration.
|
||||
|
||||
Parameters:
|
||||
type: Noise reduction type for different microphone scenarios.
|
||||
"""
|
||||
|
||||
type: Optional[Literal["near_field", "far_field"]]
|
||||
|
||||
|
||||
class SessionProperties(BaseModel):
|
||||
"""Configuration properties for an OpenAI Realtime session.
|
||||
|
||||
Parameters:
|
||||
modalities: Communication modalities to enable (text, audio, or both).
|
||||
instructions: System instructions for the assistant.
|
||||
voice: Voice ID for text-to-speech output.
|
||||
input_audio_format: Format for input audio data.
|
||||
output_audio_format: Format for output audio data.
|
||||
input_audio_transcription: Configuration for input audio transcription.
|
||||
input_audio_noise_reduction: Configuration for input audio noise reduction.
|
||||
turn_detection: Turn detection configuration or False to disable.
|
||||
tools: Available function tools for the assistant.
|
||||
tool_choice: Tool usage strategy ("auto", "none", or "required").
|
||||
temperature: Sampling temperature for response generation.
|
||||
max_response_output_tokens: Maximum tokens in response or "inf" for unlimited.
|
||||
"""
|
||||
|
||||
modalities: Optional[List[Literal["text", "audio"]]] = None
|
||||
instructions: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
input_audio_format: Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]] = None
|
||||
output_audio_format: Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]] = None
|
||||
input_audio_transcription: Optional[InputAudioTranscription] = None
|
||||
input_audio_noise_reduction: Optional[InputAudioNoiseReduction] = None
|
||||
# set turn_detection to False to disable turn detection
|
||||
turn_detection: Optional[Union[TurnDetection, SemanticTurnDetection, bool]] = Field(
|
||||
default=None
|
||||
)
|
||||
tools: Optional[List[Dict]] = None
|
||||
tool_choice: Optional[Literal["auto", "none", "required"]] = None
|
||||
temperature: Optional[float] = None
|
||||
max_response_output_tokens: Optional[Union[int, Literal["inf"]]] = None
|
||||
|
||||
|
||||
#
|
||||
# context
|
||||
#
|
||||
|
||||
|
||||
class ItemContent(BaseModel):
|
||||
"""Content within a conversation item.
|
||||
|
||||
Parameters:
|
||||
type: Content type (text, audio, input_text, or input_audio).
|
||||
text: Text content for text-based items.
|
||||
audio: Base64-encoded audio data for audio items.
|
||||
transcript: Transcribed text for audio items.
|
||||
"""
|
||||
|
||||
type: Literal["text", "audio", "input_text", "input_audio"]
|
||||
text: Optional[str] = None
|
||||
audio: Optional[str] = None # base64-encoded audio
|
||||
transcript: Optional[str] = None
|
||||
|
||||
|
||||
class ConversationItem(BaseModel):
|
||||
"""A conversation item in the realtime session.
|
||||
|
||||
Parameters:
|
||||
id: Unique identifier for the item, auto-generated if not provided.
|
||||
object: Object type identifier for the realtime API.
|
||||
type: Item type (message, function_call, or function_call_output).
|
||||
status: Current status of the item.
|
||||
role: Speaker role for message items (user, assistant, or system).
|
||||
content: Content list for message items.
|
||||
call_id: Function call identifier for function_call items.
|
||||
name: Function name for function_call items.
|
||||
arguments: Function arguments as JSON string for function_call items.
|
||||
output: Function output as JSON string for function_call_output items.
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex))
|
||||
object: Optional[Literal["realtime.item"]] = None
|
||||
type: Literal["message", "function_call", "function_call_output"]
|
||||
status: Optional[Literal["completed", "in_progress", "incomplete"]] = None
|
||||
# role and content are present for message items
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[List[ItemContent]] = None
|
||||
# these four fields are present for function_call items
|
||||
call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
|
||||
|
||||
class RealtimeConversation(BaseModel):
|
||||
"""A realtime conversation session.
|
||||
|
||||
Parameters:
|
||||
id: Unique identifier for the conversation.
|
||||
object: Object type identifier, always "realtime.conversation".
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["realtime.conversation"]
|
||||
|
||||
|
||||
class ResponseProperties(BaseModel):
|
||||
"""Properties for configuring assistant responses.
|
||||
|
||||
Parameters:
|
||||
modalities: Output modalities for the response. Defaults to ["audio", "text"].
|
||||
instructions: Specific instructions for this response.
|
||||
voice: Voice ID for text-to-speech in this response.
|
||||
output_audio_format: Audio format for this response.
|
||||
tools: Available tools for this response.
|
||||
tool_choice: Tool usage strategy for this response.
|
||||
temperature: Sampling temperature for this response.
|
||||
max_response_output_tokens: Maximum tokens for this response.
|
||||
"""
|
||||
|
||||
modalities: Optional[List[Literal["text", "audio"]]] = ["audio", "text"]
|
||||
instructions: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
output_audio_format: Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]] = None
|
||||
tools: Optional[List[Dict]] = Field(default_factory=list)
|
||||
tool_choice: Optional[Literal["auto", "none", "required"]] = None
|
||||
temperature: Optional[float] = None
|
||||
max_response_output_tokens: Optional[Union[int, Literal["inf"]]] = None
|
||||
|
||||
|
||||
#
|
||||
# error class
|
||||
#
|
||||
class RealtimeError(BaseModel):
|
||||
"""Error information from the realtime API.
|
||||
|
||||
Parameters:
|
||||
type: Error type identifier.
|
||||
code: Specific error code.
|
||||
message: Human-readable error message.
|
||||
param: Parameter name that caused the error, if applicable.
|
||||
event_id: Event ID associated with the error, if applicable.
|
||||
"""
|
||||
|
||||
type: str
|
||||
code: Optional[str] = ""
|
||||
message: str
|
||||
param: Optional[str] = None
|
||||
event_id: Optional[str] = None
|
||||
|
||||
|
||||
#
|
||||
# client events
|
||||
#
|
||||
|
||||
|
||||
class ClientEvent(BaseModel):
|
||||
"""Base class for client events sent to the realtime API.
|
||||
|
||||
Parameters:
|
||||
event_id: Unique identifier for the event, auto-generated if not provided.
|
||||
"""
|
||||
|
||||
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
class SessionUpdateEvent(ClientEvent):
|
||||
"""Event to update session properties.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "session.update".
|
||||
session: Updated session properties.
|
||||
"""
|
||||
|
||||
type: Literal["session.update"] = "session.update"
|
||||
session: SessionProperties
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""Serialize the event to a dictionary.
|
||||
|
||||
Handles special serialization for turn_detection where False becomes null.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments passed to parent model_dump.
|
||||
**kwargs: Keyword arguments passed to parent model_dump.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the event.
|
||||
"""
|
||||
dump = super().model_dump(*args, **kwargs)
|
||||
|
||||
# Handle turn_detection so that False is serialized as null
|
||||
if "turn_detection" in dump["session"]:
|
||||
if dump["session"]["turn_detection"] is False:
|
||||
dump["session"]["turn_detection"] = None
|
||||
|
||||
return dump
|
||||
|
||||
|
||||
class InputAudioBufferAppendEvent(ClientEvent):
|
||||
"""Event to append audio data to the input buffer.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.append".
|
||||
audio: Base64-encoded audio data to append.
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
|
||||
audio: str # base64-encoded audio
|
||||
|
||||
|
||||
class InputAudioBufferCommitEvent(ClientEvent):
|
||||
"""Event to commit the current input audio buffer.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.commit".
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
|
||||
|
||||
|
||||
class InputAudioBufferClearEvent(ClientEvent):
|
||||
"""Event to clear the input audio buffer.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.clear".
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.clear"] = "input_audio_buffer.clear"
|
||||
|
||||
|
||||
class ConversationItemCreateEvent(ClientEvent):
|
||||
"""Event to create a new conversation item.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.create".
|
||||
previous_item_id: ID of the item to insert after, if any.
|
||||
item: The conversation item to create.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.create"] = "conversation.item.create"
|
||||
previous_item_id: Optional[str] = None
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ConversationItemTruncateEvent(ClientEvent):
|
||||
"""Event to truncate a conversation item's audio content.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.truncate".
|
||||
item_id: ID of the item to truncate.
|
||||
content_index: Index of the content to truncate within the item.
|
||||
audio_end_ms: End time in milliseconds for the truncated audio.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.truncate"] = "conversation.item.truncate"
|
||||
item_id: str
|
||||
content_index: int
|
||||
audio_end_ms: int
|
||||
|
||||
|
||||
class ConversationItemDeleteEvent(ClientEvent):
|
||||
"""Event to delete a conversation item.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.delete".
|
||||
item_id: ID of the item to delete.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.delete"] = "conversation.item.delete"
|
||||
item_id: str
|
||||
|
||||
|
||||
class ConversationItemRetrieveEvent(ClientEvent):
|
||||
"""Event to retrieve a conversation item by ID.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.retrieve".
|
||||
item_id: ID of the item to retrieve.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.retrieve"] = "conversation.item.retrieve"
|
||||
item_id: str
|
||||
|
||||
|
||||
class ResponseCreateEvent(ClientEvent):
|
||||
"""Event to create a new assistant response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.create".
|
||||
response: Optional response configuration properties.
|
||||
"""
|
||||
|
||||
type: Literal["response.create"] = "response.create"
|
||||
response: Optional[ResponseProperties] = None
|
||||
|
||||
|
||||
class ResponseCancelEvent(ClientEvent):
|
||||
"""Event to cancel the current assistant response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.cancel".
|
||||
"""
|
||||
|
||||
type: Literal["response.cancel"] = "response.cancel"
|
||||
|
||||
|
||||
#
|
||||
# server events
|
||||
#
|
||||
|
||||
|
||||
class ServerEvent(BaseModel):
|
||||
"""Base class for server events received from the realtime API.
|
||||
|
||||
Parameters:
|
||||
event_id: Unique identifier for the event.
|
||||
type: Type of the server event.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
event_id: str
|
||||
type: str
|
||||
|
||||
|
||||
class SessionCreatedEvent(ServerEvent):
|
||||
"""Event indicating a session has been created.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "session.created".
|
||||
session: The created session properties.
|
||||
"""
|
||||
|
||||
type: Literal["session.created"]
|
||||
session: SessionProperties
|
||||
|
||||
|
||||
class SessionUpdatedEvent(ServerEvent):
|
||||
"""Event indicating a session has been updated.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "session.updated".
|
||||
session: The updated session properties.
|
||||
"""
|
||||
|
||||
type: Literal["session.updated"]
|
||||
session: SessionProperties
|
||||
|
||||
|
||||
class ConversationCreated(ServerEvent):
|
||||
"""Event indicating a conversation has been created.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.created".
|
||||
conversation: The created conversation.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.created"]
|
||||
conversation: RealtimeConversation
|
||||
|
||||
|
||||
class ConversationItemCreated(ServerEvent):
|
||||
"""Event indicating a conversation item has been created.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.created".
|
||||
previous_item_id: ID of the previous item, if any.
|
||||
item: The created conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.created"]
|
||||
previous_item_id: Optional[str] = None
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ConversationItemInputAudioTranscriptionDelta(ServerEvent):
|
||||
"""Event containing incremental input audio transcription.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.input_audio_transcription.delta".
|
||||
item_id: ID of the conversation item being transcribed.
|
||||
content_index: Index of the content within the item.
|
||||
delta: Incremental transcription text.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.input_audio_transcription.delta"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ConversationItemInputAudioTranscriptionCompleted(ServerEvent):
|
||||
"""Event indicating input audio transcription is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.input_audio_transcription.completed".
|
||||
item_id: ID of the conversation item that was transcribed.
|
||||
content_index: Index of the content within the item.
|
||||
transcript: Complete transcription text.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.input_audio_transcription.completed"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
transcript: str
|
||||
|
||||
|
||||
class ConversationItemInputAudioTranscriptionFailed(ServerEvent):
|
||||
"""Event indicating input audio transcription failed.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.input_audio_transcription.failed".
|
||||
item_id: ID of the conversation item that failed transcription.
|
||||
content_index: Index of the content within the item.
|
||||
error: Error details for the transcription failure.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.input_audio_transcription.failed"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
error: RealtimeError
|
||||
|
||||
|
||||
class ConversationItemTruncated(ServerEvent):
|
||||
"""Event indicating a conversation item has been truncated.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.truncated".
|
||||
item_id: ID of the truncated conversation item.
|
||||
content_index: Index of the content within the item.
|
||||
audio_end_ms: End time in milliseconds for the truncated audio.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.truncated"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
audio_end_ms: int
|
||||
|
||||
|
||||
class ConversationItemDeleted(ServerEvent):
|
||||
"""Event indicating a conversation item has been deleted.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.deleted".
|
||||
item_id: ID of the deleted conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.deleted"]
|
||||
item_id: str
|
||||
|
||||
|
||||
class ConversationItemRetrieved(ServerEvent):
|
||||
"""Event containing a retrieved conversation item.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "conversation.item.retrieved".
|
||||
item: The retrieved conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["conversation.item.retrieved"]
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ResponseCreated(ServerEvent):
|
||||
"""Event indicating an assistant response has been created.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.created".
|
||||
response: The created response object.
|
||||
"""
|
||||
|
||||
type: Literal["response.created"]
|
||||
response: "Response"
|
||||
|
||||
|
||||
class ResponseDone(ServerEvent):
|
||||
"""Event indicating an assistant response is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.done".
|
||||
response: The completed response object.
|
||||
"""
|
||||
|
||||
type: Literal["response.done"]
|
||||
response: "Response"
|
||||
|
||||
|
||||
class ResponseOutputItemAdded(ServerEvent):
|
||||
"""Event indicating an output item has been added to a response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.output_item.added".
|
||||
response_id: ID of the response.
|
||||
output_index: Index of the output item.
|
||||
item: The added conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["response.output_item.added"]
|
||||
response_id: str
|
||||
output_index: int
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ResponseOutputItemDone(ServerEvent):
|
||||
"""Event indicating an output item is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.output_item.done".
|
||||
response_id: ID of the response.
|
||||
output_index: Index of the output item.
|
||||
item: The completed conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["response.output_item.done"]
|
||||
response_id: str
|
||||
output_index: int
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ResponseContentPartAdded(ServerEvent):
|
||||
"""Event indicating a content part has been added to a response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.content_part.added".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
part: The added content part.
|
||||
"""
|
||||
|
||||
type: Literal["response.content_part.added"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ItemContent
|
||||
|
||||
|
||||
class ResponseContentPartDone(ServerEvent):
|
||||
"""Event indicating a content part is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.content_part.done".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
part: The completed content part.
|
||||
"""
|
||||
|
||||
type: Literal["response.content_part.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ItemContent
|
||||
|
||||
|
||||
class ResponseTextDelta(ServerEvent):
|
||||
"""Event containing incremental text from a response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.text.delta".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
delta: Incremental text content.
|
||||
"""
|
||||
|
||||
type: Literal["response.text.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseTextDone(ServerEvent):
|
||||
"""Event indicating text content is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.text.done".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
text: Complete text content.
|
||||
"""
|
||||
|
||||
type: Literal["response.text.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseAudioTranscriptDelta(ServerEvent):
|
||||
"""Event containing incremental audio transcript from a response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.audio_transcript.delta".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
delta: Incremental transcript text.
|
||||
"""
|
||||
|
||||
type: Literal["response.audio_transcript.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseAudioTranscriptDone(ServerEvent):
|
||||
"""Event indicating audio transcript is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.audio_transcript.done".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
transcript: Complete transcript text.
|
||||
"""
|
||||
|
||||
type: Literal["response.audio_transcript.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
transcript: str
|
||||
|
||||
|
||||
class ResponseAudioDelta(ServerEvent):
|
||||
"""Event containing incremental audio data from a response.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.audio.delta".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
delta: Base64-encoded incremental audio data.
|
||||
"""
|
||||
|
||||
type: Literal["response.audio.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str # base64-encoded audio
|
||||
|
||||
|
||||
class ResponseAudioDone(ServerEvent):
|
||||
"""Event indicating audio content is complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.audio.done".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
content_index: Index of the content part.
|
||||
"""
|
||||
|
||||
type: Literal["response.audio.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
|
||||
|
||||
class ResponseFunctionCallArgumentsDelta(ServerEvent):
|
||||
"""Event containing incremental function call arguments.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.function_call_arguments.delta".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
call_id: ID of the function call.
|
||||
delta: Incremental function arguments as JSON.
|
||||
"""
|
||||
|
||||
type: Literal["response.function_call_arguments.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
call_id: str
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseFunctionCallArgumentsDone(ServerEvent):
|
||||
"""Event indicating function call arguments are complete.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "response.function_call_arguments.done".
|
||||
response_id: ID of the response.
|
||||
item_id: ID of the conversation item.
|
||||
output_index: Index of the output item.
|
||||
call_id: ID of the function call.
|
||||
arguments: Complete function arguments as JSON string.
|
||||
"""
|
||||
|
||||
type: Literal["response.function_call_arguments.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
call_id: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class InputAudioBufferSpeechStarted(ServerEvent):
|
||||
"""Event indicating speech has started in the input audio buffer.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.speech_started".
|
||||
audio_start_ms: Start time of speech in milliseconds.
|
||||
item_id: ID of the associated conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.speech_started"]
|
||||
audio_start_ms: int
|
||||
item_id: str
|
||||
|
||||
|
||||
class InputAudioBufferSpeechStopped(ServerEvent):
|
||||
"""Event indicating speech has stopped in the input audio buffer.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.speech_stopped".
|
||||
audio_end_ms: End time of speech in milliseconds.
|
||||
item_id: ID of the associated conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.speech_stopped"]
|
||||
audio_end_ms: int
|
||||
item_id: str
|
||||
|
||||
|
||||
class InputAudioBufferCommitted(ServerEvent):
|
||||
"""Event indicating the input audio buffer has been committed.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.committed".
|
||||
previous_item_id: ID of the previous item, if any.
|
||||
item_id: ID of the committed conversation item.
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.committed"]
|
||||
previous_item_id: Optional[str] = None
|
||||
item_id: str
|
||||
|
||||
|
||||
class InputAudioBufferCleared(ServerEvent):
|
||||
"""Event indicating the input audio buffer has been cleared.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "input_audio_buffer.cleared".
|
||||
"""
|
||||
|
||||
type: Literal["input_audio_buffer.cleared"]
|
||||
|
||||
|
||||
class ErrorEvent(ServerEvent):
|
||||
"""Event indicating an error occurred.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "error".
|
||||
error: Error details.
|
||||
"""
|
||||
|
||||
type: Literal["error"]
|
||||
error: RealtimeError
|
||||
|
||||
|
||||
class RateLimitsUpdated(ServerEvent):
|
||||
"""Event indicating rate limits have been updated.
|
||||
|
||||
Parameters:
|
||||
type: Event type, always "rate_limits.updated".
|
||||
rate_limits: List of rate limit information.
|
||||
"""
|
||||
|
||||
type: Literal["rate_limits.updated"]
|
||||
rate_limits: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class TokenDetails(BaseModel):
|
||||
"""Detailed token usage information.
|
||||
|
||||
Parameters:
|
||||
cached_tokens: Number of cached tokens used. Defaults to 0.
|
||||
text_tokens: Number of text tokens used. Defaults to 0.
|
||||
audio_tokens: Number of audio tokens used. Defaults to 0.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
cached_tokens: Optional[int] = 0
|
||||
text_tokens: Optional[int] = 0
|
||||
audio_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
"""Token usage statistics for a response.
|
||||
|
||||
Parameters:
|
||||
total_tokens: Total number of tokens used.
|
||||
input_tokens: Number of input tokens used.
|
||||
output_tokens: Number of output tokens used.
|
||||
input_token_details: Detailed breakdown of input token usage.
|
||||
output_token_details: Detailed breakdown of output token usage.
|
||||
"""
|
||||
|
||||
total_tokens: int
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
input_token_details: TokenDetails
|
||||
output_token_details: TokenDetails
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
"""A complete assistant response.
|
||||
|
||||
Parameters:
|
||||
id: Unique identifier for the response.
|
||||
object: Object type, always "realtime.response".
|
||||
status: Current status of the response.
|
||||
status_details: Additional status information.
|
||||
output: List of conversation items in the response.
|
||||
usage: Token usage statistics for the response.
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["realtime.response"]
|
||||
status: Literal["completed", "in_progress", "incomplete", "cancelled", "failed"]
|
||||
status_details: Any
|
||||
output: List[ConversationItem]
|
||||
usage: Optional[Usage] = None
|
||||
|
||||
|
||||
_server_event_types = {
|
||||
"error": ErrorEvent,
|
||||
"session.created": SessionCreatedEvent,
|
||||
"session.updated": SessionUpdatedEvent,
|
||||
"conversation.created": ConversationCreated,
|
||||
"input_audio_buffer.committed": InputAudioBufferCommitted,
|
||||
"input_audio_buffer.cleared": InputAudioBufferCleared,
|
||||
"input_audio_buffer.speech_started": InputAudioBufferSpeechStarted,
|
||||
"input_audio_buffer.speech_stopped": InputAudioBufferSpeechStopped,
|
||||
"conversation.item.created": ConversationItemCreated,
|
||||
"conversation.item.input_audio_transcription.delta": ConversationItemInputAudioTranscriptionDelta,
|
||||
"conversation.item.input_audio_transcription.completed": ConversationItemInputAudioTranscriptionCompleted,
|
||||
"conversation.item.input_audio_transcription.failed": ConversationItemInputAudioTranscriptionFailed,
|
||||
"conversation.item.truncated": ConversationItemTruncated,
|
||||
"conversation.item.deleted": ConversationItemDeleted,
|
||||
"conversation.item.retrieved": ConversationItemRetrieved,
|
||||
"response.created": ResponseCreated,
|
||||
"response.done": ResponseDone,
|
||||
"response.output_item.added": ResponseOutputItemAdded,
|
||||
"response.output_item.done": ResponseOutputItemDone,
|
||||
"response.content_part.added": ResponseContentPartAdded,
|
||||
"response.content_part.done": ResponseContentPartDone,
|
||||
"response.text.delta": ResponseTextDelta,
|
||||
"response.text.done": ResponseTextDone,
|
||||
"response.audio_transcript.delta": ResponseAudioTranscriptDelta,
|
||||
"response.audio_transcript.done": ResponseAudioTranscriptDone,
|
||||
"response.audio.delta": ResponseAudioDelta,
|
||||
"response.audio.done": ResponseAudioDone,
|
||||
"response.function_call_arguments.delta": ResponseFunctionCallArgumentsDelta,
|
||||
"response.function_call_arguments.done": ResponseFunctionCallArgumentsDone,
|
||||
"rate_limits.updated": RateLimitsUpdated,
|
||||
}
|
||||
|
||||
|
||||
def parse_server_event(str):
|
||||
"""Parse a server event from JSON string.
|
||||
|
||||
Args:
|
||||
str: JSON string containing the server event.
|
||||
|
||||
Returns:
|
||||
Parsed server event object of the appropriate type.
|
||||
|
||||
Raises:
|
||||
Exception: If the event type is unimplemented or parsing fails.
|
||||
"""
|
||||
try:
|
||||
event = json.loads(str)
|
||||
event_type = event["type"]
|
||||
if event_type not in _server_event_types:
|
||||
raise Exception(f"Unimplemented server event type: {event_type}")
|
||||
return _server_event_types[event_type].model_validate(event)
|
||||
except Exception as e:
|
||||
raise Exception(f"{e} \n\n{str}")
|
||||
@@ -1,37 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Custom frame types for OpenAI Realtime API integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.openai_realtime_beta.context import OpenAIRealtimeLLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeMessagesUpdateFrame(DataFrame):
|
||||
"""Frame indicating that the realtime context messages have been updated.
|
||||
|
||||
Parameters:
|
||||
context: The updated OpenAI realtime LLM context.
|
||||
"""
|
||||
|
||||
context: "OpenAIRealtimeLLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeFunctionCallResultFrame(DataFrame):
|
||||
"""Frame containing function call results for the realtime service.
|
||||
|
||||
Parameters:
|
||||
result_frame: The function call result frame to send to the realtime API.
|
||||
"""
|
||||
|
||||
result_frame: FunctionCallResultFrame
|
||||
@@ -1,858 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Realtime Beta LLM service implementation with WebSocket support."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt
|
||||
|
||||
from . import events
|
||||
from .context import (
|
||||
OpenAIRealtimeAssistantContextAggregator,
|
||||
OpenAIRealtimeLLMContext,
|
||||
OpenAIRealtimeUserContextAggregator,
|
||||
)
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use OpenAI, you need to `pip install pipecat-ai[openai]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurrentAudioResponse:
|
||||
"""Tracks the current audio response from the assistant.
|
||||
|
||||
Parameters:
|
||||
item_id: Unique identifier for the audio response item.
|
||||
content_index: Index of the audio content within the item.
|
||||
start_time_ms: Timestamp when the audio response started in milliseconds.
|
||||
total_size: Total size of audio data received in bytes. Defaults to 0.
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
content_index: int
|
||||
start_time_ms: int
|
||||
total_size: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIRealtimeBetaLLMSettings(LLMSettings):
|
||||
"""Settings for OpenAIRealtimeBetaLLMService."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
"""OpenAI Realtime Beta LLM service providing real-time audio and text communication.
|
||||
|
||||
.. deprecated:: 0.0.84
|
||||
`OpenAIRealtimeBetaLLMService` is deprecated, use `OpenAIRealtimeLLMService` instead.
|
||||
This class will be removed in version 1.0.0.
|
||||
|
||||
Implements the OpenAI Realtime API Beta with WebSocket communication for low-latency
|
||||
bidirectional audio and text interactions. Supports function calling, conversation
|
||||
management, and real-time transcription.
|
||||
"""
|
||||
|
||||
Settings = OpenAIRealtimeBetaLLMSettings
|
||||
_settings: Settings
|
||||
|
||||
# Overriding the default adapter to use the OpenAIRealtimeLLMAdapter one.
|
||||
adapter_class = OpenAIRealtimeLLMAdapter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: Optional[str] = None,
|
||||
base_url: str = "wss://api.openai.com/v1/realtime",
|
||||
session_properties: Optional[events.SessionProperties] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
start_audio_paused: bool = False,
|
||||
send_transcription_frames: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI Realtime Beta LLM service.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key for authentication.
|
||||
model: OpenAI model name.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=OpenAIRealtimeBetaLLMService.Settings(model=...)`` instead.
|
||||
|
||||
base_url: WebSocket base URL for the realtime API.
|
||||
Defaults to "wss://api.openai.com/v1/realtime".
|
||||
session_properties: Configuration properties for the realtime session.
|
||||
If None, uses default SessionProperties.
|
||||
settings: Runtime-updatable settings for this service.
|
||||
start_audio_paused: Whether to start with audio input paused. Defaults to False.
|
||||
send_transcription_frames: Whether to emit transcription frames. Defaults to True.
|
||||
**kwargs: Additional arguments passed to parent LLMService.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"OpenAIRealtimeBetaLLMService is deprecated and will be removed in version 1.0.0. "
|
||||
"Use OpenAIRealtimeLLMService instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = self.Settings(
|
||||
model="gpt-4o-realtime-preview-2025-06-03",
|
||||
system_instruction=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
seed=None,
|
||||
filter_incomplete_user_turns=False,
|
||||
user_turn_completion_config=None,
|
||||
)
|
||||
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
if model is not None:
|
||||
self._warn_init_param_moved_to_settings("model", "model")
|
||||
default_settings.model = model
|
||||
# 3. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
full_url = f"{base_url}?model={default_settings.model}"
|
||||
super().__init__(
|
||||
base_url=full_url,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = full_url
|
||||
self._session_properties = session_properties or events.SessionProperties()
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._send_transcription_frames = send_transcription_frames
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._context = None
|
||||
|
||||
self._disconnecting = False
|
||||
self._api_session_ready = False
|
||||
self._run_llm_when_api_session_ready = False
|
||||
|
||||
self._current_assistant_response = None
|
||||
self._current_audio_response = None
|
||||
|
||||
self._messages_added_manually = {}
|
||||
self._user_and_response_message_tuple = None
|
||||
|
||||
self._register_event_handler("on_conversation_item_created")
|
||||
self._register_event_handler("on_conversation_item_updated")
|
||||
self._retrieve_conversation_item_futures = {}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate usage metrics.
|
||||
|
||||
Returns:
|
||||
True if metrics generation is supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
def set_audio_input_paused(self, paused: bool):
|
||||
"""Set whether audio input is paused.
|
||||
|
||||
Args:
|
||||
paused: True to pause audio input, False to resume.
|
||||
"""
|
||||
self._audio_input_paused = paused
|
||||
|
||||
def _is_modality_enabled(self, modality: str) -> bool:
|
||||
"""Check if a specific modality is enabled, "text" or "audio"."""
|
||||
modalities = self._session_properties.modalities or ["audio", "text"]
|
||||
return modality in modalities
|
||||
|
||||
def _get_enabled_modalities(self) -> list[str]:
|
||||
"""Get the list of enabled modalities."""
|
||||
return self._session_properties.modalities or ["audio", "text"]
|
||||
|
||||
async def retrieve_conversation_item(self, item_id: str):
|
||||
"""Retrieve a conversation item by ID from the server.
|
||||
|
||||
Args:
|
||||
item_id: The ID of the conversation item to retrieve.
|
||||
|
||||
Returns:
|
||||
The retrieved conversation item.
|
||||
"""
|
||||
future = self.get_event_loop().create_future()
|
||||
retrieval_in_flight = False
|
||||
if not self._retrieve_conversation_item_futures.get(item_id):
|
||||
self._retrieve_conversation_item_futures[item_id] = []
|
||||
else:
|
||||
retrieval_in_flight = True
|
||||
self._retrieve_conversation_item_futures[item_id].append(future)
|
||||
if not retrieval_in_flight:
|
||||
await self.send_client_event(
|
||||
# Set event_id to "rci_{item_id}" so that we can identify an
|
||||
# error later if the retrieval fails. We don't need a UUID
|
||||
# suffix to the event_id because we're ensuring only one
|
||||
# in-flight retrieval per item_id. (Note: "rci" = "retrieve
|
||||
# conversation item")
|
||||
events.ConversationItemRetrieveEvent(item_id=item_id, event_id=f"rci_{item_id}")
|
||||
)
|
||||
return await future
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish WebSocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame triggering service initialization.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close WebSocket connection.
|
||||
|
||||
Args:
|
||||
frame: The end frame triggering service shutdown.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the service and close WebSocket connection.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame triggering service cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
#
|
||||
# speech and interruption handling
|
||||
#
|
||||
|
||||
async def _handle_interruption(self):
|
||||
# None and False are different. Check for False. None means we're using OpenAI's
|
||||
# built-in turn detection defaults.
|
||||
if self._session_properties.turn_detection is False:
|
||||
await self.send_client_event(events.InputAudioBufferClearEvent())
|
||||
await self.send_client_event(events.ResponseCancelEvent())
|
||||
await self._truncate_current_audio_response()
|
||||
await self.stop_all_metrics()
|
||||
if self._current_assistant_response:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
# Only push TTSStoppedFrame if audio modality is enabled
|
||||
if self._is_modality_enabled("audio"):
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
|
||||
async def _handle_user_started_speaking(self, frame):
|
||||
pass
|
||||
|
||||
async def _handle_user_stopped_speaking(self, frame):
|
||||
# None and False are different. Check for False. None means we're using OpenAI's
|
||||
# built-in turn detection defaults.
|
||||
if self._session_properties.turn_detection is False:
|
||||
await self.send_client_event(events.InputAudioBufferCommitEvent())
|
||||
await self.send_client_event(events.ResponseCreateEvent())
|
||||
|
||||
async def _handle_bot_stopped_speaking(self):
|
||||
self._current_audio_response = None
|
||||
|
||||
def _calculate_audio_duration_ms(
|
||||
self, total_bytes: int, sample_rate: int = 24000, bytes_per_sample: int = 2
|
||||
) -> int:
|
||||
"""Calculate audio duration in milliseconds based on PCM audio parameters."""
|
||||
samples = total_bytes / bytes_per_sample
|
||||
duration_seconds = samples / sample_rate
|
||||
return int(duration_seconds * 1000)
|
||||
|
||||
async def _truncate_current_audio_response(self):
|
||||
"""Truncates the current audio response at the appropriate duration.
|
||||
|
||||
Calculates the actual duration of the audio content and truncates at the shorter of
|
||||
either the wall clock time or the actual audio duration to prevent invalid truncation
|
||||
requests.
|
||||
"""
|
||||
if not self._current_audio_response:
|
||||
return
|
||||
|
||||
# if the bot is still speaking, truncate the last message
|
||||
try:
|
||||
current = self._current_audio_response
|
||||
self._current_audio_response = None
|
||||
|
||||
# Calculate actual audio duration instead of using wall clock time
|
||||
audio_duration_ms = self._calculate_audio_duration_ms(current.total_size)
|
||||
|
||||
# Use the shorter of wall clock time or actual audio duration
|
||||
elapsed_ms = int(time.time() * 1000 - current.start_time_ms)
|
||||
truncate_ms = min(elapsed_ms, audio_duration_ms)
|
||||
|
||||
logger.trace(
|
||||
f"Truncating audio: duration={audio_duration_ms}ms, "
|
||||
f"elapsed={elapsed_ms}ms, truncate={truncate_ms}ms"
|
||||
)
|
||||
|
||||
await self.send_client_event(
|
||||
events.ConversationItemTruncateEvent(
|
||||
item_id=current.item_id,
|
||||
content_index=current.content_index,
|
||||
audio_end_ms=truncate_ms,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Log warning and don't re-raise - allow session to continue
|
||||
logger.warning(f"Audio truncation failed (non-fatal): {e}")
|
||||
|
||||
#
|
||||
# frame processing
|
||||
#
|
||||
# StartFrame, StopFrame, CancelFrame implemented in base class
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames from the pipeline.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
# Backward-compatible dict path: frame.settings contains SessionProperties
|
||||
# fields, not our Settings fields, so we construct SessionProperties
|
||||
# directly. The frame.delta path falls through to super, which calls
|
||||
# _update_settings → our override handles the rest.
|
||||
if isinstance(frame, LLMUpdateSettingsFrame) and frame.delta is None:
|
||||
self._session_properties = events.SessionProperties(**frame.settings)
|
||||
await self._send_session_update()
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
pass
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAIRealtimeLLMContext = OpenAIRealtimeLLMContext.upgrade_to_realtime(
|
||||
frame.context
|
||||
)
|
||||
if not self._context:
|
||||
self._context = context
|
||||
elif frame.context is not self._context:
|
||||
# If the context has changed, reset the conversation
|
||||
self._context = context
|
||||
await self.reset_conversation()
|
||||
# Run the LLM at next opportunity
|
||||
await self._create_response()
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError(
|
||||
"Universal LLMContext is not yet supported for OpenAI Realtime."
|
||||
)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if not self._audio_input_paused:
|
||||
await self._send_user_audio(frame)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking()
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_messages_append(frame)
|
||||
elif isinstance(frame, RealtimeMessagesUpdateFrame):
|
||||
self._context = frame.context
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
await self._send_session_update()
|
||||
elif isinstance(frame, RealtimeFunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame.result_frame)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_messages_append(self, frame):
|
||||
logger.error("!!! NEED TO IMPLEMENT MESSAGES APPEND")
|
||||
|
||||
async def _handle_function_call_result(self, frame):
|
||||
item = events.ConversationItem(
|
||||
type="function_call_output",
|
||||
call_id=frame.tool_call_id,
|
||||
output=json.dumps(frame.result, ensure_ascii=False),
|
||||
)
|
||||
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
|
||||
|
||||
#
|
||||
# websocket communication
|
||||
#
|
||||
|
||||
async def send_client_event(self, event: events.ClientEvent):
|
||||
"""Send a client event to the OpenAI Realtime API.
|
||||
|
||||
Args:
|
||||
event: The client event to send.
|
||||
"""
|
||||
await self._ws_send(event.model_dump(exclude_none=True))
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
if self._websocket:
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self.base_url,
|
||||
additional_headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
},
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
try:
|
||||
self._disconnecting = True
|
||||
self._api_session_ready = False
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
|
||||
async def _ws_send(self, realtime_message):
|
||||
try:
|
||||
if self._websocket:
|
||||
await self._websocket.send(json.dumps(realtime_message))
|
||||
except Exception as e:
|
||||
if self._disconnecting:
|
||||
return
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
|
||||
|
||||
async def _update_settings(self, delta):
|
||||
"""Apply a settings delta."""
|
||||
changed = await super()._update_settings(delta)
|
||||
self._warn_unhandled_updated_settings(changed.keys())
|
||||
return changed
|
||||
|
||||
async def _send_session_update(self):
|
||||
settings = self._session_properties
|
||||
# tools given in the context override the tools in the session properties
|
||||
if self._context and self._context.tools:
|
||||
settings.tools = self._context.tools
|
||||
# instructions in the context come from an initial "system" message in the
|
||||
# messages list, and override instructions in the session properties
|
||||
if self._context and self._context._session_instructions:
|
||||
settings.instructions = self._context._session_instructions
|
||||
await self.send_client_event(events.SessionUpdateEvent(session=settings))
|
||||
|
||||
#
|
||||
# inbound server event handling
|
||||
# https://platform.openai.com/docs/api-reference/realtime-server-events
|
||||
#
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
async for message in self._websocket:
|
||||
evt = events.parse_server_event(message)
|
||||
if evt.type == "session.created":
|
||||
await self._handle_evt_session_created(evt)
|
||||
elif evt.type == "session.updated":
|
||||
await self._handle_evt_session_updated(evt)
|
||||
elif evt.type == "response.audio.delta":
|
||||
await self._handle_evt_audio_delta(evt)
|
||||
elif evt.type == "response.audio.done":
|
||||
await self._handle_evt_audio_done(evt)
|
||||
elif evt.type == "conversation.item.created":
|
||||
await self._handle_evt_conversation_item_created(evt)
|
||||
elif evt.type == "conversation.item.input_audio_transcription.delta":
|
||||
await self._handle_evt_input_audio_transcription_delta(evt)
|
||||
elif evt.type == "conversation.item.input_audio_transcription.completed":
|
||||
await self.handle_evt_input_audio_transcription_completed(evt)
|
||||
elif evt.type == "conversation.item.retrieved":
|
||||
await self._handle_conversation_item_retrieved(evt)
|
||||
elif evt.type == "response.done":
|
||||
await self._handle_evt_response_done(evt)
|
||||
elif evt.type == "input_audio_buffer.speech_started":
|
||||
await self._handle_evt_speech_started(evt)
|
||||
elif evt.type == "input_audio_buffer.speech_stopped":
|
||||
await self._handle_evt_speech_stopped(evt)
|
||||
elif evt.type == "response.text.delta":
|
||||
await self._handle_evt_text_delta(evt)
|
||||
elif evt.type == "response.audio_transcript.delta":
|
||||
await self._handle_evt_audio_transcript_delta(evt)
|
||||
elif evt.type == "error":
|
||||
if not await self._maybe_handle_evt_retrieve_conversation_item_error(evt):
|
||||
if evt.error.code in (
|
||||
"response_cancel_not_active",
|
||||
"conversation_already_has_active_response",
|
||||
):
|
||||
logger.debug(f"{self} {evt.error.message}")
|
||||
else:
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
|
||||
@traced_openai_realtime(operation="llm_setup")
|
||||
async def _handle_evt_session_created(self, evt):
|
||||
# session.created is received right after connecting. Send a message
|
||||
# to configure the session properties.
|
||||
await self._send_session_update()
|
||||
|
||||
async def _handle_evt_session_updated(self, evt):
|
||||
# If this is our first context frame, run the LLM
|
||||
self._api_session_ready = True
|
||||
# Now that we've configured the session, we can run the LLM if we need to.
|
||||
if self._run_llm_when_api_session_ready:
|
||||
self._run_llm_when_api_session_ready = False
|
||||
await self._create_response()
|
||||
|
||||
async def _handle_evt_audio_delta(self, evt):
|
||||
# note: ttfb is faster by 1/2 RTT than ttfb as measured for other services, since we're getting
|
||||
# this event from the server
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if self._current_audio_response and self._current_audio_response.item_id != evt.item_id:
|
||||
logger.warning(
|
||||
f"Received a new audio delta for an already completed audio response before receiving the BotStoppedSpeakingFrame."
|
||||
)
|
||||
logger.debug("Forcing previous audio response to None")
|
||||
self._current_audio_response = None
|
||||
|
||||
if not self._current_audio_response:
|
||||
self._current_audio_response = CurrentAudioResponse(
|
||||
item_id=evt.item_id,
|
||||
content_index=evt.content_index,
|
||||
start_time_ms=int(time.time() * 1000),
|
||||
)
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
audio = base64.b64decode(evt.delta)
|
||||
self._current_audio_response.total_size += len(audio)
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio,
|
||||
sample_rate=24000,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_evt_audio_done(self, evt):
|
||||
if self._current_audio_response:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
# Don't clear the self._current_audio_response here. We need to wait until we
|
||||
# receive a BotStoppedSpeakingFrame from the output transport.
|
||||
|
||||
async def _handle_evt_conversation_item_created(self, evt):
|
||||
await self._call_event_handler("on_conversation_item_created", evt.item.id, evt.item)
|
||||
|
||||
# This will get sent from the server every time a new "message" is added
|
||||
# to the server's conversation state, whether we create it via the API
|
||||
# or the server creates it from LLM output.
|
||||
if self._messages_added_manually.get(evt.item.id):
|
||||
del self._messages_added_manually[evt.item.id]
|
||||
return
|
||||
|
||||
if evt.item.role == "user":
|
||||
# We need to wait for completion of both user message and response message. Then we'll
|
||||
# add both to the context. User message is complete when we have a "transcript" field
|
||||
# that is not None. Response message is complete when we get a "response.done" event.
|
||||
self._user_and_response_message_tuple = (evt.item, {"done": False, "output": []})
|
||||
elif evt.item.role == "assistant":
|
||||
self._current_assistant_response = evt.item
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
async def _handle_evt_input_audio_transcription_delta(self, evt):
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
# no way to get a language code?
|
||||
InterimTranscriptionFrame(evt.delta, "", time_now_iso8601(), result=evt)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_user_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def handle_evt_input_audio_transcription_completed(self, evt):
|
||||
"""Handle completion of input audio transcription.
|
||||
|
||||
Args:
|
||||
evt: The transcription completed event.
|
||||
"""
|
||||
await self._call_event_handler("on_conversation_item_updated", evt.item_id, None)
|
||||
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
# no way to get a language code?
|
||||
TranscriptionFrame(evt.transcript, "", time_now_iso8601(), result=evt)
|
||||
)
|
||||
await self._handle_user_transcription(evt.transcript, True, Language.EN)
|
||||
pair = self._user_and_response_message_tuple
|
||||
if pair:
|
||||
user, assistant = pair
|
||||
user.content[0].transcript = evt.transcript
|
||||
if assistant["done"]:
|
||||
self._user_and_response_message_tuple = None
|
||||
self._context.add_user_content_item_as_message(user)
|
||||
await self._handle_assistant_output(assistant["output"])
|
||||
else:
|
||||
# User message without preceding conversation.item.created. Bug?
|
||||
logger.warning(f"Transcript for unknown user message: {evt}")
|
||||
|
||||
async def _handle_conversation_item_retrieved(self, evt: events.ConversationItemRetrieved):
|
||||
futures = self._retrieve_conversation_item_futures.pop(evt.item.id, None)
|
||||
if futures:
|
||||
for future in futures:
|
||||
future.set_result(evt.item)
|
||||
|
||||
@traced_openai_realtime(operation="llm_response")
|
||||
async def _handle_evt_response_done(self, evt):
|
||||
# todo: figure out whether there's anything we need to do for "cancelled" events
|
||||
# usage metrics
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=evt.response.usage.input_tokens,
|
||||
completion_tokens=evt.response.usage.output_tokens,
|
||||
total_tokens=evt.response.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._current_assistant_response = None
|
||||
# error handling
|
||||
if evt.response.status == "failed":
|
||||
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
|
||||
return
|
||||
# response content
|
||||
for item in evt.response.output:
|
||||
await self._call_event_handler("on_conversation_item_updated", item.id, item)
|
||||
pair = self._user_and_response_message_tuple
|
||||
if pair:
|
||||
user, assistant = pair
|
||||
assistant["done"] = True
|
||||
assistant["output"] = evt.response.output
|
||||
if user.content[0].transcript is not None:
|
||||
self._user_and_response_message_tuple = None
|
||||
self._context.add_user_content_item_as_message(user)
|
||||
await self._handle_assistant_output(assistant["output"])
|
||||
else:
|
||||
# Response message without preceding user message. Add it to the context.
|
||||
await self._handle_assistant_output(evt.response.output)
|
||||
|
||||
async def _handle_evt_text_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):
|
||||
"""Maybe handle an error event related to retrieving a conversation item.
|
||||
|
||||
If the given error event is an error retrieving a conversation item:
|
||||
|
||||
- set an exception on the future that retrieve_conversation_item() is waiting on
|
||||
- return true
|
||||
Otherwise:
|
||||
- return false
|
||||
"""
|
||||
if evt.error.code == "item_retrieve_invalid_item_id":
|
||||
item_id = evt.error.event_id.split("_", 1)[1] # event_id is of the form "rci_{item_id}"
|
||||
futures = self._retrieve_conversation_item_futures.pop(item_id, None)
|
||||
if futures:
|
||||
for future in futures:
|
||||
future.set_exception(Exception(evt.error.message))
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(error_msg=f"Error: {evt}")
|
||||
|
||||
async def _handle_assistant_output(self, output):
|
||||
# We haven't seen intermixed audio and function_call items in the same response. But let's
|
||||
# try to write logic that handles that, if it does happen.
|
||||
# Also, the assistant output is pushed as LLMTextFrame and TTSTextFrame to be handled by
|
||||
# the assistant context aggregator.
|
||||
function_calls = [item for item in output if item.type == "function_call"]
|
||||
await self._handle_function_call_items(function_calls)
|
||||
|
||||
async def _handle_function_call_items(self, items):
|
||||
function_calls = []
|
||||
for item in items:
|
||||
args = json.loads(item.arguments)
|
||||
function_calls.append(
|
||||
FunctionCallFromLLM(
|
||||
context=self._context,
|
||||
tool_call_id=item.call_id,
|
||||
function_name=item.name,
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
#
|
||||
# state and client events for the current conversation
|
||||
# https://platform.openai.com/docs/api-reference/realtime-client-events
|
||||
#
|
||||
|
||||
async def reset_conversation(self):
|
||||
"""Reset the conversation by disconnecting and reconnecting.
|
||||
|
||||
This is the safest way to start a new conversation. Note that this will
|
||||
fail if called from the receive task.
|
||||
"""
|
||||
logger.debug("Resetting conversation")
|
||||
await self._disconnect()
|
||||
if self._context:
|
||||
self._context.llm_needs_settings_update = True
|
||||
self._context.llm_needs_initial_messages = True
|
||||
await self._connect()
|
||||
|
||||
@traced_openai_realtime(operation="llm_request")
|
||||
async def _create_response(self):
|
||||
if not self._api_session_ready:
|
||||
self._run_llm_when_api_session_ready = True
|
||||
return
|
||||
|
||||
if self._context.llm_needs_initial_messages:
|
||||
messages = self._context.get_messages_for_initializing_history()
|
||||
for item in messages:
|
||||
evt = events.ConversationItemCreateEvent(item=item)
|
||||
self._messages_added_manually[evt.item.id] = True
|
||||
await self.send_client_event(evt)
|
||||
self._context.llm_needs_initial_messages = False
|
||||
|
||||
if self._context.llm_needs_settings_update:
|
||||
await self._send_session_update()
|
||||
self._context.llm_needs_settings_update = False
|
||||
|
||||
logger.debug(f"Creating response: {self._context.get_messages_for_logging()}")
|
||||
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
await self.send_client_event(
|
||||
events.ResponseCreateEvent(
|
||||
response=events.ResponseProperties(modalities=self._get_enabled_modalities())
|
||||
)
|
||||
)
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
payload = base64.b64encode(frame.audio).decode("utf-8")
|
||||
await self.send_client_event(events.InputAudioBufferAppendEvent(audio=payload))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create an instance of OpenAIContextAggregatorPair from an OpenAILLMContext.
|
||||
|
||||
Constructor keyword arguments for both the user and assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context: The LLM context.
|
||||
user_params: User aggregator parameters.
|
||||
assistant_params: Assistant aggregator parameters.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIContextAggregatorPair.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
|
||||
user = OpenAIRealtimeUserContextAggregator(context, params=user_params)
|
||||
|
||||
assistant_params.expect_stripped_words = False
|
||||
assistant = OpenAIRealtimeAssistantContextAggregator(context, params=assistant_params)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
@@ -20,7 +20,6 @@ from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.adapters.services.perplexity_adapter import PerplexityLLMAdapter
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
@@ -126,7 +125,7 @@ class PerplexityLLMService(OpenAILLMService):
|
||||
|
||||
return params
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle
|
||||
|
||||
@@ -20,7 +20,6 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
@@ -138,9 +137,7 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
return params
|
||||
|
||||
@traced_llm # type: ignore
|
||||
async def _process_context(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
async def _process_context(self, context: LLMContext) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""Process OpenAI LLM context and stream chat completion chunks.
|
||||
|
||||
This method handles the streaming response from SambaNova API, including
|
||||
@@ -163,11 +160,7 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream = await (
|
||||
self._stream_chat_completions_specific_context(context)
|
||||
if isinstance(context, OpenAILLMContext)
|
||||
else self._stream_chat_completions_universal_context(context)
|
||||
)
|
||||
chunk_stream = await self.get_chat_completions(context)
|
||||
|
||||
# Use context manager to ensure stream is closed on cancellation/exception.
|
||||
# Without this, CancelledError during iteration leaves the underlying socket open.
|
||||
|
||||
@@ -46,15 +46,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven
|
||||
@@ -404,13 +396,8 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
context = (
|
||||
frame.context
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
await self._handle_context(context)
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
await self._handle_context(frame.context)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self.stop_all_metrics()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -629,40 +616,3 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
self._bot_responding = "text"
|
||||
await self.push_frame(LLMTextFrame(text=text or delta))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create an instance of LLMContextAggregatorPair from an OpenAILLMContext.
|
||||
|
||||
Constructor keyword arguments for both the user and assistant aggregators can be provided.
|
||||
|
||||
NOTE: this method exists only for backward compatibility. New code
|
||||
should instead do::
|
||||
|
||||
context = LLMContext(...)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
Args:
|
||||
context: The LLM context to use.
|
||||
user_params: User aggregator parameters. Defaults to LLMUserAggregatorParams().
|
||||
assistant_params: Assistant aggregator parameters. Defaults to LLMAssistantAggregatorParams().
|
||||
|
||||
Returns:
|
||||
A pair of user and assistant context aggregators.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
# from_openai_context handles deprecation warning
|
||||
context = LLMContext.from_openai_context(context)
|
||||
assistant_params.expect_stripped_words = False
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
@@ -18,57 +18,12 @@ from loguru import logger
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMService,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrokContextAggregatorPair:
|
||||
"""Pair of context aggregators for user and assistant interactions.
|
||||
|
||||
Provides a convenient container for managing both user and assistant
|
||||
context aggregators together for Grok LLM interactions.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`GrokContextAggregatorPair` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator instance.
|
||||
_assistant: The assistant context aggregator instance.
|
||||
"""
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
_user: OpenAIUserContextAggregator
|
||||
_assistant: OpenAIAssistantContextAggregator
|
||||
|
||||
def user(self) -> OpenAIUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> OpenAIAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrokLLMSettings(BaseOpenAILLMService.Settings):
|
||||
"""Settings for GrokLLMService."""
|
||||
@@ -147,7 +102,7 @@ class GrokLLMService(OpenAILLMService):
|
||||
logger.debug(f"Creating Grok client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
async def _process_context(self, context: LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle Grok's
|
||||
@@ -213,38 +168,3 @@ class GrokLLMService(OpenAILLMService):
|
||||
|
||||
if tokens.reasoning_tokens is not None:
|
||||
self._reasoning_tokens = tokens.reasoning_tokens
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> GrokContextAggregatorPair:
|
||||
"""Create an instance of GrokContextAggregatorPair from an OpenAILLMContext.
|
||||
|
||||
Constructor keyword arguments for both the user and assistant aggregators
|
||||
can be provided.
|
||||
|
||||
Args:
|
||||
context: The LLM context to create aggregators for.
|
||||
user_params: Parameters for configuring the user aggregator.
|
||||
assistant_params: Parameters for configuring the assistant aggregator.
|
||||
|
||||
Returns:
|
||||
GrokContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
GrokContextAggregatorPair.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`create_context_aggregator()` is deprecated and will be removed in a future version.
|
||||
Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
|
||||
See `OpenAILLMContext` docstring for migration guide.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
# Aggregators handle deprecation warnings
|
||||
user = OpenAIUserContextAggregator(context, params=user_params)
|
||||
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
||||
|
||||
return GrokContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -46,14 +46,9 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import (
|
||||
@@ -946,26 +941,3 @@ class GrokRealtimeLLMService(LLMService):
|
||||
output=json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create context aggregators for the Grok Realtime service.
|
||||
|
||||
Args:
|
||||
context: The LLM context.
|
||||
user_params: User aggregator parameters.
|
||||
assistant_params: Assistant aggregator parameters.
|
||||
|
||||
Returns:
|
||||
LLMContextAggregatorPair for user and assistant context aggregation.
|
||||
"""
|
||||
context = LLMContext.from_openai_context(context)
|
||||
assistant_params.expect_stripped_words = False
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
@@ -480,7 +480,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
self, audio_frame: InputAudioRawFrame, vad_state: VADState
|
||||
) -> VADState:
|
||||
"""Handle Voice Activity Detection results and generate appropriate frames."""
|
||||
if self._params.turn_analyzer or self._deprecated_openaillmcontext:
|
||||
if self._params.turn_analyzer:
|
||||
return await self._deprecated_old_handle_vad(audio_frame, vad_state)
|
||||
else:
|
||||
return await self._deprecated_new_handle_vad(audio_frame, vad_state)
|
||||
|
||||
@@ -47,7 +47,6 @@ def _get_provider_name_from_service_name(service_name: str) -> str:
|
||||
"AzureLLMService": "az.ai.openai",
|
||||
# Google
|
||||
"GoogleLLMService": "gcp.gemini",
|
||||
"GoogleLLMOpenAIBetaService": "gcp.gemini",
|
||||
"GoogleVertexLLMService": "gcp.vertex_ai",
|
||||
# Others
|
||||
"GrokLLMService": "xai",
|
||||
|
||||
@@ -24,7 +24,6 @@ if TYPE_CHECKING:
|
||||
from opentelemetry import trace
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.utils.tracing.service_attributes import (
|
||||
add_gemini_live_span_attributes,
|
||||
add_llm_span_attributes,
|
||||
@@ -459,40 +458,30 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
self.push_frame = traced_push_frame
|
||||
|
||||
# Get messages for logging
|
||||
# For OpenAILLMContext: use context's own get_messages_for_logging() method
|
||||
# For LLMContext: use adapter's get_messages_for_logging() which returns
|
||||
# Use adapter's get_messages_for_logging() which returns
|
||||
# messages in provider's native format with sensitive data sanitized
|
||||
messages = None
|
||||
serialized_messages = None
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
# OpenAILLMContext and subclasses have their own method
|
||||
messages = context.get_messages_for_logging()
|
||||
elif isinstance(context, LLMContext):
|
||||
# Universal LLMContext - use adapter for provider-native format
|
||||
if hasattr(self, "get_llm_adapter"):
|
||||
adapter = self.get_llm_adapter()
|
||||
messages = adapter.get_messages_for_logging(context)
|
||||
# Use adapter for provider-native format
|
||||
if hasattr(self, "get_llm_adapter"):
|
||||
adapter = self.get_llm_adapter()
|
||||
messages = adapter.get_messages_for_logging(context)
|
||||
|
||||
# Serialize messages if available
|
||||
if messages:
|
||||
serialized_messages = json.dumps(messages)
|
||||
|
||||
# Get tools
|
||||
# For OpenAILLMContext: tools may need adapter conversion if set
|
||||
# For LLMContext: use adapter's from_standard_tools() to convert ToolsSchema
|
||||
# Use adapter's from_standard_tools() to convert ToolsSchema
|
||||
tools = None
|
||||
serialized_tools = None
|
||||
tool_count = 0
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
# OpenAILLMContext: tools property handles adapter conversion internally
|
||||
tools = context.tools
|
||||
elif isinstance(context, LLMContext):
|
||||
# Universal LLMContext - use adapter to convert ToolsSchema
|
||||
if hasattr(self, "get_llm_adapter") and hasattr(context, "tools"):
|
||||
adapter = self.get_llm_adapter()
|
||||
tools = adapter.from_standard_tools(context.tools)
|
||||
# Use adapter to convert ToolsSchema
|
||||
if hasattr(self, "get_llm_adapter") and hasattr(context, "tools"):
|
||||
adapter = self.get_llm_adapter()
|
||||
tools = adapter.from_standard_tools(context.tools)
|
||||
|
||||
# Serialize and count tools if available
|
||||
# Check if tools is not None and not NOT_GIVEN
|
||||
@@ -501,36 +490,27 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
tool_count = len(tools) if isinstance(tools, list) else 1
|
||||
|
||||
# Handle system message for different services
|
||||
system_message = None
|
||||
if isinstance(context, LLMContext):
|
||||
# settings.system_instruction takes priority (matches service behavior)
|
||||
if hasattr(self, "_settings") and getattr(
|
||||
self._settings, "system_instruction", None
|
||||
):
|
||||
system_message = self._settings.system_instruction
|
||||
else:
|
||||
# Fall back to extracting from context messages
|
||||
ctx_messages = context.get_messages()
|
||||
if ctx_messages:
|
||||
first = ctx_messages[0]
|
||||
if (
|
||||
isinstance(first, dict)
|
||||
and first.get("role") == "system"
|
||||
):
|
||||
content = first.get("content")
|
||||
if isinstance(content, str):
|
||||
system_message = content
|
||||
elif isinstance(content, list):
|
||||
system_message = " ".join(
|
||||
part.get("text", "")
|
||||
for part in content
|
||||
if isinstance(part, dict)
|
||||
and part.get("type") == "text"
|
||||
)
|
||||
elif hasattr(context, "system"):
|
||||
system_message = context.system
|
||||
elif hasattr(context, "system_message"):
|
||||
system_message = context.system_message
|
||||
# settings.system_instruction takes priority (matches service behavior)
|
||||
if hasattr(self, "_settings") and getattr(
|
||||
self._settings, "system_instruction", None
|
||||
):
|
||||
system_message = self._settings.system_instruction
|
||||
else:
|
||||
# Fall back to extracting from context messages
|
||||
ctx_messages = context.get_messages()
|
||||
if ctx_messages:
|
||||
first = ctx_messages[0]
|
||||
if isinstance(first, dict) and first.get("role") == "system":
|
||||
content = first.get("content")
|
||||
if isinstance(content, str):
|
||||
system_message = content
|
||||
elif isinstance(content, list):
|
||||
system_message = " ".join(
|
||||
part.get("text", "")
|
||||
for part in content
|
||||
if isinstance(part, dict)
|
||||
and part.get("type") == "text"
|
||||
)
|
||||
|
||||
# Use given_fields() defensively in case a service doesn't
|
||||
# initialize all settings.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,13 +4,16 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
FunctionCallFromLLM,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsStartedFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
@@ -26,6 +29,7 @@ from pipecat.frames.frames import (
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TranslationFrame,
|
||||
UserMuteStartedFrame,
|
||||
@@ -588,6 +592,165 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_stop)
|
||||
self.assertEqual(stop_message.content, "Hello from Pipecat!")
|
||||
|
||||
async def test_multiple_text_with_spaces(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
def make_text_frame(text: str) -> TextFrame:
|
||||
frame = TextFrame(text=text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame("Hello "),
|
||||
make_text_frame("Pipecat. "),
|
||||
make_text_frame("How are "),
|
||||
make_text_frame("you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [LLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hello Pipecat. How are you?"
|
||||
|
||||
async def test_multiple_text_stripped(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello"),
|
||||
TextFrame(text="Pipecat."),
|
||||
TextFrame(text="How are"),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [LLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hello Pipecat. How are you?"
|
||||
|
||||
async def test_multiple_text_mixed_spaces(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
def make_text_frame(text: str, includes_spaces: bool) -> TextFrame:
|
||||
frame = TextFrame(text=text)
|
||||
frame.includes_inter_frame_spaces = includes_spaces
|
||||
return frame
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame("Hello ", includes_spaces=True),
|
||||
make_text_frame("Pipecat. ", includes_spaces=True),
|
||||
make_text_frame("Here's some", includes_spaces=True),
|
||||
make_text_frame(
|
||||
" code:", includes_spaces=True
|
||||
), # Validates ending includes_inter_frame_spaces run with no space
|
||||
make_text_frame("```python\nprint('Hello, World!')\n```", includes_spaces=False),
|
||||
make_text_frame(
|
||||
"```javascript\nconsole.log('Hello, World!');\n```", includes_spaces=False
|
||||
),
|
||||
make_text_frame(
|
||||
" And some more: ", includes_spaces=True
|
||||
), # Validates starting includes_inter_frame_spaces run with a space and ending it with no space
|
||||
make_text_frame("```html\n<div>Hello, World!</div>\n```", includes_spaces=False),
|
||||
make_text_frame(
|
||||
"Hope that ", includes_spaces=True
|
||||
), # Validates starting includes_inter_frame_spaces run with no space
|
||||
make_text_frame("helps!", includes_spaces=True),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [LLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == (
|
||||
"Hello Pipecat. Here's some code: "
|
||||
"```python\nprint('Hello, World!')\n``` "
|
||||
"```javascript\nconsole.log('Hello, World!');\n``` "
|
||||
"And some more: "
|
||||
"```html\n<div>Hello, World!</div>\n``` "
|
||||
"Hope that helps!"
|
||||
)
|
||||
|
||||
async def test_multiple_responses(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
def make_text_frame(text: str) -> TextFrame:
|
||||
frame = TextFrame(text=text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame("Hello "),
|
||||
make_text_frame("Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame(text="How are "),
|
||||
make_text_frame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
LLMContextFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hello Pipecat."
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_multiple_responses_interruption(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
def make_text_frame(text: str) -> TextFrame:
|
||||
frame = TextFrame(text=text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame("Hello "),
|
||||
make_text_frame("Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
SleepFrame(0.15),
|
||||
InterruptionFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
make_text_frame("How are "),
|
||||
make_text_frame("you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
LLMContextFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hello Pipecat."
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_interruption(self):
|
||||
context = LLMContext()
|
||||
|
||||
@@ -635,6 +798,67 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(stop_messages[0].content, "Hello")
|
||||
self.assertEqual(stop_messages[1].content, "Hello there!")
|
||||
|
||||
async def test_function_call(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
frames_to_send = [
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
SleepFrame(),
|
||||
FunctionCallResultFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
result={"conditions": "Sunny"},
|
||||
),
|
||||
]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert json.loads(context.messages[-1]["content"]) == {"conditions": "Sunny"}
|
||||
|
||||
async def test_function_call_on_context_updated(self):
|
||||
context_updated = False
|
||||
|
||||
async def on_context_updated():
|
||||
nonlocal context_updated
|
||||
context_updated = True
|
||||
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
frames_to_send = [
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
SleepFrame(),
|
||||
FunctionCallResultFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={"location": "Los Angeles"},
|
||||
result={"conditions": "Sunny"},
|
||||
properties=FunctionCallResultProperties(on_context_updated=on_context_updated),
|
||||
),
|
||||
SleepFrame(),
|
||||
]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert json.loads(context.messages[-1]["content"]) == {"conditions": "Sunny"}
|
||||
assert context_updated
|
||||
|
||||
async def test_thought(self):
|
||||
context = LLMContext()
|
||||
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Unit tests for Google LLM OpenAI Beta service."""
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
|
||||
try:
|
||||
from pipecat.services.google.openai.llm import GoogleLLMOpenAIBetaService
|
||||
|
||||
google_available = True
|
||||
except Exception:
|
||||
google_available = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not google_available, reason="Google dependencies not installed")
|
||||
async def test_google_llm_openai_stream_closed_on_cancellation():
|
||||
"""Test that the stream is closed when CancelledError occurs during iteration.
|
||||
|
||||
This prevents socket leaks when the pipeline is interrupted (e.g., user interruption).
|
||||
See issue #3639.
|
||||
"""
|
||||
with patch.object(GoogleLLMOpenAIBetaService, "create_client"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
service = GoogleLLMOpenAIBetaService(api_key="test-key", model="test-model")
|
||||
service._client = AsyncMock()
|
||||
|
||||
stream_closed = False
|
||||
|
||||
class MockAsyncStream:
|
||||
"""Mock AsyncStream that tracks close() calls and raises CancelledError."""
|
||||
|
||||
def __init__(self):
|
||||
self.iteration_count = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
nonlocal stream_closed
|
||||
stream_closed = True
|
||||
return False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
self.iteration_count += 1
|
||||
if self.iteration_count > 1:
|
||||
raise asyncio.CancelledError()
|
||||
mock_chunk = AsyncMock()
|
||||
mock_chunk.usage = None
|
||||
mock_chunk.choices = []
|
||||
return mock_chunk
|
||||
|
||||
mock_stream = MockAsyncStream()
|
||||
|
||||
service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream)
|
||||
service.start_ttfb_metrics = AsyncMock()
|
||||
service.stop_ttfb_metrics = AsyncMock()
|
||||
service.start_llm_usage_metrics = AsyncMock()
|
||||
|
||||
context = OpenAILLMContext(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await service._process_context(context)
|
||||
|
||||
assert stream_closed, "Stream should be closed even when CancelledError occurs"
|
||||
@@ -84,61 +84,6 @@ async def test_openai_run_inference_with_llm_context():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_with_openai_llm_context():
|
||||
"""Test run_inference with OpenAILLMContext returns expected response."""
|
||||
# Create service with mocked client and specific parameters
|
||||
with patch.object(OpenAILLMService, "create_client"):
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
|
||||
params = BaseOpenAILLMService.InputParams(
|
||||
temperature=0.8, max_completion_tokens=150, presence_penalty=0.3, top_p=0.9
|
||||
)
|
||||
service = OpenAILLMService(model="gpt-4", params=params)
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Create OpenAILLMContext
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
],
|
||||
tools=OPENAI_NOT_GIVEN,
|
||||
tool_choice=OPENAI_NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello! How can I help you today?"
|
||||
service._client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service._client.chat.completions.create.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
stream=False,
|
||||
frequency_penalty=OPENAI_NOT_GIVEN,
|
||||
presence_penalty=0.3,
|
||||
seed=OPENAI_NOT_GIVEN,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
max_tokens=OPENAI_NOT_GIVEN,
|
||||
max_completion_tokens=150,
|
||||
service_tier=OPENAI_NOT_GIVEN,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
],
|
||||
tools=OPENAI_NOT_GIVEN,
|
||||
tool_choice=OPENAI_NOT_GIVEN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_run_inference_client_exception():
|
||||
"""Test that exceptions from the client are propagated."""
|
||||
@@ -209,54 +154,6 @@ async def test_anthropic_run_inference_with_llm_context():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_with_openai_llm_context():
|
||||
"""Test run_inference with OpenAILLMContext returns expected response for Anthropic."""
|
||||
# Create service with mocked client and specific parameters
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
|
||||
params = AnthropicLLMService.InputParams(max_tokens=1024, temperature=0.7, top_k=40, top_p=0.9)
|
||||
service = AnthropicLLMService(
|
||||
api_key="test-key", model="claude-3-sonnet-20240229", params=params
|
||||
)
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Create OpenAILLMContext
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
],
|
||||
tools=NOT_GIVEN,
|
||||
tool_choice=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Hello! How can I help you today?"
|
||||
service._client.beta.messages.create.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
service._client.beta.messages.create.assert_called_once_with(
|
||||
model="claude-3-sonnet-20240229",
|
||||
max_tokens=1024,
|
||||
stream=False,
|
||||
temperature=0.7,
|
||||
top_k=40,
|
||||
top_p=0.9,
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
system="You are a helpful assistant",
|
||||
tools=[],
|
||||
betas=["interleaved-thinking-2025-05-14"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_run_inference_client_exception():
|
||||
"""Test that exceptions from the Anthropic client are propagated."""
|
||||
@@ -336,61 +233,6 @@ async def test_google_run_inference_client_exception():
|
||||
await service.run_inference(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_run_inference_with_openai_llm_context():
|
||||
"""Test run_inference with OpenAILLMContext returns expected response for Google."""
|
||||
# Create service with mocked client and specific parameters
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
|
||||
params = GoogleLLMService.InputParams(max_tokens=256, temperature=0.4, top_k=30, top_p=0.75)
|
||||
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash", params=params)
|
||||
service._client = AsyncMock()
|
||||
|
||||
# Create OpenAILLMContext
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
],
|
||||
tools=NOT_GIVEN,
|
||||
tool_choice=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [MagicMock()]
|
||||
mock_response.candidates[0].content = MagicMock()
|
||||
mock_response.candidates[0].content.parts = [MagicMock()]
|
||||
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
|
||||
service._client.aio = AsyncMock()
|
||||
service._client.aio.models = AsyncMock()
|
||||
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
result = await service.run_inference(context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
|
||||
# Verify the call includes configured parameters
|
||||
call_kwargs = service._client.aio.models.generate_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "gemini-2.0-flash"
|
||||
# Contents is a Google Content object, so check its structure
|
||||
contents = call_kwargs["contents"]
|
||||
assert len(contents) == 1
|
||||
assert contents[0].role == "user"
|
||||
assert len(contents[0].parts) == 1
|
||||
assert contents[0].parts[0].text == "Hello, world!"
|
||||
assert "config" in call_kwargs
|
||||
config = call_kwargs["config"]
|
||||
# Config is a GenerateContentConfig object, so access attributes
|
||||
assert config.system_instruction == "You are a helpful assistant"
|
||||
assert config.temperature == 0.4
|
||||
assert config.top_k == 30
|
||||
assert config.top_p == 0.75
|
||||
assert config.max_output_tokens == 256
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_with_llm_context():
|
||||
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
|
||||
@@ -445,57 +287,6 @@ async def test_aws_bedrock_run_inference_with_llm_context():
|
||||
assert call_kwargs["inferenceConfig"]["topP"] == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_with_openai_llm_context():
|
||||
"""Test run_inference with OpenAILLMContext returns expected response for AWS Bedrock."""
|
||||
# Create service with specific parameters
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
|
||||
params = AWSBedrockLLMService.InputParams(max_tokens=512, temperature=0.8, top_p=0.95)
|
||||
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0", params=params)
|
||||
|
||||
# Create OpenAILLMContext
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello, world!"},
|
||||
],
|
||||
tools=NOT_GIVEN,
|
||||
tool_choice=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Mock the client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = {
|
||||
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
|
||||
}
|
||||
mock_client.converse.return_value = mock_response
|
||||
|
||||
# Patch the _aws_session.client method to be an async context manager
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
|
||||
# Execute
|
||||
result = await service.run_inference(context)
|
||||
|
||||
# Verify
|
||||
assert result == "Hello! How can I help you today?"
|
||||
|
||||
# Verify the call includes configured parameters
|
||||
call_kwargs = mock_client.converse.call_args.kwargs
|
||||
assert call_kwargs["modelId"] == "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
assert call_kwargs["messages"] == [{"role": "user", "content": [{"text": "Hello, world!"}]}]
|
||||
assert call_kwargs["system"] == [{"text": "You are a helpful assistant"}]
|
||||
assert call_kwargs["additionalModelRequestFields"] == {}
|
||||
assert "inferenceConfig" in call_kwargs
|
||||
assert call_kwargs["inferenceConfig"]["maxTokens"] == 512
|
||||
assert call_kwargs["inferenceConfig"]["temperature"] == 0.8
|
||||
assert call_kwargs["inferenceConfig"]["topP"] == 0.95
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aws_bedrock_run_inference_client_exception():
|
||||
"""Test that exceptions from the AWS Bedrock client are propagated."""
|
||||
|
||||
Reference in New Issue
Block a user