processors(aggregators): multiple LLM aggregators updates

This commit is contained in:
Aleix Conchillo Flaqué
2024-08-14 20:23:18 -07:00
parent b78981bb9d
commit 3a5cd17ea3
5 changed files with 66 additions and 61 deletions

View File

@@ -82,10 +82,6 @@ class LLMResponseAggregator(FrameProcessor):
#
# and T2 would be dropped.
async def _set_tools(self, tools: List):
# noop in the base class
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -129,34 +125,16 @@ class LLMResponseAggregator(FrameProcessor):
self._reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMMessagesAppendFrame):
self._messages.extend(frame.messages)
messages_frame = LLMMessagesFrame(self._messages)
await self.push_frame(messages_frame)
self._add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
# We push the frame downstream so the assistant aggregator gets
# updated as well.
# TODO-CB: Now we're replacing the contents of the array so we
# don't need to push the frame here
# await self.push_frame(frame)
# We can now reset this one.
self._reset()
self._set_messages(frame.messages)
# messages_frame = LLMMessagesFrame(self._messages)
# await self.push_frame(messages_frame)
await self.push_messages_frame()
elif isinstance(frame, LLMSetToolsFrame):
await self.push_frame(frame)
await self._set_tools(frame.tools)
self._set_tools(frame.tools)
else:
await self.push_frame(frame, direction)
if send_aggregation:
await self._push_aggregation()
# TODO-CB: Types
def _set_messages(self, messages):
self._messages.clear()
self._messages.extend(messages)
async def _push_aggregation(self):
if len(self._aggregation) > 0:
@@ -169,6 +147,19 @@ class LLMResponseAggregator(FrameProcessor):
frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)
# TODO-CB: Types
def _add_messages(self, messages):
self._messages.extend(messages)
def _set_messages(self, messages):
self._reset()
self._messages.clear()
self._messages.extend(messages)
def _set_tools(self, tools):
# noop in the base class
pass
def _reset(self):
self._aggregation = ""
self._aggregating = False
@@ -257,23 +248,29 @@ class LLMFullResponseAggregator(FrameProcessor):
class LLMContextAggregator(LLMResponseAggregator):
def __init__(self, *, context: OpenAILLMContext, **kwargs):
self._context = context
super().__init__(**kwargs)
# TODO-CB: thanks, I hate it
self._messages = context.messages
async def _set_tools(self, tools: List):
# We push the frame downstream so the assistant aggregator gets
# updated as well.
self._context.tools = tools
# TODO-CB: Types
def _set_messages(self, messages):
self._messages.clear()
self._messages.extend(messages)
self._context = context
@property
def context(self):
return self._context
def get_context_frame(self) -> OpenAILLMContextFrame:
return OpenAILLMContextFrame(context=self._context)
async def push_context_frame(self):
frame = self.get_context_frame()
await self.push_frame(frame)
# TODO-CB: Types
def _add_messages(self, messages):
self._context.add_messages(messages)
def _set_messages(self, messages):
self._context.set_messages(messages)
def _set_tools(self, tools: List):
self._context.set_tools(tools)
async def _push_aggregation(self):
if len(self._aggregation) > 0:

View File

@@ -44,10 +44,10 @@ class OpenAILLMContext:
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
):
self.messages: List[ChatCompletionMessageParam] = messages if messages else [
self._messages: List[ChatCompletionMessageParam] = messages if messages else [
]
self.tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
self.tools: List[ChatCompletionToolParam] | NotGiven = tools
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
@staticmethod
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
@@ -84,26 +84,43 @@ class OpenAILLMContext:
})
return context
@property
def messages(self) -> List[ChatCompletionMessageParam]:
return self._messages
@property
def tools(self) -> List[ChatCompletionToolParam] | NotGiven:
return self._tools
@property
def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven:
return self._tool_choice
def add_message(self, message: ChatCompletionMessageParam):
self.messages.append(message)
self._messages.append(message)
def add_messages(self, messages: List[ChatCompletionMessageParam]):
self._messages.extend(messages)
def set_messages(self, messages: List[ChatCompletionMessageParam]):
self._messages[:] = messages
def get_messages(self) -> List[ChatCompletionMessageParam]:
return self.messages
return self._messages
def get_messages_json(self) -> str:
return json.dumps(self.messages, cls=CustomEncoder)
return json.dumps(self._messages, cls=CustomEncoder)
def set_tool_choice(
self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven
):
self.tool_choice = tool_choice
self._tool_choice = tool_choice
def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN):
if tools != NOT_GIVEN and len(tools) == 0:
tools = NOT_GIVEN
self._tools = tools
self.tools = tools
async def call_function(
self,
f: callable,

View File

@@ -336,10 +336,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
if isinstance(context, OpenAILLMContext):
self._context = AnthropicLLMContext.from_openai_context(context)
async def push_messages_frame(self):
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# Our parent method has already called push_frame(). So we can't interrupt the
@@ -415,7 +411,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
size=frame.user_image_raw_frame.size,
image=frame.user_image_raw_frame.image,
text=frame.text)
await self._user_context_aggregator.push_messages_frame()
await self._user_context_aggregator.push_context_frame()
except Exception as e:
logger.error(f"Error processing AnthropicImageMessageFrame: {e}")
@@ -465,7 +461,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
self._context.add_message({"role": "assistant", "content": aggregation})
if run_llm:
await self._user_context_aggregator.push_messages_frame()
await self._user_context_aggregator.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")

View File

@@ -177,7 +177,7 @@ class CartesiaTTSService(TTSService):
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.stop_all_metrics()
await self.push_frame(ErrorFrame(f'{self} error: {msg["error"]}'))
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"Cartesia error, unknown message type: {msg}")
except asyncio.CancelledError:

View File

@@ -355,10 +355,6 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext):
super().__init__(context=context)
async def push_messages_frame(self):
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
@@ -426,8 +422,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._context.add_message({"role": "assistant", "content": aggregation})
if run_llm:
await self._user_context_aggregator.push_messages_frame()
await self._user_context_aggregator.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")