diff --git a/examples/foundational/19-openai-realtime-beta.py b/examples/foundational/19-openai-realtime-beta.py index 38b635e35..4739b793e 100644 --- a/examples/foundational/19-openai-realtime-beta.py +++ b/examples/foundational/19-openai-realtime-beta.py @@ -7,6 +7,7 @@ import asyncio import json import os +import re import sys from datetime import datetime @@ -15,6 +16,7 @@ from dotenv import load_dotenv from loguru import logger from runner import configure +from pipecat.frames.frames import LLMMessagesUpdateFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask @@ -37,18 +39,34 @@ logger.add(sys.stderr, level="DEBUG") async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback): + temperature = 75 if args["format"] == "fahrenheit" else 24 await result_callback( { "conditions": "nice", - "temperature": "75", + "temperature": temperature, + "format": args["format"], "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), } ) +async def get_saved_conversation_filenames( + function_name, tool_call_id, args, llm, context, result_callback +): + pattern = re.compile("example_19_\\d{8}_\\d{6}\\.json$") + matching_files = [] + + for filename in os.listdir("."): + if pattern.match(filename): + matching_files.append(filename) + + await result_callback({"filenames": matching_files}) + + async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"example_19_{timestamp}.json" + logger.debug(f"writing conversation to {filename}\n{json.dumps(context.messages, indent=4)}") try: with open(filename, "w") as file: json.dump(context.messages, file, indent=4) @@ -57,6 +75,18 @@ async def save_conversation(function_name, tool_call_id, args, llm, context, res await result_callback({"success": False, "error": str(e)}) +async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback): + filename = args["filename"] + logger.debug(f"loading conversation from {filename}") + try: + with open(filename, "r") as file: + messages = json.load(file) + await result_callback({"success": True}) + await llm.push_frame(LLMMessagesUpdateFrame(messages)) + except Exception as e: + await result_callback({"success": False, "error": str(e)}) + + tools = [ { "type": "function", @@ -88,6 +118,31 @@ tools = [ "required": [], }, }, + { + "type": "function", + "name": "get_saved_conversation_filenames", + "description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a timestamp. Each file is conversation history that can be loaded into this session.", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + { + "type": "function", + "name": "load_conversation", + "description": "Load a conversation history. Use this function to load a conversation history into the current session.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The filename of the conversation history to load.", + } + }, + "required": ["filename"], + }, + }, ] @@ -118,7 +173,7 @@ async def main(): # turn_detection=TurnDetection(silence_duration_ms=1000), # Or set to False to disable openai turn detection and use transport VAD turn_detection=False, - tools=tools, + # tools=tools, instructions=""" Your knowledge cutoff is 2023-10. You are a helpful and friendly AI. @@ -145,10 +200,14 @@ Remember, your responses should be short. Just one or two sentences, usually. # llm.register_function(None, fetch_weather_from_api) llm.register_function("get_current_weather", fetch_weather_from_api) llm.register_function("save_conversation", save_conversation) + llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames) + llm.register_function("load_conversation", load_conversation) context = OpenAILLMContext( - # [{"role": "user", "content": "What's the weather right now in San Francisco?"}], tools [{"role": "user", "content": "Say 'hello'."}], + # [{"role": "user", "content": "What's the weather right now in San Francisco?"}], + # conversation load from file is a WIP -- not functional yet + # [{"role": "user", "content": "Load the most recent conversation."}], tools, ) context_aggregator = llm.create_context_aggregator(context) diff --git a/src/pipecat/services/openai_realtime_beta/llm_and_context.py b/src/pipecat/services/openai_realtime_beta/llm_and_context.py index ad6b6a39a..de7f367b4 100644 --- a/src/pipecat/services/openai_realtime_beta/llm_and_context.py +++ b/src/pipecat/services/openai_realtime_beta/llm_and_context.py @@ -4,6 +4,7 @@ import json import traceback from copy import deepcopy from dataclasses import dataclass +from typing import List import websockets from loguru import logger @@ -18,6 +19,7 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesUpdateFrame, + LLMSetToolsFrame, LLMUpdateSettingsFrame, StartFrame, StartInterruptionFrame, @@ -78,6 +80,9 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext): # "conversation items" that have been created by opeanai realtime api events but are # not completely filled in, yet. map from item_id to message self._messages_in_progress = {} + # count of messages prior to recent reset + self._messages_reset_count = 0 + self._tools_list_updated = True @staticmethod def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext": @@ -86,30 +91,48 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext): obj.__setup_local() return obj - # cases to handle - # - tools in the context constructor (and in general?) - # - relatedly, set tools frame - # - clearing the context by deleting all messages (for scripted conversations) + # still working on + # - clearing the context by deleting all messages + # - reloading from a standard messages list # - truncating the last spoken message to maintain context when interrupted + def set_tools(self, tools: List): + super().set_tools(tools) + self._tools_list_updated = True + def add_message(self, message): super().add_message(message) self._unsent_messages.append(message) return message + def add_messages(self, messages): + super().add_messages(messages) + self._unsent_messages.extend(messages) + def add_message_already_present_in_api_context(self, message): super().add_message(message) return message def set_messages(self, messages): + self._messages_reset_count = len(self.messages) - len(self._unsent_messages) super().set_messages(messages) self._unsent_messages = deepcopy(self._messages) def get_unsent_messages(self): return self._unsent_messages + def get_messages_reset_count(self): + return self._messages_reset_count + + def get_tools_list_updated(self): + return self._tools_list_updated + def update_all_messages_sent(self): self._unsent_messages = [] + self._messages_reset_count = 0 + + def update_tools_list_sent(self): + self._tools_list_updated = False def note_manually_added_message(self, item_id): self._manually_created_messages[item_id] = True @@ -163,6 +186,10 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): if isinstance(frame, LLMMessagesUpdateFrame): await self.push_frame(_InternalMessagesUpdateFrame(context=self._context)) + # Parent also doesn't push the LLMSetToolsFrame. + if isinstance(frame, LLMSetToolsFrame): + await self.push_frame(frame, direction) + async def _push_aggregation(self): # for the moment, ignore all user input coming into the pipeline. # todo: think about whether/how to fix this to allow for text input from @@ -232,7 +259,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService): self.api_key = api_key self.base_url = base_url - self._session_properties = session_properties + self._session_properties: events.SessionProperties = session_properties self._audio_input_paused = start_audio_paused self._send_transcription_frames = send_transcription_frames # todo: wire _send_user_started_speaking_frames up correctly @@ -304,7 +331,12 @@ class OpenAILLMServiceRealtimeBeta(LLMService): return self._websocket raise Exception("Websocket not connected") - async def _update_settings(self, settings: events.SessionProperties): + async def _update_settings(self): + settings = self._session_properties + # tools given in the context override the tools in the session properties + if self._context and self._context.tools: + settings.tools = self._context.tools + self._context.update_tools_list_sent() await self.send_client_event(events.SessionUpdateEvent(session=settings)) async def _receive_task_handler(self): @@ -315,7 +347,7 @@ class OpenAILLMServiceRealtimeBeta(LLMService): if evt.type == "session.created": # session.created is received right after connecting. send a message # to configure the session properties. - await self._update_settings(self._session_properties) + await self._update_settings() elif evt.type == "session.updated": self._session_properties = evt.session elif evt.type == "input_audio_buffer.speech_started": @@ -445,9 +477,12 @@ class OpenAILLMServiceRealtimeBeta(LLMService): f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." ) - async def _reset_conversation(self): + async def _reset_conversation(self, count): # need to think about how to implement this, and how to think about interop with messages lists # used with the HTTP API + logger.debug(f"!!! RESET CONVERSATION: {count} [WIP]") + await self._disconnect() + await self._connect() pass async def _send_messages_context_update(self): @@ -455,28 +490,70 @@ class OpenAILLMServiceRealtimeBeta(LLMService): return context = self._context messages = context.get_unsent_messages() + + needs_reset = context.get_messages_reset_count() context.update_all_messages_sent() + if needs_reset: + await self._reset_conversation(needs_reset) + # debugging + logger.debug("MESSAGE HISTORY RELOAD NOT IMPLEMENTED YET") + return + items = [] for m in messages: - if m and (m.get("role") == "user" or m.get("role") == "system"): + if m and ( + m.get("role") == "user" or m.get("role") == "system" or m.get("role") == "assistant" + ): content = m.get("content") if isinstance(content, str): - items.append( - events.ConversationItem( - type="message", - status="completed", - role="user", - content=[events.ItemContent(type="input_text", text=content)], + # skip any messages that aren't "text" and change "user" message type to "input_text" + + if m.get("type", "text") == "text": + items.append( + events.ConversationItem( + type="message", + status="completed", + role=m.get("role", "user"), + content=[ + events.ItemContent( + type="input_text" if m.get("role") == "user" else "text", + text=content, + ) + ], + ) ) - ) elif isinstance(content, list): + # skip any messages that aren't "text" and change "user" message type to "input_text" + cs = [] + for item in content: + if item.get("type", "text") == "text": + # cs.append(events.ItemContent(type="input_text", text=item.get("text"))) + ( + cs.append( + events.ItemContent( + type="input_text" if m.get("role") == "user" else "text", + text=item.get("text"), + ) + ), + ) + if cs: + items.append( + events.ConversationItem( + type="message", + status="completed", + role=m.get("role", "user"), + content=cs, + ) + ) + elif m.get("role") == "assistant" and m.get("tool_calls"): + tc = m.get("tool_calls")[0] items.append( events.ConversationItem( - type="message", - status="completed", - role="user", - content=content, + type="function_call", + call_id=tc["id"], + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], ) ) else: @@ -489,11 +566,14 @@ class OpenAILLMServiceRealtimeBeta(LLMService): output=m["content"], ) ) + for item in items: context.note_manually_added_message(item.id) await self.send_client_event(events.ConversationItemCreateEvent(item=item)) async def _create_response(self): + if self._context.get_tools_list_updated(): + await self._update_settings() await self._send_messages_context_update() logger.debug(f"Creating response: {self._context.get_messages_for_logging()}") await self.push_frame(LLMFullResponseStartFrame()) @@ -544,7 +624,10 @@ class OpenAILLMServiceRealtimeBeta(LLMService): self._context = frame.context await self._send_messages_context_update() elif isinstance(frame, LLMUpdateSettingsFrame): - await self._update_settings(frame.settings) + self._session_properties = frame.settings + await self._update_settings() + elif isinstance(frame, LLMSetToolsFrame): + await self._update_settings() await self.push_frame(frame, direction)