tools frame support and wip message resetting/loading
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
@@ -15,6 +16,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesUpdateFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -37,18 +39,34 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": "75",
|
||||
"temperature": temperature,
|
||||
"format": args["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(
|
||||
function_name, tool_call_id, args, llm, context, result_callback
|
||||
):
|
||||
pattern = re.compile("example_19_\\d{8}_\\d{6}\\.json$")
|
||||
matching_files = []
|
||||
|
||||
for filename in os.listdir("."):
|
||||
if pattern.match(filename):
|
||||
matching_files.append(filename)
|
||||
|
||||
await result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"example_19_{timestamp}.json"
|
||||
logger.debug(f"writing conversation to {filename}\n{json.dumps(context.messages, indent=4)}")
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
json.dump(context.messages, file, indent=4)
|
||||
@@ -57,6 +75,18 @@ async def save_conversation(function_name, tool_call_id, args, llm, context, res
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
filename = args["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
messages = json.load(file)
|
||||
await result_callback({"success": True})
|
||||
await llm.push_frame(LLMMessagesUpdateFrame(messages))
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -88,6 +118,31 @@ tools = [
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_saved_conversation_filenames",
|
||||
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "load_conversation",
|
||||
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -118,7 +173,7 @@ async def main():
|
||||
# turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
turn_detection=False,
|
||||
tools=tools,
|
||||
# tools=tools,
|
||||
instructions="""
|
||||
Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||
|
||||
@@ -145,10 +200,14 @@ Remember, your responses should be short. Just one or two sentences, usually.
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
# [{"role": "user", "content": "What's the weather right now in San Francisco?"}], tools
|
||||
[{"role": "user", "content": "Say 'hello'."}],
|
||||
# [{"role": "user", "content": "What's the weather right now in San Francisco?"}],
|
||||
# conversation load from file is a WIP -- not functional yet
|
||||
# [{"role": "user", "content": "Load the most recent conversation."}],
|
||||
tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
@@ -18,6 +19,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
@@ -78,6 +80,9 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
# "conversation items" that have been created by opeanai realtime api events but are
|
||||
# not completely filled in, yet. map from item_id to message
|
||||
self._messages_in_progress = {}
|
||||
# count of messages prior to recent reset
|
||||
self._messages_reset_count = 0
|
||||
self._tools_list_updated = True
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
@@ -86,30 +91,48 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# cases to handle
|
||||
# - tools in the context constructor (and in general?)
|
||||
# - relatedly, set tools frame
|
||||
# - clearing the context by deleting all messages (for scripted conversations)
|
||||
# still working on
|
||||
# - clearing the context by deleting all messages
|
||||
# - reloading from a standard messages list
|
||||
# - truncating the last spoken message to maintain context when interrupted
|
||||
|
||||
def set_tools(self, tools: List):
|
||||
super().set_tools(tools)
|
||||
self._tools_list_updated = True
|
||||
|
||||
def add_message(self, message):
|
||||
super().add_message(message)
|
||||
self._unsent_messages.append(message)
|
||||
return message
|
||||
|
||||
def add_messages(self, messages):
|
||||
super().add_messages(messages)
|
||||
self._unsent_messages.extend(messages)
|
||||
|
||||
def add_message_already_present_in_api_context(self, message):
|
||||
super().add_message(message)
|
||||
return message
|
||||
|
||||
def set_messages(self, messages):
|
||||
self._messages_reset_count = len(self.messages) - len(self._unsent_messages)
|
||||
super().set_messages(messages)
|
||||
self._unsent_messages = deepcopy(self._messages)
|
||||
|
||||
def get_unsent_messages(self):
|
||||
return self._unsent_messages
|
||||
|
||||
def get_messages_reset_count(self):
|
||||
return self._messages_reset_count
|
||||
|
||||
def get_tools_list_updated(self):
|
||||
return self._tools_list_updated
|
||||
|
||||
def update_all_messages_sent(self):
|
||||
self._unsent_messages = []
|
||||
self._messages_reset_count = 0
|
||||
|
||||
def update_tools_list_sent(self):
|
||||
self._tools_list_updated = False
|
||||
|
||||
def note_manually_added_message(self, item_id):
|
||||
self._manually_created_messages[item_id] = True
|
||||
@@ -163,6 +186,10 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(_InternalMessagesUpdateFrame(context=self._context))
|
||||
|
||||
# Parent also doesn't push the LLMSetToolsFrame.
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
@@ -232,7 +259,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
self._session_properties = session_properties
|
||||
self._session_properties: events.SessionProperties = session_properties
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._send_transcription_frames = send_transcription_frames
|
||||
# todo: wire _send_user_started_speaking_frames up correctly
|
||||
@@ -304,7 +331,12 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _update_settings(self, settings: events.SessionProperties):
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
# tools given in the context override the tools in the session properties
|
||||
if self._context and self._context.tools:
|
||||
settings.tools = self._context.tools
|
||||
self._context.update_tools_list_sent()
|
||||
await self.send_client_event(events.SessionUpdateEvent(session=settings))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
@@ -315,7 +347,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
if evt.type == "session.created":
|
||||
# session.created is received right after connecting. send a message
|
||||
# to configure the session properties.
|
||||
await self._update_settings(self._session_properties)
|
||||
await self._update_settings()
|
||||
elif evt.type == "session.updated":
|
||||
self._session_properties = evt.session
|
||||
elif evt.type == "input_audio_buffer.speech_started":
|
||||
@@ -445,9 +477,12 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def _reset_conversation(self):
|
||||
async def _reset_conversation(self, count):
|
||||
# need to think about how to implement this, and how to think about interop with messages lists
|
||||
# used with the HTTP API
|
||||
logger.debug(f"!!! RESET CONVERSATION: {count} [WIP]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
pass
|
||||
|
||||
async def _send_messages_context_update(self):
|
||||
@@ -455,28 +490,70 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
return
|
||||
context = self._context
|
||||
messages = context.get_unsent_messages()
|
||||
|
||||
needs_reset = context.get_messages_reset_count()
|
||||
context.update_all_messages_sent()
|
||||
|
||||
if needs_reset:
|
||||
await self._reset_conversation(needs_reset)
|
||||
# debugging
|
||||
logger.debug("MESSAGE HISTORY RELOAD NOT IMPLEMENTED YET")
|
||||
return
|
||||
|
||||
items = []
|
||||
for m in messages:
|
||||
if m and (m.get("role") == "user" or m.get("role") == "system"):
|
||||
if m and (
|
||||
m.get("role") == "user" or m.get("role") == "system" or m.get("role") == "assistant"
|
||||
):
|
||||
content = m.get("content")
|
||||
if isinstance(content, str):
|
||||
items.append(
|
||||
events.ConversationItem(
|
||||
type="message",
|
||||
status="completed",
|
||||
role="user",
|
||||
content=[events.ItemContent(type="input_text", text=content)],
|
||||
# skip any messages that aren't "text" and change "user" message type to "input_text"
|
||||
|
||||
if m.get("type", "text") == "text":
|
||||
items.append(
|
||||
events.ConversationItem(
|
||||
type="message",
|
||||
status="completed",
|
||||
role=m.get("role", "user"),
|
||||
content=[
|
||||
events.ItemContent(
|
||||
type="input_text" if m.get("role") == "user" else "text",
|
||||
text=content,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
# skip any messages that aren't "text" and change "user" message type to "input_text"
|
||||
cs = []
|
||||
for item in content:
|
||||
if item.get("type", "text") == "text":
|
||||
# cs.append(events.ItemContent(type="input_text", text=item.get("text")))
|
||||
(
|
||||
cs.append(
|
||||
events.ItemContent(
|
||||
type="input_text" if m.get("role") == "user" else "text",
|
||||
text=item.get("text"),
|
||||
)
|
||||
),
|
||||
)
|
||||
if cs:
|
||||
items.append(
|
||||
events.ConversationItem(
|
||||
type="message",
|
||||
status="completed",
|
||||
role=m.get("role", "user"),
|
||||
content=cs,
|
||||
)
|
||||
)
|
||||
elif m.get("role") == "assistant" and m.get("tool_calls"):
|
||||
tc = m.get("tool_calls")[0]
|
||||
items.append(
|
||||
events.ConversationItem(
|
||||
type="message",
|
||||
status="completed",
|
||||
role="user",
|
||||
content=content,
|
||||
type="function_call",
|
||||
call_id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -489,11 +566,14 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
output=m["content"],
|
||||
)
|
||||
)
|
||||
|
||||
for item in items:
|
||||
context.note_manually_added_message(item.id)
|
||||
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
|
||||
|
||||
async def _create_response(self):
|
||||
if self._context.get_tools_list_updated():
|
||||
await self._update_settings()
|
||||
await self._send_messages_context_update()
|
||||
logger.debug(f"Creating response: {self._context.get_messages_for_logging()}")
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
@@ -544,7 +624,10 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
self._context = frame.context
|
||||
await self._send_messages_context_update()
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
self._session_properties = frame.settings
|
||||
await self._update_settings()
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
await self._update_settings()
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user