Compare commits

...

2 Commits

Author SHA1 Message Date
Kwindla Hultman Kramer
6c58a5b024 tiny bit more cleanup 2024-09-07 12:38:17 -07:00
Kwindla Hultman Kramer
37bbb687de function calling fixes for together/llama-3.1 2024-09-07 12:05:16 -07:00
2 changed files with 67 additions and 19 deletions

View File

@@ -43,6 +43,17 @@ async def get_current_weather(
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
async def save_checkpoint(
function_name,
tool_call_id,
arguments,
llm,
context,
result_callback):
logger.debug("IN save_checkpoint")
await result_callback({"status": "success"})
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
@@ -69,7 +80,9 @@ async def main():
model=os.getenv("TOGETHER_MODEL"),
)
llm.register_function("get_current_weather", get_current_weather)
llm.register_function("save_checkpoint", save_checkpoint)
# standard function call that's in all the LLM docs!
weatherTool = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
@@ -85,12 +98,26 @@ async def main():
},
}
# a function to test function calls with no arguments
saveCheckpoint = {
"name": "save_checkpoint",
"description": "Save the current state of the conversation",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
}
system_prompt = f"""\
You have access to the following functions:
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
{json.dumps(weatherTool)}
Use the function '{saveCheckpoint["name"]}' to '{saveCheckpoint["description"]}':
{json.dumps(saveCheckpoint)}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>

View File

@@ -18,9 +18,7 @@ from pipecat.frames.frames import (
Frame,
LLMModelUpdateFrame,
TextFrame,
VisionImageRawFrame,
UserImageRequestFrame,
UserImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
stream=True,
)
# Function calling
got_first_chunk = False
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
# text or a function call. However, occasionally we see a function call after plain text.
# Try to account for that.
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
accumulating_function_call = False
function_call_accumulator = ""
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
if accumulating_function_call:
function_call_accumulator += chunk.choices[0].delta.content
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
text = chunk.choices[0].delta.content
if most_recent_chunk_was_function_call_start_char:
most_recent_chunk_was_function_call_start_char = False
if text == "function":
accumulating_function_call = True
function_call_accumulator = "<function"
else:
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
elif text == '<':
most_recent_chunk_was_function_call_start_char = True
else:
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
if chunk.choices[0].finish_reason:
if accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)
elif most_recent_chunk_was_function_call_start_char:
await self.push_frame(TextFrame("<"))
except CancelledError as e:
# todo: implement token counting estimates for use when the user interrupts a long generation
@@ -164,14 +180,27 @@ class TogetherLLMService(LLMService):
await self._process_context(context)
async def _extract_function_call(self, context, function_call_accumulator):
# logger.debug(f"Extracting function call: {function_call_accumulator}")
context.add_message({"role": "assistant", "content": function_call_accumulator})
function_regex = r"<function=(\w+)>(.*?)</function>"
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
match = re.search(function_regex, function_call_accumulator)
if match:
function_name, args_string = match.groups()
function_name = ""
args_string = ""
if match.group(1): # Case with closing tag
function_name = match.group(1)
args_string = match.group(2)
else: # Case without closing tag
function_name = match.group(3)
args_string = match.group(4)
try:
arguments = json.loads(args_string)
args_string = re.sub(r'[\s"]+$', '', args_string)
arguments = json.loads(args_string) if args_string else ""
await self.call_function(context=context,
tool_call_id=str(uuid.uuid4()),
function_name=function_name,
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
# Should we do anything more than log a warning?
logger.debug(f"Error parsing function arguments: {error}")
logger.debug(
f"Error parsing function arguments: {error} - {function_call_accumulator}")
class TogetherLLMContext(OpenAILLMContext):
@@ -247,15 +277,6 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
except Exception as e:
logger.error(f"Error processing frame: {e}")
#
# Claude returns a text content block along with a tool use content block. This works quite nicely
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
#
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
# chattiness about it's tool thinking.
#
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):