context management improvements

This commit is contained in:
Kwindla Hultman Kramer
2024-10-08 17:47:32 -07:00
parent 0db5c86494
commit df2ddb4b91
2 changed files with 161 additions and 24 deletions

View File

@@ -5,8 +5,10 @@
#
import asyncio
import json
import os
import sys
from datetime import datetime
import aiohttp
from dotenv import load_dotenv
@@ -20,6 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
)
from pipecat.services.openai_realtime_beta import (
InputAudioTranscription,
OpenAILLMServiceRealtimeBeta,
SessionProperties,
)
@@ -34,7 +37,24 @@ logger.add(sys.stderr, level="DEBUG")
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
await result_callback({"conditions": "nice", "temperature": "75"})
await result_callback(
{
"conditions": "nice",
"temperature": "75",
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
}
)
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"example_19_{timestamp}.json"
try:
with open(filename, "w") as file:
json.dump(context.messages, file, indent=4)
await result_callback({"success": True})
except Exception as e:
await result_callback({"success": False, "error": str(e)})
tools = [
@@ -57,7 +77,17 @@ tools = [
},
"required": ["location", "format"],
},
}
},
{
"type": "function",
"name": "save_conversation",
"description": "Save the current conversatione. Use this function to persist the current conversation to external storage.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
]
@@ -82,7 +112,7 @@ async def main():
)
session_properties = SessionProperties(
# input_audio_transcription=InputAudioTranscription(),
input_audio_transcription=InputAudioTranscription(),
# Set openai TurnDetection parameters. Not setting this at all will turn it
# on by default
# turn_detection=TurnDetection(silence_duration_ms=1000),
@@ -114,6 +144,7 @@ Remember, your responses should be short. Just one or two sentences, usually.
# you can either register a single function for all function calls, or specific functions
# llm.register_function(None, fetch_weather_from_api)
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("save_conversation", save_conversation)
context = OpenAILLMContext(
# [{"role": "user", "content": "What's the weather right now in San Francisco?"}], tools

View File

@@ -1,7 +1,6 @@
import asyncio
import base64
import json
import random
import traceback
from copy import deepcopy
from dataclasses import dataclass
@@ -65,20 +64,42 @@ class OpenAIUnhandledFunctionException(Exception):
class OpenAIRealtimeLLMContext(OpenAILLMContext):
def __init__(self, messages=None, tools=None, **kwargs):
super().__init__(messages=messages, tools=tools, **kwargs)
self.__setup_local()
def __setup_local(self):
# messages that have been added to the context but not yet sent to the openai server
self._unsent_messages = deepcopy(self._messages)
# messages that we added to the context because they were part of our external
# context store. we do not want to add these again when we see conversation.item.created
# events about them. map from item_id to True
self._manually_created_messages = {}
# "conversation items" that have been created by opeanai realtime api events but are
# not completely filled in, yet. map from item_id to message
self._messages_in_progress = {}
@staticmethod
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
obj.__class__ = OpenAIRealtimeLLMContext
obj._unsent_messages = deepcopy(obj._messages)
obj._marker = random.randint(1, 1000)
obj.__setup_local()
return obj
# todo: do we need to also override add_messages() ?
# cases to handle
# - tools in the context constructor (and in general?)
# - relatedly, set tools frame
# - clearing the context by deleting all messages (for scripted conversations)
# - truncating the last spoken message to maintain context when interrupted
def add_message(self, message):
super().add_message(message)
if message.get("role") == "tool":
self._unsent_messages.append(message)
self._unsent_messages.append(message)
return message
def add_message_already_present_in_api_context(self, message):
super().add_message(message)
return message
def set_messages(self, messages):
super().set_messages(messages)
@@ -90,6 +111,45 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
def update_all_messages_sent(self):
self._unsent_messages = []
def note_manually_added_message(self, item_id):
self._manually_created_messages[item_id] = True
def add_message_from_realtime_event(self, evt):
if evt.item.id in self._manually_created_messages:
del self._manually_created_messages[evt.item.id]
return
# add messages. don't add function_call or function_call_output items.
if evt.item.type == "message":
message = self.add_message_already_present_in_api_context(
{"role": evt.item.role, "content": []}
)
if not evt.item.content:
self._messages_in_progress[evt.item.id] = message
return
for content in evt.item.content:
message["content"].append({"type": content.type})
if content.text:
message["content"] = content.text
elif content.transcript:
message["content"] = content.transcript
else:
# we will get the transcript in a later event
self._messages_in_progress[evt.item.id] = message
return
def add_transcript_to_message(self, evt):
message = self._messages_in_progress.get(evt.item_id)
if message:
cs = message["content"]
cs.extend([{"type": ""}] * (evt.content_index - len(cs) + 1))
cs[evt.content_index] = {"type": "text", "text": evt.transcript}
del self._messages_in_progress[evt.item_id]
else:
logger.error(
f"Could not find content {evt.item_id}/{evt.content_index} to add transcript to"
)
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
async def process_frame(
@@ -105,13 +165,55 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
async def _push_aggregation(self):
# for the moment, ignore all user input coming into the pipeline.
# todo: fix this to allow text prompting
# todo: think about whether/how to fix this to allow for text input from
# upstream (transport/transcription, or other sources)
pass
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
await super()._push_aggregation()
# the only thing we implement here is function calling. in all other cases, messages
# are added to the context when we receive openai realtime api events
if not self._function_call_result:
return
self._reset()
try:
frame = self._function_call_result
self._function_call_result = None
if frame.result:
self._context.add_message_already_present_in_api_context(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
"type": "function",
}
],
}
)
self._context.add_message(
{
"role": "tool",
"content": json.dumps(frame.result),
"tool_call_id": frame.tool_call_id,
}
)
run_llm = frame.run_llm
if run_llm:
await self._user_context_aggregator.push_context_frame()
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
except Exception as e:
logger.error(f"Error processing frame: {e}")
class OpenAILLMServiceRealtimeBeta(LLMService):
@@ -218,7 +320,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
self._session_properties = evt.session
elif evt.type == "input_audio_buffer.speech_started":
# user started speaking
# todo: send user started speaking if configured
if self._send_user_started_speaking_frames:
await self.push_frame(UserStartedSpeakingFrame())
await self.push_frame(StartInterruptionFrame())
@@ -226,7 +327,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
pass
elif evt.type == "input_audio_buffer.speech_stopped":
# user stopped speaking
# todo: send user stopped speaking if configured
if self._send_user_started_speaking_frames:
await self.push_frame(UserStoppedSpeakingFrame())
await self.push_frame(StopInterruptionFrame())
@@ -235,12 +335,10 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
await self.start_processing_metrics()
await self.start_ttfb_metrics()
elif evt.type == "conversation.item.created":
# for input, this will get sent from the server whether the
# conversation item is created by audio transcription or by
# sending a client conversation.item.create message.
# we could listen to this event and track conversation item IDs to
# help with context bookkeeping.
pass
# this will get sent from the server every time a new "message" is added
# to the server's conversation state
if self._context:
self._context.add_message_from_realtime_event(evt)
elif evt.type == "response.created":
# todo: 1. figure out TTS started/stopped frame semantics better
# 2. do not push these frames in text-only mode
@@ -249,12 +347,9 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
await self.push_frame(TTSStartedFrame())
pass
elif evt.type == "conversation.item.input_audio_transcription.completed":
# or here maybe (possible send upstream to user context aggregator)
if evt.transcript:
if self._context:
self._context.add_message({"role": "user", "content": evt.transcript})
else:
logger.error("Context is None, cannot add message")
self._context.add_transcript_to_message(evt)
if self._send_transcription_frames:
await self.push_frame(
# no way to get a language code?
@@ -284,7 +379,8 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
self._bot_speaking = False
await self.push_frame(TTSStoppedFrame())
elif evt.type == "response.audio_transcript.done":
# this doesn't map to any Pipecat frame types
if self._context:
self._context.add_transcript_to_message(evt)
pass
elif evt.type == "response.content_part.done":
# this doesn't map to any Pipecat frame types
@@ -374,6 +470,15 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
content=[events.ItemContent(type="input_text", text=content)],
)
)
elif isinstance(content, list):
items.append(
events.ConversationItem(
type="message",
status="completed",
role="user",
content=content,
)
)
else:
raise Exception(f"Invalid message content {m}")
elif m and m.get("role") == "tool":
@@ -385,6 +490,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
)
)
for item in items:
context.note_manually_added_message(item.id)
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
async def _create_response(self):