From 07712cdb16e11fb78e5af30469504791a72646ff Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Fri, 18 Oct 2024 17:14:57 -0700 Subject: [PATCH 1/2] gemini function calling and partial implementation of standard context stuff --- .../14e-function-calling-gemini.py | 159 +++++++++++++++ pyproject.toml | 2 +- src/pipecat/services/google.py | 189 +++++++++++++++++- 3 files changed, 342 insertions(+), 8 deletions(-) create mode 100644 examples/foundational/14e-function-calling-gemini.py diff --git a/examples/foundational/14e-function-calling-gemini.py b/examples/foundational/14e-function-calling-gemini.py new file mode 100644 index 000000000..ed1b904ce --- /dev/null +++ b/examples/foundational/14e-function-calling-gemini.py @@ -0,0 +1,159 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import aiohttp +import os +import sys + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.google import GoogleLLMService +from pipecat.services.openai import OpenAILLMContext +from pipecat.transports.services.daily import DailyParams, DailyTransport + +from runner import configure + +from loguru import logger + +from dotenv import load_dotenv + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + +video_participant_id = None + + +async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback): + location = arguments["location"] + await result_callback(f"The weather in {location} is currently 72 degrees and sunny.") + + +async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback): + logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}") + question = arguments["question"] + await llm.request_image_frame(user_id=video_participant_id, text_content=question) + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY")) + llm.register_function("get_weather", get_weather) + llm.register_function("get_image", get_image) + + tools = [ + { + "function_declarations": [ + { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + } + ] + } + ] + + system_prompt = """\ +You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions. + +Your response will be turned into speech so use only simple words and punctuation. + +You have access to two tools: get_weather and get_image. + +You can respond to questions about the weather using the get_weather tool. + +You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \ +indicate you should use the get_image tool are: + - What do you see? + - What's in the video? + - Can you describe the video? + - Tell me about what you see. + - Tell me something interesting about what you see. + - What's happening in the video? +""" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "Say hello."}, + ] + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + global video_participant_id + video_participant_id = participant["id"] + transport.capture_participant_transcription(participant["id"]) + transport.capture_participant_video(video_participant_id, framerate=0) + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 7625ce68a..6687193c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ elevenlabs = [ "websockets~=13.1" ] examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ] fal = [ "fal-client~=0.4.1" ] gladia = [ "websockets~=13.1" ] -google = [ "google-generativeai~=0.7.2", "google-cloud-texttospeech~=2.17.2" ] +google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.17.2" ] gstreamer = [ "pygobject~=3.48.2" ] fireworks = [ "openai~=1.37.2" ] langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ] diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index 8fc5151a3..9a21f8f13 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -5,10 +5,14 @@ # import asyncio +from dataclasses import dataclass import json +import io from typing import AsyncGenerator, List, Literal, Optional + from loguru import logger +from PIL import Image from pydantic import BaseModel from pipecat.frames.frames import ( @@ -28,6 +32,10 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) +from pipecat.services.openai import ( + OpenAIAssistantContextAggregator, + OpenAIUserContextAggregator, +) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import LLMService, TTSService from pipecat.transcriptions.language import Language @@ -45,6 +53,148 @@ except ModuleNotFoundError as e: raise Exception(f"Missing module: {e}") +class GoogleUserContextAggregator(OpenAIUserContextAggregator): + async def _push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message({"role": "user", "parts": [glm.Part(text=self._aggregation)]}) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + # Reset our accumulator state. + self._reset() + + +class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): + async def _push_aggregation(self): + if not ( + self._aggregation or self._function_call_result or self._pending_image_frame_message + ): + return + + run_llm = False + + aggregation = self._aggregation + self._reset() + + try: + if self._function_call_result: + frame = self._function_call_result + self._function_call_result = None + if frame.result: + logger.debug(f"FunctionCallResultFrame result: {frame.arguments}") + self._context.add_message( + { + "role": "model", + "parts": [ + glm.Part( + function_call=glm.FunctionCall( + name=frame.function_name, args=frame.arguments + ) + ) + ], + } + ) + response = frame.result + if isinstance(response, str): + response = {"response": response} + self._context.add_message( + { + "role": "user", + "parts": [ + glm.Part( + function_response=glm.FunctionResponse( + name=frame.function_name, response=response + ) + ) + ], + } + ) + run_llm = not bool(self._function_calls_in_progress) + else: + self._context.add_message({"role": "model", "parts": [glm.Part(text=aggregation)]}) + + if self._pending_image_frame_message: + frame = self._pending_image_frame_message + self._pending_image_frame_message = None + self._context.add_image_frame_message( + format=frame.user_image_raw_frame.format, + size=frame.user_image_raw_frame.size, + image=frame.user_image_raw_frame.image, + text=frame.text, + ) + run_llm = True + + if run_llm: + await self._user_context_aggregator.push_context_frame() + + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + except Exception as e: + logger.exception(f"Error processing frame: {e}") + + +@dataclass +class GoogleContextAggregatorPair: + _user: "GoogleUserContextAggregator" + _assistant: "GoogleAssistantContextAggregator" + + def user(self) -> "GoogleUserContextAggregator": + return self._user + + def assistant(self) -> "GoogleAssistantContextAggregator": + return self._assistant + + +class GoogleLLMContext(OpenAILLMContext): + @staticmethod + def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext": + if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext): + logger.debug(f"Upgrading to Google: {obj}") + obj.__class__ = GoogleLLMContext + obj._restructure_from_openai_messages() + return obj + + def from_standard_message(self, message): + role = message["role"] + content = message["content"] + if role == "system": + role = "user" + elif role == "assistant": + role = "model" + + parts = [] + if isinstance(content, str): + parts.append(glm.Part(text=content)) + elif isinstance(content, list): + logger.debug("!!!NEED TO IMPL CONTENT LIST") + + message = {"role": role, "parts": parts} + return message + + def add_image_frame_message( + self, *, format: str, size: tuple[int, int], image: bytes, text: str = None + ): + buffer = io.BytesIO() + Image.frombytes(format, size, image).save(buffer, format="JPEG") + + parts = [] + if text: + parts.append(glm.Part(text=text)) + parts.append( + glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())), + ) + self.add_message({"role": "user", "parts": parts}) + + def _restructure_from_openai_messages(self): + self._messages[:] = [self.from_standard_message(m) for m in self._messages] + + class GoogleLLMService(LLMService): """This class implements inference with Google's AI models @@ -98,20 +248,34 @@ class GoogleLLMService(LLMService): async def _process_context(self, context: OpenAILLMContext): await self.push_frame(LLMFullResponseStartFrame()) try: - logger.debug(f"Generating chat: {context.get_messages_json()}") + logger.debug(f"Generating chat: {context.messages}") - messages = self._get_messages_from_openai_context(context) + # todo: move this into the new context code structure, convert from openai context one time + # todo: add system instructions + # messages = self._get_messages_from_openai_context(context) + messages = context.messages await self.start_ttfb_metrics() - response = self._client.generate_content(messages, stream=True) + tools = context.tools if context.tools else [] + response = self._client.generate_content(contents=messages, tools=tools, stream=True) await self.stop_ttfb_metrics() async for chunk in self._async_generator_wrapper(response): + # todo: usage try: - text = chunk.text - await self.push_frame(TextFrame(text)) + for c in chunk.parts: + if c.text: + await self.push_frame(TextFrame(c.text)) + elif c.function_call: + args = type(c.function_call).to_dict(c.function_call).get("args", {}) + await self.call_function( + context=context, + tool_call_id="what_should_this_be", + function_name=c.function_call.name, + arguments=args, + ) except Exception as e: # Google LLMs seem to flag safety issues a lot! if chunk.candidates[0].finish_reason == 3: @@ -132,10 +296,11 @@ class GoogleLLMService(LLMService): context = None if isinstance(frame, OpenAILLMContextFrame): - context: OpenAILLMContext = frame.context + context: GoogleLLMContext = GoogleLLMContext.upgrade_to_google(frame.context) elif isinstance(frame, LLMMessagesFrame): - context = OpenAILLMContext.from_messages(frame.messages) + context = GoogleLLMContext(frame.messages) elif isinstance(frame, VisionImageRawFrame): + # todo: fix this context = OpenAILLMContext.from_image_frame(frame) elif isinstance(frame, LLMUpdateSettingsFrame): await self._update_settings(frame.settings) @@ -145,6 +310,16 @@ class GoogleLLMService(LLMService): if context: await self._process_context(context) + @staticmethod + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> GoogleContextAggregatorPair: + user = GoogleUserContextAggregator(context) + assistant = GoogleAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) + return GoogleContextAggregatorPair(_user=user, _assistant=assistant) + class GoogleTTSService(TTSService): class InputParams(BaseModel): From e032b0b70acf1f1cf49d2439c6428c53aec2d55b Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Wed, 23 Oct 2024 18:44:09 -0700 Subject: [PATCH 2/2] gemini context aggregators --- .../14e-function-calling-gemini.py | 16 +- .../20d-persistent-context-gemini.py | 290 ++++++++++++++++++ .../aggregators/openai_llm_context.py | 6 + src/pipecat/services/google.py | 136 +++++++- 4 files changed, 430 insertions(+), 18 deletions(-) create mode 100644 examples/foundational/20d-persistent-context-gemini.py diff --git a/examples/foundational/14e-function-calling-gemini.py b/examples/foundational/14e-function-calling-gemini.py index ed1b904ce..c124e4c2e 100644 --- a/examples/foundational/14e-function-calling-gemini.py +++ b/examples/foundational/14e-function-calling-gemini.py @@ -89,7 +89,21 @@ async def main(): }, "required": ["location", "format"], }, - } + }, + { + "name": "get_image", + "description": "Get and image from the camera or video stream.", + "parameters": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question to to use when running inference on the acquired image.", + }, + }, + "required": ["question"], + }, + }, ] } ] diff --git a/examples/foundational/20d-persistent-context-gemini.py b/examples/foundational/20d-persistent-context-gemini.py new file mode 100644 index 000000000..96abd8a41 --- /dev/null +++ b/examples/foundational/20d-persistent-context-gemini.py @@ -0,0 +1,290 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import glob +import json +import os +import sys +from datetime import datetime + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, +) +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.google import GoogleLLMService + +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + +video_participant_id = None + + +BASE_FILENAME = "/tmp/pipecat_conversation_" +tts = None + + +async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback): + temperature = 75 if args["format"] == "fahrenheit" else 24 + await result_callback( + { + "conditions": "nice", + "temperature": temperature, + "format": args["format"], + "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), + } + ) + + +async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback): + question = arguments["question"] + await llm.request_image_frame(user_id=video_participant_id, text_content=question) + + +async def get_saved_conversation_filenames( + function_name, tool_call_id, args, llm, context, result_callback +): + # Construct the full pattern including the BASE_FILENAME + full_pattern = f"{BASE_FILENAME}*.json" + + # Use glob to find all matching files + matching_files = glob.glob(full_pattern) + logger.debug(f"matching files: {matching_files}") + + await result_callback({"filenames": matching_files}) + + +async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback): + timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + filename = f"{BASE_FILENAME}{timestamp}.json" + logger.debug( + f"writing conversation to {filename}\n{json.dumps(context.get_messages_for_logging(), indent=4)}" + ) + try: + with open(filename, "w") as file: + # todo: extract 'system' into the first message in the list + messages = context.get_messages_for_persistent_storage() + # remove the last message (the instruction to save the context) + messages.pop() + json.dump(messages, file, indent=2) + await result_callback({"success": True}) + except Exception as e: + logger.debug(f"error saving conversation: {e}") + await result_callback({"success": False, "error": str(e)}) + + +async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback): + global tts + filename = args["filename"] + logger.debug(f"loading conversation from {filename}") + try: + with open(filename, "r") as file: + context.set_messages(json.load(file)) + await result_callback( + { + "success": True, + "message": "The most recent conversation has been loaded. Awaiting further instructions.", + } + ) + except Exception as e: + await result_callback({"success": False, "error": str(e)}) + + +# Test message munging ... +messages = [ + { + "role": "system", + "content": """You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your +capabilities in a succinct way. Your output will be converted to audio so don't include special +characters in your answers. Respond to what the user said in a creative and helpful way. + +You have several tools you can use to help you. + +You can respond to questions about the weather using the get_weather tool. + +You can save the current conversation using the save_conversation tool. This tool allows you to save +the current conversation to external storage. If the user asks you to save the conversation, use this +save_conversation too. + +You can load a saved conversation using the load_conversation tool. This tool allows you to load a +conversation from external storage. You can get a list of conversations that have been saved using the +get_saved_conversation_filenames tool. + +You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \ +indicate you should use the get_image tool are: + - What do you see? + - What's in the video? + - Can you describe the video? + - Tell me about what you see. + - Tell me something interesting about what you see. + - What's happening in the video? + """, + }, + # {"role": "user", "content": ""}, + # {"role": "assistant", "content": []}, + # {"role": "user", "content": "Tell me"}, + # {"role": "user", "content": "a joke"}, +] +tools = [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + { + "name": "save_conversation", + "description": "Save the current conversation. Use this function to persist the current conversation to external storage.", + "parameters": { + "type": "object", + "properties": { + "user_request_text": { + "type": "string", + "description": "The text of the user's request to save the conversation.", + } + }, + "required": ["user_request_text"], + }, + }, + { + "name": "get_saved_conversation_filenames", + "description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.", + "parameters": None, + }, + { + "name": "load_conversation", + "description": "Load a conversation history. Use this function to load a conversation history into the current session.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The filename of the conversation history to load.", + } + }, + "required": ["filename"], + }, + }, + { + "name": "get_image", + "description": "Get and image from the camera or video stream.", + "parameters": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question to to use when running inference on the acquired image.", + }, + }, + "required": ["question"], + }, + }, + ] + }, +] + + +async def main(): + global tts + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)), + ), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY")) + + # you can either register a single function for all function calls, or specific functions + # llm.register_function(None, fetch_weather_from_api) + llm.register_function("get_current_weather", fetch_weather_from_api) + llm.register_function("save_conversation", save_conversation) + llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames) + llm.register_function("load_conversation", load_conversation) + llm.register_function("get_image", get_image) + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + context_aggregator.user(), + llm, # LLM + tts, + context_aggregator.assistant(), + transport.output(), # Transport bot output + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + # report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + global video_participant_id + video_participant_id = participant["id"] + transport.capture_participant_transcription(participant["id"]) + transport.capture_participant_video(video_participant_id, framerate=0) + # Kick off the conversation. + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index 0f3b5cf59..d70f0e25b 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -70,6 +70,8 @@ class OpenAILLMContext: context.add_message(message) return context + # todo: deprecate from_image_frame. It's only used to create a single-use + # context, which isn't useful for most real-world applications. @staticmethod def from_image_frame(frame: VisionImageRawFrame) -> "OpenAILLMContext": """ @@ -77,6 +79,10 @@ class OpenAILLMContext: expects images to be base64 encoded, but other vision models may not. So we'll store the image as bytes and do the base64 encoding as needed in the LLM service. + + NOTE: the above only applies to the deprecated use of this method. The + add_image_frame_message() below does the base64 encoding as expected + in the OpenAI format. """ context = OpenAILLMContext() buffer = io.BytesIO() diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index 9a21f8f13..2b7d380eb 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -5,6 +5,7 @@ # import asyncio +import base64 from dataclasses import dataclass import json import io @@ -56,7 +57,9 @@ except ModuleNotFoundError as e: class GoogleUserContextAggregator(OpenAIUserContextAggregator): async def _push_aggregation(self): if len(self._aggregation) > 0: - self._context.add_message({"role": "user", "parts": [glm.Part(text=self._aggregation)]}) + self._context.add_message( + glm.Content(role="user", parts=[glm.Part(text=self._aggregation)]) + ) # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. @@ -88,35 +91,37 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): if frame.result: logger.debug(f"FunctionCallResultFrame result: {frame.arguments}") self._context.add_message( - { - "role": "model", - "parts": [ + glm.Content( + role="model", + parts=[ glm.Part( function_call=glm.FunctionCall( name=frame.function_name, args=frame.arguments ) ) ], - } + ) ) response = frame.result if isinstance(response, str): response = {"response": response} self._context.add_message( - { - "role": "user", - "parts": [ + glm.Content( + role="user", + parts=[ glm.Part( function_response=glm.FunctionResponse( name=frame.function_name, response=response ) ) ], - } + ) ) run_llm = not bool(self._function_calls_in_progress) else: - self._context.add_message({"role": "model", "parts": [glm.Part(text=aggregation)]}) + self._context.add_message( + glm.Content(role="model", parts=[glm.Part(text=aggregation)]) + ) if self._pending_image_frame_message: frame = self._pending_image_frame_message @@ -160,21 +165,70 @@ class GoogleLLMContext(OpenAILLMContext): obj._restructure_from_openai_messages() return obj + def set_messages(self, messages: List): + self._messages[:] = messages + self._restructure_from_openai_messages() + + def get_messages_for_logging(self): + msgs = [] + for message in self.messages: + obj = glm.Content.to_dict(message) + try: + if "parts" in obj: + for part in obj["parts"]: + if "inline_data" in part: + part["inline_data"]["data"] = "..." + except Exception as e: + logger.debug(f"Error: {e}") + msgs.append(obj) + return msgs + def from_standard_message(self, message): role = message["role"] - content = message["content"] + content = message.get("content", []) if role == "system": role = "user" elif role == "assistant": role = "model" parts = [] - if isinstance(content, str): + if message.get("tool_calls"): + for tc in message["tool_calls"]: + parts.append( + glm.Part( + function_call=glm.FunctionCall( + name=tc["function"]["name"], + args=json.loads(tc["function"]["arguments"]), + ) + ) + ) + elif role == "tool": + role = "model" + parts.append( + glm.Part( + function_response=glm.FunctionResponse( + name="tool_call_result", # seems to work to hard-code the same name every time + response=json.loads(message["content"]), + ) + ) + ) + elif isinstance(content, str): parts.append(glm.Part(text=content)) elif isinstance(content, list): - logger.debug("!!!NEED TO IMPL CONTENT LIST") + for c in content: + if c["type"] == "text": + parts.append(glm.Part(text=c["text"])) + elif c["type"] == "image_url": + parts.append( + glm.Part( + inline_data=glm.Blob( + mime_type="image/jpeg", + data=base64.b64decode(c["image_url"]["url"].split(",")[1]), + ) + ) + ) - message = {"role": role, "parts": parts} + message = glm.Content(role=role, parts=parts) return message def add_image_frame_message( @@ -189,10 +243,58 @@ class GoogleLLMContext(OpenAILLMContext): parts.append( glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())), ) - self.add_message({"role": "user", "parts": parts}) + self.add_message(glm.Content(role="user", parts=parts)) + + def to_standard_messages(self, obj) -> list: + msg = {"role": obj.role, "content": []} + if msg["role"] == "model": + msg["role"] = "assistant" + + for part in obj.parts: + if part.text: + msg["content"].append({"type": "text", "text": part.text}) + elif part.inline_data: + encoded = base64.b64encode(part.inline_data.data).decode("utf-8") + msg["content"].append( + { + "type": "image_url", + "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"}, + } + ) + elif part.function_call: + args = type(part.function_call).to_dict(part.function_call).get("args", {}) + msg["tool_calls"] = [ + { + "id": part.function_call.name, + "type": "function", + "function": { + "name": part.function_call.name, + "arguments": json.dumps(args), + }, + } + ] + + elif part.function_response: + msg["role"] = "tool" + resp = ( + type(part.function_response).to_dict(part.function_response).get("response", {}) + ) + msg["tool_call_id"] = part.function_response.name + msg["content"] = json.dumps(resp) + + # there might be no content parts for tool_calls messages + if not msg["content"]: + del msg["content"] + return [msg] def _restructure_from_openai_messages(self): - self._messages[:] = [self.from_standard_message(m) for m in self._messages] + # first, map across self._messages calling self.from_standard_message(m) to modify messages in place + try: + self._messages[:] = [self.from_standard_message(m) for m in self._messages] + except Exception as e: + logger.error(f"Error mapping messages: {e}") + # iterate over messages and remove any messages that have an empty content list + self._messages = [m for m in self._messages if m.parts] class GoogleLLMService(LLMService): @@ -248,7 +350,7 @@ class GoogleLLMService(LLMService): async def _process_context(self, context: OpenAILLMContext): await self.push_frame(LLMFullResponseStartFrame()) try: - logger.debug(f"Generating chat: {context.messages}") + logger.debug(f"Generating chat: {context.get_messages_for_logging()}") # todo: move this into the new context code structure, convert from openai context one time # todo: add system instructions