LLMService: don't allow running functions concurrently for now
This commit is contained in:
@@ -675,7 +675,6 @@ class FunctionCallInProgressFrame(SystemFrame):
|
||||
tool_call_id: str
|
||||
arguments: Any
|
||||
cancel_on_interruption: bool = False
|
||||
run_concurrently: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -603,13 +603,13 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
if frame.result:
|
||||
run_llm = False
|
||||
if properties and properties.run_llm is not None:
|
||||
# If the tool call result has a run_llm property, use it
|
||||
# If the tool call result has a run_llm property, use it.
|
||||
run_llm = properties.run_llm
|
||||
elif frame.run_llm is not None:
|
||||
# If the frame is indicating we should run the LLM, do it.
|
||||
run_llm = frame.run_llm
|
||||
elif in_progress.run_concurrently:
|
||||
# If this was a parallel function call and there are no pending function call, run the LLM.
|
||||
else:
|
||||
# If this is the last function call in progress, run the LLM.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
|
||||
@@ -52,14 +52,12 @@ class FunctionCallItem:
|
||||
function_name (Optional[str]): The name of the function.
|
||||
handler (FunctionCallHandler): The handler for processing function call parameters.
|
||||
cancel_on_interruption (bool): Flag indicating whether to cancel the call on interruption.
|
||||
run_concurrently (bool): Flag to indicate if this function call should run concurrently or not.
|
||||
|
||||
"""
|
||||
|
||||
function_name: Optional[str]
|
||||
handler: FunctionCallHandler
|
||||
cancel_on_interruption: bool
|
||||
run_concurrently: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,7 +74,7 @@ class FunctionCallRunnerItem:
|
||||
|
||||
"""
|
||||
|
||||
registry_name: Optional[str]
|
||||
registry_item: FunctionCallItem
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
@@ -118,7 +116,7 @@ class LLMService(AIService):
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallItem] = {}
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._function_call_runner_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
|
||||
@@ -136,18 +134,17 @@ class LLMService(AIService):
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._function_call_runner_queue = asyncio.Queue()
|
||||
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
|
||||
await self._create_runner_task()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._cancel_function_call(None)
|
||||
await self.cancel_task(self._function_call_runner_task)
|
||||
await self._cancel_function_call()
|
||||
await self._cancel_runner_task()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._cancel_function_call(None)
|
||||
await self.cancel_task(self._function_call_runner_task)
|
||||
await self._cancel_function_call()
|
||||
await self._cancel_runner_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -156,9 +153,9 @@ class LLMService(AIService):
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
async def _handle_interruptions(self, _: StartInterruptionFrame):
|
||||
for function_name, entry in self._functions.items():
|
||||
if entry.cancel_on_interruption:
|
||||
await self._cancel_function_call(function_name)
|
||||
await self._cancel_function_call()
|
||||
await self._cancel_runner_task()
|
||||
await self._create_runner_task()
|
||||
|
||||
def register_function(
|
||||
self,
|
||||
@@ -167,7 +164,6 @@ class LLMService(AIService):
|
||||
start_callback=None,
|
||||
*,
|
||||
cancel_on_interruption: bool = False,
|
||||
run_concurrently: bool = False,
|
||||
):
|
||||
# Registering a function with the function_name set to None will run
|
||||
# that handler for all functions
|
||||
@@ -175,7 +171,6 @@ class LLMService(AIService):
|
||||
function_name=function_name,
|
||||
handler=handler,
|
||||
cancel_on_interruption=cancel_on_interruption,
|
||||
run_concurrently=run_concurrently,
|
||||
)
|
||||
|
||||
# Start callbacks are now deprecated.
|
||||
@@ -212,27 +207,21 @@ class LLMService(AIService):
|
||||
):
|
||||
if function_name in self._functions.keys():
|
||||
item = self._functions[function_name]
|
||||
registry_name = function_name
|
||||
elif None in self._functions.keys():
|
||||
item = self._functions[None]
|
||||
registry_name = None
|
||||
else:
|
||||
return
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_name=registry_name,
|
||||
registry_item=item,
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
context=context,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
if item.run_concurrently:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
else:
|
||||
await self._function_call_runner_queue.put(runner_item)
|
||||
|
||||
await self._function_call_runner_queue.put(runner_item)
|
||||
|
||||
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
@@ -260,11 +249,25 @@ class LLMService(AIService):
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
async def _create_runner_task(self):
|
||||
if not self._function_call_runner_task:
|
||||
self._current_runner: Optional[FunctionCallRunnerItem] = None
|
||||
self._current_task: Optional[asyncio.Task] = None
|
||||
self._function_call_runner_queue = asyncio.Queue()
|
||||
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
|
||||
|
||||
async def _cancel_runner_task(self):
|
||||
if self._function_call_runner_task:
|
||||
await self.cancel_task(self._function_call_runner_task)
|
||||
self._function_call_runner_task = None
|
||||
|
||||
async def _function_call_runner_handler(self):
|
||||
while True:
|
||||
runner_item = await self._function_call_runner_queue.get()
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
await self.wait_for_task(task)
|
||||
self._current_runner = await self._function_call_runner_queue.get()
|
||||
self._current_task = self.create_task(self._run_function_call(self._current_runner))
|
||||
await self.wait_for_task(self._current_task)
|
||||
self._current_runner = None
|
||||
self._current_task = None
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
@@ -291,14 +294,12 @@ class LLMService(AIService):
|
||||
tool_call_id=runner_item.tool_call_id,
|
||||
arguments=runner_item.arguments,
|
||||
cancel_on_interruption=item.cancel_on_interruption,
|
||||
run_concurrently=item.run_concurrently,
|
||||
)
|
||||
progress_frame_upstream = FunctionCallInProgressFrame(
|
||||
function_name=runner_item.function_name,
|
||||
tool_call_id=runner_item.tool_call_id,
|
||||
arguments=runner_item.arguments,
|
||||
cancel_on_interruption=item.cancel_on_interruption,
|
||||
run_concurrently=item.run_concurrently,
|
||||
)
|
||||
|
||||
# Push frame both downstream and upstream
|
||||
@@ -359,37 +360,20 @@ class LLMService(AIService):
|
||||
)
|
||||
await item.handler(params)
|
||||
|
||||
async def _cancel_function_call(self, function_name: Optional[str]):
|
||||
cancelled_tasks = set()
|
||||
for task, runner_item in self._function_call_tasks.items():
|
||||
if runner_item.registry_name == function_name:
|
||||
name = runner_item.function_name
|
||||
tool_call_id = runner_item.tool_call_id
|
||||
async def _cancel_function_call(self):
|
||||
if (
|
||||
self._current_runner
|
||||
and self._current_task
|
||||
and self._current_runner.registry_item.cancel_on_interruption
|
||||
):
|
||||
name = self._current_runner.function_name
|
||||
tool_call_id = self._current_runner.tool_call_id
|
||||
|
||||
# We remove the callback because we are going to cancel the task
|
||||
# now, otherwise we will be removing it from the set while we
|
||||
# are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
await self.cancel_task(self._current_task)
|
||||
|
||||
await self.cancel_task(task)
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
for task in cancelled_tasks:
|
||||
self._function_call_task_finished(task)
|
||||
|
||||
def _function_call_task_finished(self, task: asyncio.Task):
|
||||
if task in self._function_call_tasks:
|
||||
del self._function_call_tasks[task]
|
||||
# The task is finished so this should exit immediately. We need to
|
||||
# do this because otherwise the task manager would report a dangling
|
||||
# task if we don't remove it.
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
Reference in New Issue
Block a user