context management improvements
This commit is contained in:
@@ -5,8 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -20,6 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioTranscription,
|
||||
OpenAILLMServiceRealtimeBeta,
|
||||
SessionProperties,
|
||||
)
|
||||
@@ -34,7 +37,24 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": "75",
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"example_19_{timestamp}.json"
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
json.dump(context.messages, file, indent=4)
|
||||
await result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
tools = [
|
||||
@@ -57,7 +77,17 @@ tools = [
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "save_conversation",
|
||||
"description": "Save the current conversatione. Use this function to persist the current conversation to external storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -82,7 +112,7 @@ async def main():
|
||||
)
|
||||
|
||||
session_properties = SessionProperties(
|
||||
# input_audio_transcription=InputAudioTranscription(),
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
# turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
@@ -114,6 +144,7 @@ Remember, your responses should be short. Just one or two sentences, usually.
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
# [{"role": "user", "content": "What's the weather right now in San Francisco?"}], tools
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@@ -65,20 +64,42 @@ class OpenAIUnhandledFunctionException(Exception):
|
||||
|
||||
|
||||
class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self):
|
||||
# messages that have been added to the context but not yet sent to the openai server
|
||||
self._unsent_messages = deepcopy(self._messages)
|
||||
# messages that we added to the context because they were part of our external
|
||||
# context store. we do not want to add these again when we see conversation.item.created
|
||||
# events about them. map from item_id to True
|
||||
self._manually_created_messages = {}
|
||||
# "conversation items" that have been created by opeanai realtime api events but are
|
||||
# not completely filled in, yet. map from item_id to message
|
||||
self._messages_in_progress = {}
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
|
||||
obj.__class__ = OpenAIRealtimeLLMContext
|
||||
obj._unsent_messages = deepcopy(obj._messages)
|
||||
obj._marker = random.randint(1, 1000)
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# todo: do we need to also override add_messages() ?
|
||||
# cases to handle
|
||||
# - tools in the context constructor (and in general?)
|
||||
# - relatedly, set tools frame
|
||||
# - clearing the context by deleting all messages (for scripted conversations)
|
||||
# - truncating the last spoken message to maintain context when interrupted
|
||||
|
||||
def add_message(self, message):
|
||||
super().add_message(message)
|
||||
if message.get("role") == "tool":
|
||||
self._unsent_messages.append(message)
|
||||
self._unsent_messages.append(message)
|
||||
return message
|
||||
|
||||
def add_message_already_present_in_api_context(self, message):
|
||||
super().add_message(message)
|
||||
return message
|
||||
|
||||
def set_messages(self, messages):
|
||||
super().set_messages(messages)
|
||||
@@ -90,6 +111,45 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
def update_all_messages_sent(self):
|
||||
self._unsent_messages = []
|
||||
|
||||
def note_manually_added_message(self, item_id):
|
||||
self._manually_created_messages[item_id] = True
|
||||
|
||||
def add_message_from_realtime_event(self, evt):
|
||||
if evt.item.id in self._manually_created_messages:
|
||||
del self._manually_created_messages[evt.item.id]
|
||||
return
|
||||
|
||||
# add messages. don't add function_call or function_call_output items.
|
||||
if evt.item.type == "message":
|
||||
message = self.add_message_already_present_in_api_context(
|
||||
{"role": evt.item.role, "content": []}
|
||||
)
|
||||
if not evt.item.content:
|
||||
self._messages_in_progress[evt.item.id] = message
|
||||
return
|
||||
for content in evt.item.content:
|
||||
message["content"].append({"type": content.type})
|
||||
if content.text:
|
||||
message["content"] = content.text
|
||||
elif content.transcript:
|
||||
message["content"] = content.transcript
|
||||
else:
|
||||
# we will get the transcript in a later event
|
||||
self._messages_in_progress[evt.item.id] = message
|
||||
return
|
||||
|
||||
def add_transcript_to_message(self, evt):
|
||||
message = self._messages_in_progress.get(evt.item_id)
|
||||
if message:
|
||||
cs = message["content"]
|
||||
cs.extend([{"type": ""}] * (evt.content_index - len(cs) + 1))
|
||||
cs[evt.content_index] = {"type": "text", "text": evt.transcript}
|
||||
del self._messages_in_progress[evt.item_id]
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not find content {evt.item_id}/{evt.content_index} to add transcript to"
|
||||
)
|
||||
|
||||
|
||||
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def process_frame(
|
||||
@@ -105,13 +165,55 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
async def _push_aggregation(self):
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: fix this to allow text prompting
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
await super()._push_aggregation()
|
||||
# the only thing we implement here is function calling. in all other cases, messages
|
||||
# are added to the context when we receive openai realtime api events
|
||||
if not self._function_call_result:
|
||||
return
|
||||
|
||||
self._reset()
|
||||
try:
|
||||
frame = self._function_call_result
|
||||
self._function_call_result = None
|
||||
if frame.result:
|
||||
self._context.add_message_already_present_in_api_context(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps(frame.result),
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
run_llm = frame.run_llm
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
@@ -218,7 +320,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
self._session_properties = evt.session
|
||||
elif evt.type == "input_audio_buffer.speech_started":
|
||||
# user started speaking
|
||||
# todo: send user started speaking if configured
|
||||
if self._send_user_started_speaking_frames:
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
@@ -226,7 +327,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
pass
|
||||
elif evt.type == "input_audio_buffer.speech_stopped":
|
||||
# user stopped speaking
|
||||
# todo: send user stopped speaking if configured
|
||||
if self._send_user_started_speaking_frames:
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
@@ -235,12 +335,10 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
elif evt.type == "conversation.item.created":
|
||||
# for input, this will get sent from the server whether the
|
||||
# conversation item is created by audio transcription or by
|
||||
# sending a client conversation.item.create message.
|
||||
# we could listen to this event and track conversation item IDs to
|
||||
# help with context bookkeeping.
|
||||
pass
|
||||
# this will get sent from the server every time a new "message" is added
|
||||
# to the server's conversation state
|
||||
if self._context:
|
||||
self._context.add_message_from_realtime_event(evt)
|
||||
elif evt.type == "response.created":
|
||||
# todo: 1. figure out TTS started/stopped frame semantics better
|
||||
# 2. do not push these frames in text-only mode
|
||||
@@ -249,12 +347,9 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
pass
|
||||
elif evt.type == "conversation.item.input_audio_transcription.completed":
|
||||
# or here maybe (possible send upstream to user context aggregator)
|
||||
if evt.transcript:
|
||||
if self._context:
|
||||
self._context.add_message({"role": "user", "content": evt.transcript})
|
||||
else:
|
||||
logger.error("Context is None, cannot add message")
|
||||
self._context.add_transcript_to_message(evt)
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
# no way to get a language code?
|
||||
@@ -284,7 +379,8 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
self._bot_speaking = False
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
elif evt.type == "response.audio_transcript.done":
|
||||
# this doesn't map to any Pipecat frame types
|
||||
if self._context:
|
||||
self._context.add_transcript_to_message(evt)
|
||||
pass
|
||||
elif evt.type == "response.content_part.done":
|
||||
# this doesn't map to any Pipecat frame types
|
||||
@@ -374,6 +470,15 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
content=[events.ItemContent(type="input_text", text=content)],
|
||||
)
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
items.append(
|
||||
events.ConversationItem(
|
||||
type="message",
|
||||
status="completed",
|
||||
role="user",
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Invalid message content {m}")
|
||||
elif m and m.get("role") == "tool":
|
||||
@@ -385,6 +490,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
)
|
||||
)
|
||||
for item in items:
|
||||
context.note_manually_added_message(item.id)
|
||||
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
|
||||
|
||||
async def _create_response(self):
|
||||
|
||||
Reference in New Issue
Block a user