From 4824220260770f2561e51df5523d1ee486f36d9e Mon Sep 17 00:00:00 2001 From: Vaibhav159 Date: Thu, 27 Feb 2025 21:30:06 +0530 Subject: [PATCH 1/2] adding GoogleLLMOpenAIBetaService --- CHANGELOG.md | 7 +++++++ src/pipecat/services/google/google.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff7b62e9b..d294ab6e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Added `GoogleLLMOpenAIBetaService` for Google LLM integration with an + OpenAI-compatible interface. + ## [0.0.58] - 2025-02-26 ### Added diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index c0941ee33..cd51f35e8 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -54,6 +54,7 @@ from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import ImageGenService, LLMService, STTService, TTSService from pipecat.services.google.frames import LLMSearchResponseFrame from pipecat.services.openai import ( + BaseOpenAILLMService, OpenAIAssistantContextAggregator, OpenAIUserContextAggregator, ) @@ -1188,6 +1189,22 @@ class GoogleLLMService(LLMService): return GoogleContextAggregatorPair(_user=user, _assistant=assistant) +class GoogleLLMOpenAIBetaService(BaseOpenAILLMService): + """This class implements inference with Google's AI LLM models using the OpenAI format. + Ref - https://ai.google.dev/gemini-api/docs/openai + """ + + def __init__( + self, + *, + api_key: str, + base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/", + model: str = "gemini-2.0-flash", + **kwargs, + ): + super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) + + class GoogleTTSService(TTSService): class InputParams(BaseModel): pitch: Optional[str] = None From 59fb6313909b9bb4929ca3e6337010a43bdc1a72 Mon Sep 17 00:00:00 2001 From: Vaibhav159 Date: Thu, 27 Feb 2025 22:55:25 +0530 Subject: [PATCH 2/2] fixing function calling and adding example --- CHANGELOG.md | 3 +- ...o-function-calling-gemini-openai-format.py | 137 ++++++++++++++++++ src/pipecat/services/google/google.py | 104 ++++++++++++- 3 files changed, 242 insertions(+), 2 deletions(-) create mode 100644 examples/foundational/14o-function-calling-gemini-openai-format.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d294ab6e5..700ce425c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `GoogleLLMOpenAIBetaService` for Google LLM integration with an - OpenAI-compatible interface. + OpenAI-compatible interface. Added foundational example + `14o-function-calling-gemini-openai-format.py`. ## [0.0.58] - 2025-02-26 diff --git a/examples/foundational/14o-function-calling-gemini-openai-format.py b/examples/foundational/14o-function-calling-gemini-openai-format.py new file mode 100644 index 000000000..4b04f9285 --- /dev/null +++ b/examples/foundational/14o-function-calling-gemini-openai-format.py @@ -0,0 +1,137 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from openai.types.chat import ChatCompletionToolParam +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import TTSSpeakFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.services.elevenlabs import ElevenLabsTTSService +from pipecat.services.google import GoogleLLMOpenAIBetaService +from pipecat.services.openai import OpenAILLMContext +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def start_fetch_weather(function_name, llm, context): + """Push a frame to the LLM; this is handy when the LLM response might take a while.""" + await llm.push_frame(TTSSpeakFrame("Let me check on that.")) + logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}") + + +async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback): + await result_callback({"conditions": "nice", "temperature": "75"}) + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = ElevenLabsTTSService( + api_key=os.getenv("ELEVENLABS_API_KEY", ""), + voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""), + ) + + llm = GoogleLLMOpenAIBetaService(api_key=os.getenv("GEMINI_API_KEY")) + # Register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_function( + "get_current_weather", fetch_weather_from_api, start_callback=start_fetch_weather + ) + + tools = [ + ChatCompletionToolParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + ) + ] + messages = [ + { + "role": "user", + "content": "Start a conversation with 'Hey there' to get the current weather.", + }, + ] + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index cd51f35e8..df8f5836a 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -12,6 +12,8 @@ import os import time from google.api_core.exceptions import DeadlineExceeded +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk # Suppress gRPC fork warnings os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" @@ -56,6 +58,8 @@ from pipecat.services.google.frames import LLMSearchResponseFrame from pipecat.services.openai import ( BaseOpenAILLMService, OpenAIAssistantContextAggregator, + OpenAILLMService, + OpenAIUnhandledFunctionException, OpenAIUserContextAggregator, ) from pipecat.transcriptions.language import Language @@ -1189,7 +1193,7 @@ class GoogleLLMService(LLMService): return GoogleContextAggregatorPair(_user=user, _assistant=assistant) -class GoogleLLMOpenAIBetaService(BaseOpenAILLMService): +class GoogleLLMOpenAIBetaService(OpenAILLMService): """This class implements inference with Google's AI LLM models using the OpenAI format. Ref - https://ai.google.dev/gemini-api/docs/openai """ @@ -1204,6 +1208,104 @@ class GoogleLLMOpenAIBetaService(BaseOpenAILLMService): ): super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) + async def _process_context(self, context: OpenAILLMContext): + functions_list = [] + arguments_list = [] + tool_id_list = [] + func_idx = 0 + function_name = "" + arguments = "" + tool_call_id = "" + + await self.start_ttfb_metrics() + + chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions( + context + ) + + async for chunk in chunk_stream: + if chunk.usage: + tokens = LLMTokenUsage( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + ) + await self.start_llm_usage_metrics(tokens) + + if chunk.choices is None or len(chunk.choices) == 0: + continue + + await self.stop_ttfb_metrics() + + if not chunk.choices[0].delta: + continue + + if chunk.choices[0].delta.tool_calls: + # We're streaming the LLM response to enable the fastest response times. + # For text, we just yield each chunk as we receive it and count on consumers + # to do whatever coalescing they need (eg. to pass full sentences to TTS) + # + # If the LLM is a function call, we'll do some coalescing here. + # If the response contains a function name, we'll yield a frame to tell consumers + # that they can start preparing to call the function with that name. + # We accumulate all the arguments for the rest of the streamed response, then when + # the response is done, we package up all the arguments and the function name and + # yield a frame containing the function name and the arguments. + logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}") + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != func_idx: + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + function_name = "" + arguments = "" + tool_call_id = "" + func_idx += 1 + if tool_call.function and tool_call.function.name: + function_name += tool_call.function.name + tool_call_id = tool_call.id + if tool_call.function and tool_call.function.arguments: + # Keep iterating through the response to collect all the argument fragments + arguments += tool_call.function.arguments + elif chunk.choices[0].delta.content: + await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content)) + + # if we got a function name and arguments, check to see if it's a function with + # a registered handler. If so, run the registered callback, save the result to + # the context, and re-prompt to get a chat answer. If we don't have a registered + # handler, raise an exception. + if function_name and arguments: + # added to the list as last function name and arguments not added to the list + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + + logger.debug( + f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}" + ) + for index, (function_name, arguments, tool_id) in enumerate( + zip(functions_list, arguments_list, tool_id_list), start=1 + ): + if function_name == "": + # TODO: Remove the _process_context method once Google resolves the bug + # where the index is incorrectly set to None instead of returning the actual index, + # which currently results in an empty function name(''). + continue + if self.has_function(function_name): + run_llm = False + arguments = json.loads(arguments) + await self.call_function( + context=context, + function_name=function_name, + arguments=arguments, + tool_call_id=tool_id, + run_llm=run_llm, + ) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) + class GoogleTTSService(TTSService): class InputParams(BaseModel):