Add custom assistant context aggregator for Grok due to content requirement in function calling

This commit is contained in:
Mark Backman
2024-12-17 09:11:21 -05:00
parent fe0a7d07bd
commit ca086a856f
3 changed files with 107 additions and 7 deletions

View File

@@ -65,7 +65,7 @@ async def main():
)
llm = NimLLMService(
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.3-70b-instruct"
)
# Register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
@@ -76,18 +76,18 @@ async def main():
type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"description": "Returns the current weather at a location, if one is specified, and defaults to the user's location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
"description": "The location to find the weather of, or if not provided, it's the default location.",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
"description": "Whether to use SI or USCS units (celsius or fahrenheit).",
},
},
"required": ["location", "format"],

View File

@@ -5,11 +5,102 @@
#
import json
from dataclasses import dataclass
from loguru import logger
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.services.openai import (
OpenAIAssistantContextAggregator,
OpenAILLMService,
OpenAIUserContextAggregator,
)
class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Custom assistant context aggregator for Grok that handles empty content requirement."""
async def _push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
return
run_llm = False
aggregation = self._aggregation
self._reset()
try:
if self._function_call_result:
frame = self._function_call_result
self._function_call_result = None
if frame.result:
# Grok requires an empty content field for function calls
self._context.add_message(
{
"role": "assistant",
"content": "", # Required by Grok
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": json.dumps(frame.result),
"tool_call_id": frame.tool_call_id,
}
)
# Only run the LLM if there are no more function calls in progress.
run_llm = not bool(self._function_calls_in_progress)
else:
self._context.add_message({"role": "assistant", "content": aggregation})
if self._pending_image_frame_message:
frame = self._pending_image_frame_message
self._pending_image_frame_message = None
self._context.add_image_frame_message(
format=frame.user_image_raw_frame.format,
size=frame.user_image_raw_frame.size,
image=frame.user_image_raw_frame.image,
text=frame.text,
)
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
except Exception as e:
logger.error(f"Error processing frame: {e}")
@dataclass
class GrokContextAggregatorPair:
_user: "OpenAIUserContextAggregator"
_assistant: "GrokAssistantContextAggregator"
def user(self) -> "OpenAIUserContextAggregator":
return self._user
def assistant(self) -> "GrokAssistantContextAggregator":
return self._assistant
class GrokLLMService(OpenAILLMService):
@@ -101,3 +192,13 @@ class GrokLLMService(OpenAILLMService):
# Update completion tokens count if it has increased
if tokens.completion_tokens > self._completion_tokens:
self._completion_tokens = tokens.completion_tokens
@staticmethod
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> GrokContextAggregatorPair:
user = OpenAIUserContextAggregator(context)
assistant = GrokAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return GrokContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -559,7 +559,6 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._context.add_message(
{
"role": "assistant",
"content": "", # content field required for Grok function calling
"tool_calls": [
{
"id": frame.tool_call_id,