tools frame support and wip message resetting/loading

This commit is contained in:
Kwindla Hultman Kramer
2024-10-09 11:03:53 -07:00
parent df2ddb4b91
commit 07124bfafc
2 changed files with 166 additions and 24 deletions

View File

@@ -7,6 +7,7 @@
import asyncio
import json
import os
import re
import sys
from datetime import datetime
@@ -15,6 +16,7 @@ from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import LLMMessagesUpdateFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -37,18 +39,34 @@ logger.add(sys.stderr, level="DEBUG")
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
temperature = 75 if args["format"] == "fahrenheit" else 24
await result_callback(
{
"conditions": "nice",
"temperature": "75",
"temperature": temperature,
"format": args["format"],
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
}
)
async def get_saved_conversation_filenames(
function_name, tool_call_id, args, llm, context, result_callback
):
pattern = re.compile("example_19_\\d{8}_\\d{6}\\.json$")
matching_files = []
for filename in os.listdir("."):
if pattern.match(filename):
matching_files.append(filename)
await result_callback({"filenames": matching_files})
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"example_19_{timestamp}.json"
logger.debug(f"writing conversation to {filename}\n{json.dumps(context.messages, indent=4)}")
try:
with open(filename, "w") as file:
json.dump(context.messages, file, indent=4)
@@ -57,6 +75,18 @@ async def save_conversation(function_name, tool_call_id, args, llm, context, res
await result_callback({"success": False, "error": str(e)})
async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback):
filename = args["filename"]
logger.debug(f"loading conversation from {filename}")
try:
with open(filename, "r") as file:
messages = json.load(file)
await result_callback({"success": True})
await llm.push_frame(LLMMessagesUpdateFrame(messages))
except Exception as e:
await result_callback({"success": False, "error": str(e)})
tools = [
{
"type": "function",
@@ -88,6 +118,31 @@ tools = [
"required": [],
},
},
{
"type": "function",
"name": "get_saved_conversation_filenames",
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a timestamp. Each file is conversation history that can be loaded into this session.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
{
"type": "function",
"name": "load_conversation",
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The filename of the conversation history to load.",
}
},
"required": ["filename"],
},
},
]
@@ -118,7 +173,7 @@ async def main():
# turn_detection=TurnDetection(silence_duration_ms=1000),
# Or set to False to disable openai turn detection and use transport VAD
turn_detection=False,
tools=tools,
# tools=tools,
instructions="""
Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
@@ -145,10 +200,14 @@ Remember, your responses should be short. Just one or two sentences, usually.
# llm.register_function(None, fetch_weather_from_api)
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("save_conversation", save_conversation)
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
llm.register_function("load_conversation", load_conversation)
context = OpenAILLMContext(
# [{"role": "user", "content": "What's the weather right now in San Francisco?"}], tools
[{"role": "user", "content": "Say 'hello'."}],
# [{"role": "user", "content": "What's the weather right now in San Francisco?"}],
# conversation load from file is a WIP -- not functional yet
# [{"role": "user", "content": "Load the most recent conversation."}],
tools,
)
context_aggregator = llm.create_context_aggregator(context)

View File

@@ -4,6 +4,7 @@ import json
import traceback
from copy import deepcopy
from dataclasses import dataclass
from typing import List
import websockets
from loguru import logger
@@ -18,6 +19,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
LLMUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
@@ -78,6 +80,9 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
# "conversation items" that have been created by opeanai realtime api events but are
# not completely filled in, yet. map from item_id to message
self._messages_in_progress = {}
# count of messages prior to recent reset
self._messages_reset_count = 0
self._tools_list_updated = True
@staticmethod
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
@@ -86,30 +91,48 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
obj.__setup_local()
return obj
# cases to handle
# - tools in the context constructor (and in general?)
# - relatedly, set tools frame
# - clearing the context by deleting all messages (for scripted conversations)
# still working on
# - clearing the context by deleting all messages
# - reloading from a standard messages list
# - truncating the last spoken message to maintain context when interrupted
def set_tools(self, tools: List):
super().set_tools(tools)
self._tools_list_updated = True
def add_message(self, message):
super().add_message(message)
self._unsent_messages.append(message)
return message
def add_messages(self, messages):
super().add_messages(messages)
self._unsent_messages.extend(messages)
def add_message_already_present_in_api_context(self, message):
super().add_message(message)
return message
def set_messages(self, messages):
self._messages_reset_count = len(self.messages) - len(self._unsent_messages)
super().set_messages(messages)
self._unsent_messages = deepcopy(self._messages)
def get_unsent_messages(self):
return self._unsent_messages
def get_messages_reset_count(self):
return self._messages_reset_count
def get_tools_list_updated(self):
return self._tools_list_updated
def update_all_messages_sent(self):
self._unsent_messages = []
self._messages_reset_count = 0
def update_tools_list_sent(self):
self._tools_list_updated = False
def note_manually_added_message(self, item_id):
self._manually_created_messages[item_id] = True
@@ -163,6 +186,10 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
if isinstance(frame, LLMMessagesUpdateFrame):
await self.push_frame(_InternalMessagesUpdateFrame(context=self._context))
# Parent also doesn't push the LLMSetToolsFrame.
if isinstance(frame, LLMSetToolsFrame):
await self.push_frame(frame, direction)
async def _push_aggregation(self):
# for the moment, ignore all user input coming into the pipeline.
# todo: think about whether/how to fix this to allow for text input from
@@ -232,7 +259,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
self.api_key = api_key
self.base_url = base_url
self._session_properties = session_properties
self._session_properties: events.SessionProperties = session_properties
self._audio_input_paused = start_audio_paused
self._send_transcription_frames = send_transcription_frames
# todo: wire _send_user_started_speaking_frames up correctly
@@ -304,7 +331,12 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
return self._websocket
raise Exception("Websocket not connected")
async def _update_settings(self, settings: events.SessionProperties):
async def _update_settings(self):
settings = self._session_properties
# tools given in the context override the tools in the session properties
if self._context and self._context.tools:
settings.tools = self._context.tools
self._context.update_tools_list_sent()
await self.send_client_event(events.SessionUpdateEvent(session=settings))
async def _receive_task_handler(self):
@@ -315,7 +347,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
if evt.type == "session.created":
# session.created is received right after connecting. send a message
# to configure the session properties.
await self._update_settings(self._session_properties)
await self._update_settings()
elif evt.type == "session.updated":
self._session_properties = evt.session
elif evt.type == "input_audio_buffer.speech_started":
@@ -445,9 +477,12 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)
async def _reset_conversation(self):
async def _reset_conversation(self, count):
# need to think about how to implement this, and how to think about interop with messages lists
# used with the HTTP API
logger.debug(f"!!! RESET CONVERSATION: {count} [WIP]")
await self._disconnect()
await self._connect()
pass
async def _send_messages_context_update(self):
@@ -455,28 +490,70 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
return
context = self._context
messages = context.get_unsent_messages()
needs_reset = context.get_messages_reset_count()
context.update_all_messages_sent()
if needs_reset:
await self._reset_conversation(needs_reset)
# debugging
logger.debug("MESSAGE HISTORY RELOAD NOT IMPLEMENTED YET")
return
items = []
for m in messages:
if m and (m.get("role") == "user" or m.get("role") == "system"):
if m and (
m.get("role") == "user" or m.get("role") == "system" or m.get("role") == "assistant"
):
content = m.get("content")
if isinstance(content, str):
items.append(
events.ConversationItem(
type="message",
status="completed",
role="user",
content=[events.ItemContent(type="input_text", text=content)],
# skip any messages that aren't "text" and change "user" message type to "input_text"
if m.get("type", "text") == "text":
items.append(
events.ConversationItem(
type="message",
status="completed",
role=m.get("role", "user"),
content=[
events.ItemContent(
type="input_text" if m.get("role") == "user" else "text",
text=content,
)
],
)
)
)
elif isinstance(content, list):
# skip any messages that aren't "text" and change "user" message type to "input_text"
cs = []
for item in content:
if item.get("type", "text") == "text":
# cs.append(events.ItemContent(type="input_text", text=item.get("text")))
(
cs.append(
events.ItemContent(
type="input_text" if m.get("role") == "user" else "text",
text=item.get("text"),
)
),
)
if cs:
items.append(
events.ConversationItem(
type="message",
status="completed",
role=m.get("role", "user"),
content=cs,
)
)
elif m.get("role") == "assistant" and m.get("tool_calls"):
tc = m.get("tool_calls")[0]
items.append(
events.ConversationItem(
type="message",
status="completed",
role="user",
content=content,
type="function_call",
call_id=tc["id"],
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
)
)
else:
@@ -489,11 +566,14 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
output=m["content"],
)
)
for item in items:
context.note_manually_added_message(item.id)
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
async def _create_response(self):
if self._context.get_tools_list_updated():
await self._update_settings()
await self._send_messages_context_update()
logger.debug(f"Creating response: {self._context.get_messages_for_logging()}")
await self.push_frame(LLMFullResponseStartFrame())
@@ -544,7 +624,10 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
self._context = frame.context
await self._send_messages_context_update()
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
self._session_properties = frame.settings
await self._update_settings()
elif isinstance(frame, LLMSetToolsFrame):
await self._update_settings()
await self.push_frame(frame, direction)