Compare commits
2 Commits
mb/static-
...
khk/togeth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c58a5b024 | ||
|
|
37bbb687de |
@@ -43,6 +43,17 @@ async def get_current_weather(
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def save_checkpoint(
|
||||
function_name,
|
||||
tool_call_id,
|
||||
arguments,
|
||||
llm,
|
||||
context,
|
||||
result_callback):
|
||||
logger.debug("IN save_checkpoint")
|
||||
await result_callback({"status": "success"})
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
@@ -69,7 +80,9 @@ async def main():
|
||||
model=os.getenv("TOGETHER_MODEL"),
|
||||
)
|
||||
llm.register_function("get_current_weather", get_current_weather)
|
||||
llm.register_function("save_checkpoint", save_checkpoint)
|
||||
|
||||
# standard function call that's in all the LLM docs!
|
||||
weatherTool = {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
@@ -85,12 +98,26 @@ async def main():
|
||||
},
|
||||
}
|
||||
|
||||
# a function to test function calls with no arguments
|
||||
saveCheckpoint = {
|
||||
"name": "save_checkpoint",
|
||||
"description": "Save the current state of the conversation",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
system_prompt = f"""\
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
|
||||
{json.dumps(weatherTool)}
|
||||
|
||||
Use the function '{saveCheckpoint["name"]}' to '{saveCheckpoint["description"]}':
|
||||
{json.dumps(saveCheckpoint)}
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
|
||||
|
||||
@@ -18,9 +18,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMModelUpdateFrame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
UserImageRawFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Function calling
|
||||
got_first_chunk = False
|
||||
|
||||
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
|
||||
# text or a function call. However, occasionally we see a function call after plain text.
|
||||
# Try to account for that.
|
||||
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
|
||||
accumulating_function_call = False
|
||||
function_call_accumulator = ""
|
||||
|
||||
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
|
||||
if accumulating_function_call:
|
||||
function_call_accumulator += chunk.choices[0].delta.content
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
text = chunk.choices[0].delta.content
|
||||
if most_recent_chunk_was_function_call_start_char:
|
||||
most_recent_chunk_was_function_call_start_char = False
|
||||
if text == "function":
|
||||
accumulating_function_call = True
|
||||
function_call_accumulator = "<function"
|
||||
else:
|
||||
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
|
||||
elif text == '<':
|
||||
most_recent_chunk_was_function_call_start_char = True
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
if chunk.choices[0].finish_reason:
|
||||
if accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
elif most_recent_chunk_was_function_call_start_char:
|
||||
await self.push_frame(TextFrame("<"))
|
||||
|
||||
except CancelledError as e:
|
||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||
@@ -164,14 +180,27 @@ class TogetherLLMService(LLMService):
|
||||
await self._process_context(context)
|
||||
|
||||
async def _extract_function_call(self, context, function_call_accumulator):
|
||||
# logger.debug(f"Extracting function call: {function_call_accumulator}")
|
||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
|
||||
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
|
||||
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
|
||||
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
|
||||
match = re.search(function_regex, function_call_accumulator)
|
||||
if match:
|
||||
function_name, args_string = match.groups()
|
||||
function_name = ""
|
||||
args_string = ""
|
||||
if match.group(1): # Case with closing tag
|
||||
function_name = match.group(1)
|
||||
args_string = match.group(2)
|
||||
else: # Case without closing tag
|
||||
function_name = match.group(3)
|
||||
args_string = match.group(4)
|
||||
|
||||
try:
|
||||
arguments = json.loads(args_string)
|
||||
args_string = re.sub(r'[\s"]+$', '', args_string)
|
||||
arguments = json.loads(args_string) if args_string else ""
|
||||
await self.call_function(context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
function_name=function_name,
|
||||
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
|
||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||
# Should we do anything more than log a warning?
|
||||
logger.debug(f"Error parsing function arguments: {error}")
|
||||
logger.debug(
|
||||
f"Error parsing function arguments: {error} - {function_call_accumulator}")
|
||||
|
||||
|
||||
class TogetherLLMContext(OpenAILLMContext):
|
||||
@@ -247,15 +277,6 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
#
|
||||
# Claude returns a text content block along with a tool use content block. This works quite nicely
|
||||
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
|
||||
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
|
||||
#
|
||||
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
|
||||
# chattiness about it's tool thinking.
|
||||
#
|
||||
|
||||
|
||||
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
||||
|
||||
Reference in New Issue
Block a user