LLMService: allow executing tasks sequentially and in parallel

This commit is contained in:
Aleix Conchillo Flaqué
2025-05-20 22:23:26 -07:00
parent 4809684a13
commit 04bf85ddfe
3 changed files with 75 additions and 49 deletions

View File

@@ -11,10 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Function calls can now be executed sequentially (in the order received in the
completion) by passing `run_in_parallel=False` when creating your LLM
service. By default, function calls run in parallel, so if the LLM completion
returns 2 or more function calls they run concurrently. In both cases,
concurrently and sequentailly, a new LLM completion will run when the last
function call finishes.
service. By default, if the LLM completion returns 2 or more function calls
they run concurrently. In both cases, concurrently and sequentially, a new LLM
completion will run when the last function call finishes.
- Added OpenTelemetry tracing for `GeminiMultimodalLiveLLMService` and
`OpenAIRealtimeBetaLLMService`.

View File

@@ -591,8 +591,6 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
)
return
in_progress = self._function_calls_in_progress[frame.tool_call_id]
del self._function_calls_in_progress[frame.tool_call_id]
properties = frame.properties

View File

@@ -97,7 +97,7 @@ class FunctionCallRunner:
tool_call_id: str
arguments: Mapping[str, Any]
context: OpenAILLMContext
run_llm: bool
run_llm: Optional[bool] = None
@dataclass
@@ -129,12 +129,14 @@ class LLMService(AIService):
# However, subclasses should override this with a more specific adapter when necessary.
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
def __init__(self, **kwargs):
def __init__(self, run_in_parallel: bool = True, **kwargs):
super().__init__(**kwargs)
self._run_in_parallel = run_in_parallel
self._start_callbacks = {}
self._adapter = self.adapter_class()
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
self._function_call_runner_task: Optional[asyncio.Task] = None
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunner] = {}
self._sequential_runner_task: Optional[asyncio.Task] = None
self._register_event_handler("on_completion_timeout")
@@ -152,17 +154,18 @@ class LLMService(AIService):
async def start(self, frame: StartFrame):
await super().start(frame)
await self._create_runner_task()
if not self._run_in_parallel:
await self._create_sequential_runner_task()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._cancel_function_call()
await self._cancel_runner_task()
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._cancel_function_call()
await self._cancel_runner_task()
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -171,9 +174,9 @@ class LLMService(AIService):
await self._handle_interruptions(frame)
async def _handle_interruptions(self, _: StartInterruptionFrame):
await self._cancel_function_call()
await self._cancel_runner_task()
await self._create_runner_task()
for function_name, entry in self._functions.items():
if entry.cancel_on_interruption:
await self._cancel_function_call(function_name)
def register_function(
self,
@@ -227,8 +230,12 @@ class LLMService(AIService):
)
continue
# Run inference on the last function call.
run_llm = index == total_function_calls - 1
# If we are not running in parallel, run inference on the last
# function call. Otherwise, the last function call to finish is the
# one that will run the inference.
run_llm = None
if not self._run_in_parallel:
run_llm = index == total_function_calls - 1
runner_item = FunctionCallRunner(
registry_item=item,
@@ -239,7 +246,12 @@ class LLMService(AIService):
run_llm=run_llm,
)
await self._function_call_runner_queue.put(runner_item)
if self._run_in_parallel:
task = self.create_task(self._run_function_call(runner_item))
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
else:
await self._sequential_runner_queue.put(runner_item)
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
if function_name in self._start_callbacks.keys():
@@ -267,25 +279,25 @@ class LLMService(AIService):
FrameDirection.UPSTREAM,
)
async def _create_runner_task(self):
if not self._function_call_runner_task:
self._current_runner: Optional[FunctionCallRunner] = None
self._current_task: Optional[asyncio.Task] = None
self._function_call_runner_queue = asyncio.Queue()
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
async def _create_sequential_runner_task(self):
if not self._sequential_runner_task:
self._sequential_runner_queue = asyncio.Queue()
self._sequential_runner_task = self.create_task(self._sequential_runner_handler())
async def _cancel_runner_task(self):
if self._function_call_runner_task:
await self.cancel_task(self._function_call_runner_task)
self._function_call_runner_task = None
async def _cancel_sequential_runner_task(self):
if self._sequential_runner_task:
await self.cancel_task(self._sequential_runner_task)
self._sequential_runner_task = None
async def _function_call_runner_handler(self):
async def _sequential_runner_handler(self):
while True:
self._current_runner = await self._function_call_runner_queue.get()
self._current_task = self.create_task(self._run_function_call(self._current_runner))
await self.wait_for_task(self._current_task)
self._current_runner = None
self._current_task = None
runner_item = await self._sequential_runner_queue.get()
task = self.create_task(self._run_function_call(runner_item))
self._function_call_tasks[task] = runner_item
# Since we run tasks sequentially we don't need to call
# task.add_done_callback(self._function_call_task_finished).
await self.wait_for_task(task)
del self._function_call_tasks[task]
async def _run_function_call(self, runner_item: FunctionCallRunner):
if runner_item.function_name in self._functions.keys():
@@ -378,20 +390,37 @@ class LLMService(AIService):
)
await item.handler(params)
async def _cancel_function_call(self):
if (
self._current_runner
and self._current_task
and self._current_runner.registry_item.cancel_on_interruption
):
name = self._current_runner.function_name
tool_call_id = self._current_runner.tool_call_id
async def _cancel_function_call(self, function_name: Optional[str]):
cancelled_tasks = set()
for task, runner_item in self._function_call_tasks.items():
if runner_item.registry_item.function_name == function_name:
name = runner_item.function_name
tool_call_id = runner_item.tool_call_id
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
# We remove the callback because we are going to cancel the task
# now, otherwise we will be removing it from the set while we
# are iterating.
task.remove_done_callback(self._function_call_task_finished)
await self.cancel_task(self._current_task)
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
await self.push_frame(frame)
await self.cancel_task(task)
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
await self.push_frame(frame)
cancelled_tasks.add(task)
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
# Remove all cancelled tasks from our set.
for task in cancelled_tasks:
self._function_call_task_finished(task)
def _function_call_task_finished(self, task: asyncio.Task):
if task in self._function_call_tasks:
del self._function_call_tasks[task]
# The task is finished so this should exit immediately. We need to
# do this because otherwise the task manager would report a dangling
# task if we don't remove it.
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())