Compare commits
2 Commits
hush/realt
...
khk/togeth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c58a5b024 | ||
|
|
37bbb687de |
@@ -43,6 +43,17 @@ async def get_current_weather(
|
|||||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||||
|
|
||||||
|
|
||||||
|
async def save_checkpoint(
|
||||||
|
function_name,
|
||||||
|
tool_call_id,
|
||||||
|
arguments,
|
||||||
|
llm,
|
||||||
|
context,
|
||||||
|
result_callback):
|
||||||
|
logger.debug("IN save_checkpoint")
|
||||||
|
await result_callback({"status": "success"})
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
(room_url, token) = await configure(session)
|
(room_url, token) = await configure(session)
|
||||||
@@ -69,7 +80,9 @@ async def main():
|
|||||||
model=os.getenv("TOGETHER_MODEL"),
|
model=os.getenv("TOGETHER_MODEL"),
|
||||||
)
|
)
|
||||||
llm.register_function("get_current_weather", get_current_weather)
|
llm.register_function("get_current_weather", get_current_weather)
|
||||||
|
llm.register_function("save_checkpoint", save_checkpoint)
|
||||||
|
|
||||||
|
# standard function call that's in all the LLM docs!
|
||||||
weatherTool = {
|
weatherTool = {
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"description": "Get the current weather in a given location",
|
"description": "Get the current weather in a given location",
|
||||||
@@ -85,12 +98,26 @@ async def main():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# a function to test function calls with no arguments
|
||||||
|
saveCheckpoint = {
|
||||||
|
"name": "save_checkpoint",
|
||||||
|
"description": "Save the current state of the conversation",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
system_prompt = f"""\
|
system_prompt = f"""\
|
||||||
You have access to the following functions:
|
You have access to the following functions:
|
||||||
|
|
||||||
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
|
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
|
||||||
{json.dumps(weatherTool)}
|
{json.dumps(weatherTool)}
|
||||||
|
|
||||||
|
Use the function '{saveCheckpoint["name"]}' to '{saveCheckpoint["description"]}':
|
||||||
|
{json.dumps(saveCheckpoint)}
|
||||||
|
|
||||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
|
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ from pipecat.frames.frames import (
|
|||||||
Frame,
|
Frame,
|
||||||
LLMModelUpdateFrame,
|
LLMModelUpdateFrame,
|
||||||
TextFrame,
|
TextFrame,
|
||||||
VisionImageRawFrame,
|
|
||||||
UserImageRequestFrame,
|
UserImageRequestFrame,
|
||||||
UserImageRawFrame,
|
|
||||||
LLMMessagesFrame,
|
LLMMessagesFrame,
|
||||||
LLMFullResponseStartFrame,
|
LLMFullResponseStartFrame,
|
||||||
LLMFullResponseEndFrame,
|
LLMFullResponseEndFrame,
|
||||||
@@ -100,8 +98,12 @@ class TogetherLLMService(LLMService):
|
|||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Function calling
|
|
||||||
got_first_chunk = False
|
got_first_chunk = False
|
||||||
|
|
||||||
|
# Function calling. We should be able to prompt Llama 3.1 to always return either plain
|
||||||
|
# text or a function call. However, occasionally we see a function call after plain text.
|
||||||
|
# Try to account for that.
|
||||||
|
most_recent_chunk_was_function_call_start_char = False # function call start char is '<'
|
||||||
accumulating_function_call = False
|
accumulating_function_call = False
|
||||||
function_call_accumulator = ""
|
function_call_accumulator = ""
|
||||||
|
|
||||||
@@ -131,10 +133,24 @@ class TogetherLLMService(LLMService):
|
|||||||
if accumulating_function_call:
|
if accumulating_function_call:
|
||||||
function_call_accumulator += chunk.choices[0].delta.content
|
function_call_accumulator += chunk.choices[0].delta.content
|
||||||
else:
|
else:
|
||||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
text = chunk.choices[0].delta.content
|
||||||
|
if most_recent_chunk_was_function_call_start_char:
|
||||||
|
most_recent_chunk_was_function_call_start_char = False
|
||||||
|
if text == "function":
|
||||||
|
accumulating_function_call = True
|
||||||
|
function_call_accumulator = "<function"
|
||||||
|
else:
|
||||||
|
await self.push_frame("<" + TextFrame(chunk.choices[0].delta.content))
|
||||||
|
elif text == '<':
|
||||||
|
most_recent_chunk_was_function_call_start_char = True
|
||||||
|
else:
|
||||||
|
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||||
|
|
||||||
if chunk.choices[0].finish_reason == 'eos' and accumulating_function_call:
|
if chunk.choices[0].finish_reason:
|
||||||
await self._extract_function_call(context, function_call_accumulator)
|
if accumulating_function_call:
|
||||||
|
await self._extract_function_call(context, function_call_accumulator)
|
||||||
|
elif most_recent_chunk_was_function_call_start_char:
|
||||||
|
await self.push_frame(TextFrame("<"))
|
||||||
|
|
||||||
except CancelledError as e:
|
except CancelledError as e:
|
||||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||||
@@ -164,14 +180,27 @@ class TogetherLLMService(LLMService):
|
|||||||
await self._process_context(context)
|
await self._process_context(context)
|
||||||
|
|
||||||
async def _extract_function_call(self, context, function_call_accumulator):
|
async def _extract_function_call(self, context, function_call_accumulator):
|
||||||
|
# logger.debug(f"Extracting function call: {function_call_accumulator}")
|
||||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||||
|
|
||||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
# Function format regex. Llama 3.1 sometimes adds an extra " or space just before the
|
||||||
|
# </function> tag. This regexp just ignores the extra characters if they are there. (That's
|
||||||
|
# the [\s"]? part of the regex.) Occasionally the </function> close tag is also missing.
|
||||||
|
function_regex = r'<function=(\w+)>(.*?)<\/function>|<function=(\w+)>(.*)'
|
||||||
match = re.search(function_regex, function_call_accumulator)
|
match = re.search(function_regex, function_call_accumulator)
|
||||||
if match:
|
if match:
|
||||||
function_name, args_string = match.groups()
|
function_name = ""
|
||||||
|
args_string = ""
|
||||||
|
if match.group(1): # Case with closing tag
|
||||||
|
function_name = match.group(1)
|
||||||
|
args_string = match.group(2)
|
||||||
|
else: # Case without closing tag
|
||||||
|
function_name = match.group(3)
|
||||||
|
args_string = match.group(4)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(args_string)
|
args_string = re.sub(r'[\s"]+$', '', args_string)
|
||||||
|
arguments = json.loads(args_string) if args_string else ""
|
||||||
await self.call_function(context=context,
|
await self.call_function(context=context,
|
||||||
tool_call_id=str(uuid.uuid4()),
|
tool_call_id=str(uuid.uuid4()),
|
||||||
function_name=function_name,
|
function_name=function_name,
|
||||||
@@ -181,7 +210,8 @@ class TogetherLLMService(LLMService):
|
|||||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||||
# Should we do anything more than log a warning?
|
# Should we do anything more than log a warning?
|
||||||
logger.debug(f"Error parsing function arguments: {error}")
|
logger.debug(
|
||||||
|
f"Error parsing function arguments: {error} - {function_call_accumulator}")
|
||||||
|
|
||||||
|
|
||||||
class TogetherLLMContext(OpenAILLMContext):
|
class TogetherLLMContext(OpenAILLMContext):
|
||||||
@@ -247,15 +277,6 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing frame: {e}")
|
logger.error(f"Error processing frame: {e}")
|
||||||
|
|
||||||
#
|
|
||||||
# Claude returns a text content block along with a tool use content block. This works quite nicely
|
|
||||||
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
|
|
||||||
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
|
|
||||||
#
|
|
||||||
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
|
|
||||||
# chattiness about it's tool thinking.
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
||||||
|
|||||||
Reference in New Issue
Block a user