Files
pipecat/src/pipecat/services/openai.py
Aleix Conchillo Flaqué 96fa6c19a8 Merge pull request #1398 from nicougou/feature/openai_custom_headers
feature: add custom headers to AsyncOpenAI
2025-03-20 13:45:57 -07:00

623 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import base64
import io
import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
import aiohttp
import httpx
from loguru import logger
from openai import (
NOT_GIVEN,
AsyncOpenAI,
AsyncStream,
BadRequestError,
DefaultAsyncHttpxClient,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
ErrorFrame,
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
URLImageRawFrame,
UserImageRawFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import (
ImageGenService,
LLMService,
TTSService,
)
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
from pipecat.transcriptions.language import Language
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
VALID_VOICES: Dict[str, ValidVoice] = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
class OpenAIUnhandledFunctionException(Exception):
pass
class BaseOpenAILLMService(LLMService):
"""This is the base for all services that use the AsyncOpenAI client.
This service consumes OpenAILLMContextFrame frames, which contain a reference
to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
sent to the LLM for a completion. This includes user, assistant and system messages
as well as tool choices and the tool, which is used if requesting function
calls from the LLM.
"""
class InputParams(BaseModel):
frequency_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
)
presence_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
)
seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
# Note: top_k is currently not supported by the OpenAI client library,
# so top_k is ignored right now.
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
def __init__(
self,
*,
model: str,
api_key=None,
base_url=None,
organization=None,
project=None,
default_headers: Mapping[str, str] | None = None,
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(**kwargs)
self._settings = {
"frequency_penalty": params.frequency_penalty,
"presence_penalty": params.presence_penalty,
"seed": params.seed,
"temperature": params.temperature,
"top_p": params.top_p,
"max_tokens": params.max_tokens,
"max_completion_tokens": params.max_completion_tokens,
"extra": params.extra if isinstance(params.extra, dict) else {},
}
self.set_model_name(model)
self._client = self.create_client(
api_key=api_key, base_url=base_url, organization=organization, project=project, default_headers=default_headers, **kwargs
)
def create_client(self, api_key=None, base_url=None, organization=None, project=None, default_headers=None, **kwargs):
return AsyncOpenAI(
api_key=api_key,
base_url=base_url,
organization=organization,
project=project,
http_client=DefaultAsyncHttpxClient(
limits=httpx.Limits(
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
)
),
default_headers=default_headers
)
def can_generate_metrics(self) -> bool:
return True
async def get_chat_completions(
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
) -> AsyncStream[ChatCompletionChunk]:
params = {
"model": self.model_name,
"stream": True,
"messages": messages,
"tools": context.tools,
"tool_choice": context.tool_choice,
"stream_options": {"include_usage": True},
"frequency_penalty": self._settings["frequency_penalty"],
"presence_penalty": self._settings["presence_penalty"],
"seed": self._settings["seed"],
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"max_tokens": self._settings["max_tokens"],
"max_completion_tokens": self._settings["max_completion_tokens"],
}
params.update(self._settings["extra"])
chunks = await self._client.chat.completions.create(**params)
return chunks
async def _stream_chat_completions(
self, context: OpenAILLMContext
) -> AsyncStream[ChatCompletionChunk]:
logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
messages: List[ChatCompletionMessageParam] = context.get_messages()
# base64 encode any images
for message in messages:
if message.get("mime_type") == "image/jpeg":
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
text = message["content"]
message["content"] = [
{"type": "text", "text": text},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
]
del message["data"]
del message["mime_type"]
chunks = await self.get_chat_completions(context, messages)
return chunks
async def _process_context(self, context: OpenAILLMContext):
functions_list = []
arguments_list = []
tool_id_list = []
func_idx = 0
function_name = ""
arguments = ""
tool_call_id = ""
await self.start_ttfb_metrics()
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
context
)
async for chunk in chunk_stream:
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
)
await self.start_llm_usage_metrics(tokens)
if chunk.choices is None or len(chunk.choices) == 0:
continue
await self.stop_ttfb_metrics()
if not chunk.choices[0].delta:
continue
if chunk.choices[0].delta.tool_calls:
# We're streaming the LLM response to enable the fastest response times.
# For text, we just yield each chunk as we receive it and count on consumers
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
#
# If the LLM is a function call, we'll do some coalescing here.
# If the response contains a function name, we'll yield a frame to tell consumers
# that they can start preparing to call the function with that name.
# We accumulate all the arguments for the rest of the streamed response, then when
# the response is done, we package up all the arguments and the function name and
# yield a frame containing the function name and the arguments.
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != func_idx:
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
function_name = ""
arguments = ""
tool_call_id = ""
func_idx += 1
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
tool_call_id = tool_call.id
if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to
# the context, and re-prompt to get a chat answer. If we don't have a registered
# handler, raise an exception.
if function_name and arguments:
# added to the list as last function name and arguments not added to the list
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
for index, (function_name, arguments, tool_id) in enumerate(
zip(functions_list, arguments_list, tool_id_list), start=1
):
if self.has_function(function_name):
run_llm = False
arguments = json.loads(arguments)
await self.call_function(
context=context,
function_name=function_name,
arguments=arguments,
tool_call_id=tool_id,
run_llm=run_llm,
)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)
if context:
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self._process_context(context)
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
@dataclass
class OpenAIContextAggregatorPair:
_user: "OpenAIUserContextAggregator"
_assistant: "OpenAIAssistantContextAggregator"
def user(self) -> "OpenAIUserContextAggregator":
return self._user
def assistant(self) -> "OpenAIAssistantContextAggregator":
return self._assistant
class OpenAILLMService(BaseOpenAILLMService):
def __init__(
self,
*,
model: str = "gpt-4o",
params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
**kwargs,
):
super().__init__(model=model, params=params, **kwargs)
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
) -> OpenAIContextAggregatorPair:
"""Create an instance of OpenAIContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
user = OpenAIUserContextAggregator(context, **user_kwargs)
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
class OpenAIImageGenService(ImageGenService):
def __init__(
self,
*,
api_key: str,
aiohttp_session: aiohttp.ClientSession,
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
model: str = "dall-e-3",
):
super().__init__()
self.set_model_name(model)
self._image_size = image_size
self._client = AsyncOpenAI(api_key=api_key)
self._aiohttp_session = aiohttp_session
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")
image = await self._client.images.generate(
prompt=prompt, model=self.model_name, n=1, size=self._image_size
)
image_url = image.data[0].url
if not image_url:
logger.error(f"{self} No image provided in response: {image}")
yield ErrorFrame("Image generation failed")
return
# Load the image from the url
async with self._aiohttp_session.get(image_url) as response:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
yield frame
class OpenAISTTService(BaseWhisperSTTService):
"""OpenAI Speech-to-Text service that generates text from audio.
Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key
set via the api_key parameter or OPENAI_API_KEY environment variable.
Args:
model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe".
api_key: OpenAI API key. Defaults to None.
base_url: API base URL. Defaults to None.
language: Language of the audio input. Defaults to English.
prompt: Optional text to guide the model's style or continue a previous segment.
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
**kwargs: Additional arguments passed to BaseWhisperSTTService.
"""
def __init__(
self,
*,
model: str = "gpt-4o-transcribe",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
language: Optional[Language] = Language.EN,
prompt: Optional[str] = None,
temperature: Optional[float] = None,
**kwargs,
):
super().__init__(
model=model,
api_key=api_key,
base_url=base_url,
language=language,
prompt=prompt,
temperature=temperature,
**kwargs,
)
async def _transcribe(self, audio: bytes) -> Transcription:
assert self._language is not None # Assigned in the BaseWhisperSTTService class
# Build kwargs dict with only set parameters
kwargs = {
"file": ("audio.wav", audio, "audio/wav"),
"model": self.model_name,
"language": self._language,
}
if self._prompt is not None:
kwargs["prompt"] = self._prompt
if self._temperature is not None:
kwargs["temperature"] = self._temperature
return await self._client.audio.transcriptions.create(**kwargs)
class OpenAITTSService(TTSService):
"""OpenAI Text-to-Speech service that generates audio from text.
This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz.
Args:
api_key: OpenAI API key. Defaults to None.
voice: Voice ID to use. Defaults to "alloy".
model: TTS model to use. Defaults to "gpt-4o-mini-tts".
sample_rate: Output audio sample rate in Hz. Defaults to None.
**kwargs: Additional keyword arguments passed to TTSService.
The service returns PCM-encoded audio at the specified sample rate.
"""
OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz
def __init__(
self,
*,
api_key: Optional[str] = None,
voice: str = "alloy",
model: str = "gpt-4o-mini-tts",
sample_rate: Optional[int] = None,
**kwargs,
):
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
logger.warning(
f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
f"Current rate of {self.sample_rate}Hz may cause issues."
)
super().__init__(sample_rate=sample_rate, **kwargs)
self.set_model_name(model)
self.set_voice(voice)
self._client = AsyncOpenAI(api_key=api_key)
def can_generate_metrics(self) -> bool:
return True
async def set_model(self, model: str):
logger.info(f"Switching TTS model to: [{model}]")
self.set_model_name(model)
async def start(self, frame: StartFrame):
await super().start(frame)
if self.sample_rate != self.OPENAI_SAMPLE_RATE:
logger.warning(
f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
f"Current rate of {self.sample_rate}Hz may cause issues."
)
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
async with self._client.audio.speech.with_streaming_response.create(
input=text or " ", # Text must contain at least one character
model=self.model_name,
voice=VALID_VOICES[self._voice_id],
response_format="pcm",
) as r:
if r.status_code != 200:
error = await r.text()
logger.error(
f"{self} error getting audio (status: {r.status_code}, error: {error})"
)
yield ErrorFrame(
f"Error getting audio (status: {r.status_code}, error: {error})"
)
return
await self.start_tts_usage_metrics(text)
CHUNK_SIZE = 1024
yield TTSStartedFrame()
async for chunk in r.iter_bytes(CHUNK_SIZE):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
yield frame
yield TTSStoppedFrame()
except BadRequestError as e:
logger.exception(f"{self} error generating TTS: {e}")
class OpenAIUserContextAggregator(LLMUserContextAggregator):
pass
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
self._context.add_message(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
}
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: str
):
for message in self._context.messages:
if (
message["role"] == "tool"
and message["tool_call_id"]
and message["tool_call_id"] == tool_call_id
):
message["content"] = result
async def handle_user_image_frame(self, frame: UserImageRawFrame):
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)