Not invoking on_function_calls_started for the cancel function, and creating on_function_calls_cancelled
This commit is contained in:
@@ -183,7 +183,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
|
||||
- on_completion_timeout: Called when an LLM completion timeout occurs
|
||||
- on_function_calls_started: Called when function calls are received and
|
||||
execution is about to start
|
||||
execution is about to start. Built-in tools (e.g. ``cancel_async_tool_call``)
|
||||
are excluded from this event.
|
||||
- on_function_calls_cancelled: Called after one or more async tool calls are
|
||||
cancelled.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -194,6 +197,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
@task.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
logger.info(f"Starting {len(function_calls)} function calls")
|
||||
|
||||
@task.event_handler("on_function_calls_cancelled")
|
||||
async def on_function_calls_cancelled(service, cancelled):
|
||||
logger.info(f"Cancelled {len(cancelled)} async function calls")
|
||||
"""
|
||||
|
||||
_settings: LLMSettings
|
||||
@@ -252,6 +259,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
self._summary_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_function_calls_cancelled")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
|
||||
def get_llm_adapter(self) -> BaseLLMAdapter:
|
||||
@@ -693,9 +701,14 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
if len(function_calls) == 0:
|
||||
return
|
||||
|
||||
await self._call_event_handler("on_function_calls_started", function_calls)
|
||||
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
|
||||
# Exclude the built-in cancel tool — it's an internal mechanism and
|
||||
# should not be surfaced to user-facing event handlers or frames.
|
||||
user_visible_calls = [
|
||||
fc for fc in function_calls if fc.function_name != CANCEL_ASYNC_TOOL_NAME
|
||||
]
|
||||
if user_visible_calls:
|
||||
await self._call_event_handler("on_function_calls_started", user_visible_calls)
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=user_visible_calls)
|
||||
|
||||
# When group_parallel_tools is True all calls share a group_id so the
|
||||
# aggregator triggers the LLM exactly once after the last one completes.
|
||||
@@ -954,6 +967,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
tool_call_id: tool_call_id to cancel.
|
||||
"""
|
||||
cancelled_tasks = set()
|
||||
cancelled_items = []
|
||||
for task, runner_item in self._function_call_tasks.items():
|
||||
if runner_item.tool_call_id == tool_call_id:
|
||||
name = runner_item.function_name
|
||||
@@ -973,13 +987,18 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
FunctionCallCancelFrame, function_name=name, tool_call_id=tool_call_id
|
||||
)
|
||||
|
||||
cancelled_items.append(runner_item)
|
||||
logger.debug(f"{self} Async function call [{name}:{tool_call_id}] cancelled")
|
||||
|
||||
for task in cancelled_tasks:
|
||||
self._function_call_task_finished(task)
|
||||
|
||||
if cancelled_items:
|
||||
await self._call_event_handler("on_function_calls_cancelled", cancelled_items)
|
||||
|
||||
async def _cancel_function_call(self, function_name: Optional[str]):
|
||||
cancelled_tasks = set()
|
||||
cancelled_items = []
|
||||
for task, runner_item in self._function_call_tasks.items():
|
||||
if runner_item.registry_item.function_name == function_name:
|
||||
name = runner_item.function_name
|
||||
@@ -999,12 +1018,16 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
FunctionCallCancelFrame, function_name=name, tool_call_id=tool_call_id
|
||||
)
|
||||
|
||||
cancelled_items.append(runner_item)
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
for task in cancelled_tasks:
|
||||
self._function_call_task_finished(task)
|
||||
|
||||
if cancelled_items:
|
||||
await self._call_event_handler("on_function_calls_cancelled", cancelled_items)
|
||||
|
||||
def _function_call_task_finished(self, task: asyncio.Task):
|
||||
if task in self._function_call_tasks:
|
||||
del self._function_call_tasks[task]
|
||||
|
||||
Reference in New Issue
Block a user