Add custom assistant context aggregator for Grok due to content requirement in function calling
This commit is contained in:
@@ -65,7 +65,7 @@ async def main():
|
||||
)
|
||||
|
||||
llm = NimLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.3-70b-instruct"
|
||||
)
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
@@ -76,18 +76,18 @@ async def main():
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"description": "Returns the current weather at a location, if one is specified, and defaults to the user's location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
"description": "The location to find the weather of, or if not provided, it's the default location.",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
"description": "Whether to use SI or USCS units (celsius or fahrenheit).",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
|
||||
@@ -5,11 +5,102 @@
|
||||
#
|
||||
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.services.openai import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMService,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
|
||||
class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Custom assistant context aggregator for Grok that handles empty content requirement."""
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
self._function_call_result = None
|
||||
if frame.result:
|
||||
# Grok requires an empty content field for function calls
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "", # Required by Grok
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps(frame.result),
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
# Only run the LLM if there are no more function calls in progress.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
self._pending_image_frame_message = None
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.user_image_raw_frame.format,
|
||||
size=frame.user_image_raw_frame.size,
|
||||
image=frame.user_image_raw_frame.image,
|
||||
text=frame.text,
|
||||
)
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrokContextAggregatorPair:
|
||||
_user: "OpenAIUserContextAggregator"
|
||||
_assistant: "GrokAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "GrokAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class GrokLLMService(OpenAILLMService):
|
||||
@@ -101,3 +192,13 @@ class GrokLLMService(OpenAILLMService):
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> GrokContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = GrokAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return GrokContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -559,7 +559,6 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "", # content field required for Grok function calling
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
|
||||
Reference in New Issue
Block a user