processors(aggregators): multiple LLM aggregators updates
This commit is contained in:
@@ -82,10 +82,6 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
#
|
||||
# and T2 would be dropped.
|
||||
|
||||
async def _set_tools(self, tools: List):
|
||||
# noop in the base class
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -129,34 +125,16 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self._messages.extend(frame.messages)
|
||||
messages_frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(messages_frame)
|
||||
self._add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
# We push the frame downstream so the assistant aggregator gets
|
||||
# updated as well.
|
||||
# TODO-CB: Now we're replacing the contents of the array so we
|
||||
# don't need to push the frame here
|
||||
# await self.push_frame(frame)
|
||||
# We can now reset this one.
|
||||
self._reset()
|
||||
self._set_messages(frame.messages)
|
||||
# messages_frame = LLMMessagesFrame(self._messages)
|
||||
# await self.push_frame(messages_frame)
|
||||
await self.push_messages_frame()
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame)
|
||||
await self._set_tools(frame.tools)
|
||||
self._set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if send_aggregation:
|
||||
await self._push_aggregation()
|
||||
|
||||
# TODO-CB: Types
|
||||
def _set_messages(self, messages):
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
@@ -169,6 +147,19 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# TODO-CB: Types
|
||||
def _add_messages(self, messages):
|
||||
self._messages.extend(messages)
|
||||
|
||||
def _set_messages(self, messages):
|
||||
self._reset()
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
|
||||
def _set_tools(self, tools):
|
||||
# noop in the base class
|
||||
pass
|
||||
|
||||
def _reset(self):
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
@@ -257,23 +248,29 @@ class LLMFullResponseAggregator(FrameProcessor):
|
||||
|
||||
class LLMContextAggregator(LLMResponseAggregator):
|
||||
def __init__(self, *, context: OpenAILLMContext, **kwargs):
|
||||
|
||||
self._context = context
|
||||
super().__init__(**kwargs)
|
||||
# TODO-CB: thanks, I hate it
|
||||
self._messages = context.messages
|
||||
|
||||
|
||||
async def _set_tools(self, tools: List):
|
||||
# We push the frame downstream so the assistant aggregator gets
|
||||
# updated as well.
|
||||
self._context.tools = tools
|
||||
|
||||
# TODO-CB: Types
|
||||
def _set_messages(self, messages):
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
self._context = context
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
return self._context
|
||||
|
||||
def get_context_frame(self) -> OpenAILLMContextFrame:
|
||||
return OpenAILLMContextFrame(context=self._context)
|
||||
|
||||
async def push_context_frame(self):
|
||||
frame = self.get_context_frame()
|
||||
await self.push_frame(frame)
|
||||
|
||||
# TODO-CB: Types
|
||||
def _add_messages(self, messages):
|
||||
self._context.add_messages(messages)
|
||||
|
||||
def _set_messages(self, messages):
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def _set_tools(self, tools: List):
|
||||
self._context.set_tools(tools)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
|
||||
@@ -44,10 +44,10 @@ class OpenAILLMContext:
|
||||
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
|
||||
):
|
||||
self.messages: List[ChatCompletionMessageParam] = messages if messages else [
|
||||
self._messages: List[ChatCompletionMessageParam] = messages if messages else [
|
||||
]
|
||||
self.tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||
self.tools: List[ChatCompletionToolParam] | NotGiven = tools
|
||||
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
|
||||
|
||||
@staticmethod
|
||||
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
||||
@@ -84,26 +84,43 @@ class OpenAILLMContext:
|
||||
})
|
||||
return context
|
||||
|
||||
@property
|
||||
def messages(self) -> List[ChatCompletionMessageParam]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def tools(self) -> List[ChatCompletionToolParam] | NotGiven:
|
||||
return self._tools
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven:
|
||||
return self._tool_choice
|
||||
|
||||
def add_message(self, message: ChatCompletionMessageParam):
|
||||
self.messages.append(message)
|
||||
self._messages.append(message)
|
||||
|
||||
def add_messages(self, messages: List[ChatCompletionMessageParam]):
|
||||
self._messages.extend(messages)
|
||||
|
||||
def set_messages(self, messages: List[ChatCompletionMessageParam]):
|
||||
self._messages[:] = messages
|
||||
|
||||
def get_messages(self) -> List[ChatCompletionMessageParam]:
|
||||
return self.messages
|
||||
return self._messages
|
||||
|
||||
def get_messages_json(self) -> str:
|
||||
return json.dumps(self.messages, cls=CustomEncoder)
|
||||
return json.dumps(self._messages, cls=CustomEncoder)
|
||||
|
||||
def set_tool_choice(
|
||||
self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven
|
||||
):
|
||||
self.tool_choice = tool_choice
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN):
|
||||
if tools != NOT_GIVEN and len(tools) == 0:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
self.tools = tools
|
||||
|
||||
async def call_function(
|
||||
self,
|
||||
f: callable,
|
||||
|
||||
@@ -336,10 +336,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = AnthropicLLMContext.from_openai_context(context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
@@ -415,7 +411,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
size=frame.user_image_raw_frame.size,
|
||||
image=frame.user_image_raw_frame.image,
|
||||
text=frame.text)
|
||||
await self._user_context_aggregator.push_messages_frame()
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing AnthropicImageMessageFrame: {e}")
|
||||
|
||||
@@ -465,7 +461,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_messages_frame()
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -177,7 +177,7 @@ class CartesiaTTSService(TTSService):
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.stop_all_metrics()
|
||||
await self.push_frame(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
else:
|
||||
logger.error(f"Cartesia error, unknown message type: {msg}")
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -355,10 +355,6 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
|
||||
@@ -426,8 +422,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if run_llm:
|
||||
|
||||
await self._user_context_aggregator.push_messages_frame()
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
Reference in New Issue
Block a user