From df2ddb4b9159595dd5aec77d99fb8ea8c3736b0b Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Tue, 8 Oct 2024 17:47:32 -0700 Subject: [PATCH] context management improvements --- .../foundational/19-openai-realtime-beta.py | 37 ++++- .../openai_realtime_beta/llm_and_context.py | 148 +++++++++++++++--- 2 files changed, 161 insertions(+), 24 deletions(-) diff --git a/examples/foundational/19-openai-realtime-beta.py b/examples/foundational/19-openai-realtime-beta.py index b687eefc1..38b635e35 100644 --- a/examples/foundational/19-openai-realtime-beta.py +++ b/examples/foundational/19-openai-realtime-beta.py @@ -5,8 +5,10 @@ # import asyncio +import json import os import sys +from datetime import datetime import aiohttp from dotenv import load_dotenv @@ -20,6 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, ) from pipecat.services.openai_realtime_beta import ( + InputAudioTranscription, OpenAILLMServiceRealtimeBeta, SessionProperties, ) @@ -34,7 +37,24 @@ logger.add(sys.stderr, level="DEBUG") async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback): - await result_callback({"conditions": "nice", "temperature": "75"}) + await result_callback( + { + "conditions": "nice", + "temperature": "75", + "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), + } + ) + + +async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"example_19_{timestamp}.json" + try: + with open(filename, "w") as file: + json.dump(context.messages, file, indent=4) + await result_callback({"success": True}) + except Exception as e: + await result_callback({"success": False, "error": str(e)}) tools = [ @@ -57,7 +77,17 @@ tools = [ }, "required": ["location", "format"], }, - } + }, + { + "type": "function", + "name": "save_conversation", + "description": "Save the current conversatione. Use this function to persist the current conversation to external storage.", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, ] @@ -82,7 +112,7 @@ async def main(): ) session_properties = SessionProperties( - # input_audio_transcription=InputAudioTranscription(), + input_audio_transcription=InputAudioTranscription(), # Set openai TurnDetection parameters. Not setting this at all will turn it # on by default # turn_detection=TurnDetection(silence_duration_ms=1000), @@ -114,6 +144,7 @@ Remember, your responses should be short. Just one or two sentences, usually. # you can either register a single function for all function calls, or specific functions # llm.register_function(None, fetch_weather_from_api) llm.register_function("get_current_weather", fetch_weather_from_api) + llm.register_function("save_conversation", save_conversation) context = OpenAILLMContext( # [{"role": "user", "content": "What's the weather right now in San Francisco?"}], tools diff --git a/src/pipecat/services/openai_realtime_beta/llm_and_context.py b/src/pipecat/services/openai_realtime_beta/llm_and_context.py index f21f2066b..ad6b6a39a 100644 --- a/src/pipecat/services/openai_realtime_beta/llm_and_context.py +++ b/src/pipecat/services/openai_realtime_beta/llm_and_context.py @@ -1,7 +1,6 @@ import asyncio import base64 import json -import random import traceback from copy import deepcopy from dataclasses import dataclass @@ -65,20 +64,42 @@ class OpenAIUnhandledFunctionException(Exception): class OpenAIRealtimeLLMContext(OpenAILLMContext): + def __init__(self, messages=None, tools=None, **kwargs): + super().__init__(messages=messages, tools=tools, **kwargs) + self.__setup_local() + + def __setup_local(self): + # messages that have been added to the context but not yet sent to the openai server + self._unsent_messages = deepcopy(self._messages) + # messages that we added to the context because they were part of our external + # context store. we do not want to add these again when we see conversation.item.created + # events about them. map from item_id to True + self._manually_created_messages = {} + # "conversation items" that have been created by opeanai realtime api events but are + # not completely filled in, yet. map from item_id to message + self._messages_in_progress = {} + @staticmethod def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext": if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext): obj.__class__ = OpenAIRealtimeLLMContext - obj._unsent_messages = deepcopy(obj._messages) - obj._marker = random.randint(1, 1000) + obj.__setup_local() return obj - # todo: do we need to also override add_messages() ? + # cases to handle + # - tools in the context constructor (and in general?) + # - relatedly, set tools frame + # - clearing the context by deleting all messages (for scripted conversations) + # - truncating the last spoken message to maintain context when interrupted def add_message(self, message): super().add_message(message) - if message.get("role") == "tool": - self._unsent_messages.append(message) + self._unsent_messages.append(message) + return message + + def add_message_already_present_in_api_context(self, message): + super().add_message(message) + return message def set_messages(self, messages): super().set_messages(messages) @@ -90,6 +111,45 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext): def update_all_messages_sent(self): self._unsent_messages = [] + def note_manually_added_message(self, item_id): + self._manually_created_messages[item_id] = True + + def add_message_from_realtime_event(self, evt): + if evt.item.id in self._manually_created_messages: + del self._manually_created_messages[evt.item.id] + return + + # add messages. don't add function_call or function_call_output items. + if evt.item.type == "message": + message = self.add_message_already_present_in_api_context( + {"role": evt.item.role, "content": []} + ) + if not evt.item.content: + self._messages_in_progress[evt.item.id] = message + return + for content in evt.item.content: + message["content"].append({"type": content.type}) + if content.text: + message["content"] = content.text + elif content.transcript: + message["content"] = content.transcript + else: + # we will get the transcript in a later event + self._messages_in_progress[evt.item.id] = message + return + + def add_transcript_to_message(self, evt): + message = self._messages_in_progress.get(evt.item_id) + if message: + cs = message["content"] + cs.extend([{"type": ""}] * (evt.content_index - len(cs) + 1)) + cs[evt.content_index] = {"type": "text", "text": evt.transcript} + del self._messages_in_progress[evt.item_id] + else: + logger.error( + f"Could not find content {evt.item_id}/{evt.content_index} to add transcript to" + ) + class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): async def process_frame( @@ -105,13 +165,55 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): async def _push_aggregation(self): # for the moment, ignore all user input coming into the pipeline. - # todo: fix this to allow text prompting + # todo: think about whether/how to fix this to allow for text input from + # upstream (transport/transcription, or other sources) pass class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator): async def _push_aggregation(self): - await super()._push_aggregation() + # the only thing we implement here is function calling. in all other cases, messages + # are added to the context when we receive openai realtime api events + if not self._function_call_result: + return + + self._reset() + try: + frame = self._function_call_result + self._function_call_result = None + if frame.result: + self._context.add_message_already_present_in_api_context( + { + "role": "assistant", + "tool_calls": [ + { + "id": frame.tool_call_id, + "function": { + "name": frame.function_name, + "arguments": json.dumps(frame.arguments), + }, + "type": "function", + } + ], + } + ) + self._context.add_message( + { + "role": "tool", + "content": json.dumps(frame.result), + "tool_call_id": frame.tool_call_id, + } + ) + run_llm = frame.run_llm + + if run_llm: + await self._user_context_aggregator.push_context_frame() + + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + except Exception as e: + logger.error(f"Error processing frame: {e}") class OpenAILLMServiceRealtimeBeta(LLMService): @@ -218,7 +320,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService): self._session_properties = evt.session elif evt.type == "input_audio_buffer.speech_started": # user started speaking - # todo: send user started speaking if configured if self._send_user_started_speaking_frames: await self.push_frame(UserStartedSpeakingFrame()) await self.push_frame(StartInterruptionFrame()) @@ -226,7 +327,6 @@ class OpenAILLMServiceRealtimeBeta(LLMService): pass elif evt.type == "input_audio_buffer.speech_stopped": # user stopped speaking - # todo: send user stopped speaking if configured if self._send_user_started_speaking_frames: await self.push_frame(UserStoppedSpeakingFrame()) await self.push_frame(StopInterruptionFrame()) @@ -235,12 +335,10 @@ class OpenAILLMServiceRealtimeBeta(LLMService): await self.start_processing_metrics() await self.start_ttfb_metrics() elif evt.type == "conversation.item.created": - # for input, this will get sent from the server whether the - # conversation item is created by audio transcription or by - # sending a client conversation.item.create message. - # we could listen to this event and track conversation item IDs to - # help with context bookkeeping. - pass + # this will get sent from the server every time a new "message" is added + # to the server's conversation state + if self._context: + self._context.add_message_from_realtime_event(evt) elif evt.type == "response.created": # todo: 1. figure out TTS started/stopped frame semantics better # 2. do not push these frames in text-only mode @@ -249,12 +347,9 @@ class OpenAILLMServiceRealtimeBeta(LLMService): await self.push_frame(TTSStartedFrame()) pass elif evt.type == "conversation.item.input_audio_transcription.completed": - # or here maybe (possible send upstream to user context aggregator) if evt.transcript: if self._context: - self._context.add_message({"role": "user", "content": evt.transcript}) - else: - logger.error("Context is None, cannot add message") + self._context.add_transcript_to_message(evt) if self._send_transcription_frames: await self.push_frame( # no way to get a language code? @@ -284,7 +379,8 @@ class OpenAILLMServiceRealtimeBeta(LLMService): self._bot_speaking = False await self.push_frame(TTSStoppedFrame()) elif evt.type == "response.audio_transcript.done": - # this doesn't map to any Pipecat frame types + if self._context: + self._context.add_transcript_to_message(evt) pass elif evt.type == "response.content_part.done": # this doesn't map to any Pipecat frame types @@ -374,6 +470,15 @@ class OpenAILLMServiceRealtimeBeta(LLMService): content=[events.ItemContent(type="input_text", text=content)], ) ) + elif isinstance(content, list): + items.append( + events.ConversationItem( + type="message", + status="completed", + role="user", + content=content, + ) + ) else: raise Exception(f"Invalid message content {m}") elif m and m.get("role") == "tool": @@ -385,6 +490,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService): ) ) for item in items: + context.note_manually_added_message(item.id) await self.send_client_event(events.ConversationItemCreateEvent(item=item)) async def _create_response(self):