LLMAssistantAggregator: add assistant turn and thought events
This commit is contained in:
@@ -118,6 +118,38 @@ class UserTurnStoppedMessage:
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantTurnStoppedMessage:
|
||||
"""An assistant turn stopped message containing an assistant transcript update.
|
||||
|
||||
A message in a conversation transcript containing the assistant
|
||||
content. This is the aggregated transcript that is then used in the context.
|
||||
|
||||
Parameters:
|
||||
content: The message content/text.
|
||||
|
||||
"""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantThoughtMessage:
|
||||
"""An assistant thought message containing an assistant thought update.
|
||||
|
||||
A message in a conversation transcript containing the assistant thought
|
||||
content.
|
||||
|
||||
Parameters:
|
||||
content: The message content/text.
|
||||
timestamp: When the thought was started.
|
||||
|
||||
"""
|
||||
|
||||
content: str
|
||||
timestamp: str
|
||||
|
||||
|
||||
class LLMContextAggregator(FrameProcessor):
|
||||
"""Base LLM aggregator that uses an LLMContext for conversation storage.
|
||||
|
||||
@@ -537,6 +569,27 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
The aggregator manages function calls in progress and coordinates between
|
||||
text generation and tool execution phases of LLM responses.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_assistant_turn_started: Called when the assistant turn starts
|
||||
- on_assistant_turn_stopped: Called when the assistant turn ends
|
||||
- on_assistant_thought: Called when an assistant thought is available
|
||||
|
||||
Example::
|
||||
|
||||
@aggregator.event_handler("on_assistant_turn_started")
|
||||
async def on_assistant_turn_started(aggregator):
|
||||
...
|
||||
|
||||
@aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(aggregator, message: AssistantTurnStoppedMessage):
|
||||
...
|
||||
|
||||
@aggregator.event_handler("on_assistant_thought")
|
||||
async def on_assistant_thought(aggregator, message: AssistantThoughtMessage):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -583,6 +636,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._thought_aggregation_enabled = False
|
||||
self._thought_llm: str = ""
|
||||
self._thought_aggregation: List[TextPartForConcatenation] = []
|
||||
self._thought_start_time: str = ""
|
||||
|
||||
self._register_event_handler("on_assistant_turn_started")
|
||||
self._register_event_handler("on_assistant_turn_stopped")
|
||||
self._register_event_handler("on_assistant_thought")
|
||||
|
||||
@property
|
||||
def has_function_calls_in_progress(self) -> bool:
|
||||
@@ -661,8 +719,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
aggregation = self.aggregation_string()
|
||||
await self.reset()
|
||||
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
# Push context frame
|
||||
await self.push_context_frame()
|
||||
@@ -687,7 +744,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self.push_aggregation()
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
@@ -810,7 +867,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
text=frame.text,
|
||||
)
|
||||
|
||||
await self.push_aggregation()
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_assistant_image_frame(self, frame: AssistantImageRawFrame):
|
||||
@@ -833,10 +890,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started += 1
|
||||
await self._trigger_assistant_turn_started()
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started -= 1
|
||||
await self.push_aggregation()
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started or not frame.append_to_context:
|
||||
@@ -859,6 +917,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._reset_thought_aggregation()
|
||||
self._thought_aggregation_enabled = frame.append_to_context
|
||||
self._thought_llm = frame.llm
|
||||
self._thought_start_time = time_now_iso8601()
|
||||
|
||||
async def _handle_thought_text(self, frame: LLMThoughtTextFrame):
|
||||
if not self._started or not self._thought_aggregation_enabled:
|
||||
@@ -893,9 +952,21 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
)
|
||||
|
||||
message = AssistantThoughtMessage(content=thought, timestamp=self._thought_start_time)
|
||||
await self._call_event_handler("on_assistant_thought", message)
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
|
||||
async def _trigger_assistant_turn_started(self):
|
||||
await self._call_event_handler("on_assistant_turn_started")
|
||||
|
||||
async def _trigger_assistant_turn_stopped(self):
|
||||
aggregation = await self.push_aggregation()
|
||||
if aggregation:
|
||||
message = AssistantTurnStoppedMessage(content=aggregation)
|
||||
await self._call_event_handler("on_assistant_turn_stopped", message)
|
||||
|
||||
|
||||
class LLMContextAggregatorPair:
|
||||
"""Pair of LLM context aggregators for updating context with user and assistant messages."""
|
||||
|
||||
Reference in New Issue
Block a user