Merge pull request #1308 from Vaibhav159/vl_google_openai_format
adding GoogleLLMOpenAIBetaService
This commit is contained in:
@@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
the `RTVIObserver` and will be delivered to the client's `onServerMessage`
|
||||
callback or `ServerMessage` event.
|
||||
|
||||
- Added `GoogleLLMOpenAIBetaService` for Google LLM integration with an
|
||||
OpenAI-compatible interface. Added foundational example
|
||||
`14o-function-calling-gemini-openai-format.py`.
|
||||
|
||||
## [0.0.58] - 2025-02-26
|
||||
|
||||
### Added
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.google import GoogleLLMOpenAIBetaService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
"""Push a frame to the LLM; this is handy when the LLM response might take a while."""
|
||||
await llm.push_frame(TTSSpeakFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY", ""),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||
)
|
||||
|
||||
llm = GoogleLLMOpenAIBetaService(api_key=os.getenv("GEMINI_API_KEY"))
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(
|
||||
"get_current_weather", fetch_weather_from_api, start_callback=start_fetch_weather
|
||||
)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Start a conversation with 'Hey there' to get the current weather.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -12,6 +12,8 @@ import os
|
||||
import time
|
||||
|
||||
from google.api_core.exceptions import DeadlineExceeded
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
@@ -54,7 +56,10 @@ from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import ImageGenService, LLMService, STTService, TTSService
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.openai import (
|
||||
BaseOpenAILLMService,
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMService,
|
||||
OpenAIUnhandledFunctionException,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -1188,6 +1193,120 @@ class GoogleLLMService(LLMService):
|
||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
"""This class implements inference with Google's AI LLM models using the OpenAI format.
|
||||
Ref - https://ai.google.dev/gemini-api/docs/openai
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
model: str = "gemini-2.0-flash",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
|
||||
context
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We're streaming the LLM response to enable the fastest response times.
|
||||
# For text, we just yield each chunk as we receive it and count on consumers
|
||||
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||
#
|
||||
# If the LLM is a function call, we'll do some coalescing here.
|
||||
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||
# that they can start preparing to call the function with that name.
|
||||
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||
# the response is done, we package up all the arguments and the function name and
|
||||
# yield a frame containing the function name and the arguments.
|
||||
logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}")
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
logger.debug(
|
||||
f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
|
||||
)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if function_name == "":
|
||||
# TODO: Remove the _process_context method once Google resolves the bug
|
||||
# where the index is incorrectly set to None instead of returning the actual index,
|
||||
# which currently results in an empty function name('').
|
||||
continue
|
||||
if self.has_function(function_name):
|
||||
run_llm = False
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
|
||||
class GoogleTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
pitch: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user