LLMService: allow executing tasks sequentially and in parallel
This commit is contained in:
@@ -11,10 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Function calls can now be executed sequentially (in the order received in the
|
||||
completion) by passing `run_in_parallel=False` when creating your LLM
|
||||
service. By default, function calls run in parallel, so if the LLM completion
|
||||
returns 2 or more function calls they run concurrently. In both cases,
|
||||
concurrently and sequentailly, a new LLM completion will run when the last
|
||||
function call finishes.
|
||||
service. By default, if the LLM completion returns 2 or more function calls
|
||||
they run concurrently. In both cases, concurrently and sequentially, a new LLM
|
||||
completion will run when the last function call finishes.
|
||||
|
||||
- Added OpenTelemetry tracing for `GeminiMultimodalLiveLLMService` and
|
||||
`OpenAIRealtimeBetaLLMService`.
|
||||
|
||||
@@ -591,8 +591,6 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
)
|
||||
return
|
||||
|
||||
in_progress = self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
properties = frame.properties
|
||||
|
||||
@@ -97,7 +97,7 @@ class FunctionCallRunner:
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
context: OpenAILLMContext
|
||||
run_llm: bool
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -129,12 +129,14 @@ class LLMService(AIService):
|
||||
# However, subclasses should override this with a more specific adapter when necessary.
|
||||
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, run_in_parallel: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_runner_task: Optional[asyncio.Task] = None
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunner] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
|
||||
@@ -152,17 +154,18 @@ class LLMService(AIService):
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._create_runner_task()
|
||||
if not self._run_in_parallel:
|
||||
await self._create_sequential_runner_task()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._cancel_function_call()
|
||||
await self._cancel_runner_task()
|
||||
if not self._run_in_parallel:
|
||||
await self._cancel_sequential_runner_task()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._cancel_function_call()
|
||||
await self._cancel_runner_task()
|
||||
if not self._run_in_parallel:
|
||||
await self._cancel_sequential_runner_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -171,9 +174,9 @@ class LLMService(AIService):
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
async def _handle_interruptions(self, _: StartInterruptionFrame):
|
||||
await self._cancel_function_call()
|
||||
await self._cancel_runner_task()
|
||||
await self._create_runner_task()
|
||||
for function_name, entry in self._functions.items():
|
||||
if entry.cancel_on_interruption:
|
||||
await self._cancel_function_call(function_name)
|
||||
|
||||
def register_function(
|
||||
self,
|
||||
@@ -227,8 +230,12 @@ class LLMService(AIService):
|
||||
)
|
||||
continue
|
||||
|
||||
# Run inference on the last function call.
|
||||
run_llm = index == total_function_calls - 1
|
||||
# If we are not running in parallel, run inference on the last
|
||||
# function call. Otherwise, the last function call to finish is the
|
||||
# one that will run the inference.
|
||||
run_llm = None
|
||||
if not self._run_in_parallel:
|
||||
run_llm = index == total_function_calls - 1
|
||||
|
||||
runner_item = FunctionCallRunner(
|
||||
registry_item=item,
|
||||
@@ -239,7 +246,12 @@ class LLMService(AIService):
|
||||
run_llm=run_llm,
|
||||
)
|
||||
|
||||
await self._function_call_runner_queue.put(runner_item)
|
||||
if self._run_in_parallel:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
else:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
@@ -267,25 +279,25 @@ class LLMService(AIService):
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
async def _create_runner_task(self):
|
||||
if not self._function_call_runner_task:
|
||||
self._current_runner: Optional[FunctionCallRunner] = None
|
||||
self._current_task: Optional[asyncio.Task] = None
|
||||
self._function_call_runner_queue = asyncio.Queue()
|
||||
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
|
||||
async def _create_sequential_runner_task(self):
|
||||
if not self._sequential_runner_task:
|
||||
self._sequential_runner_queue = asyncio.Queue()
|
||||
self._sequential_runner_task = self.create_task(self._sequential_runner_handler())
|
||||
|
||||
async def _cancel_runner_task(self):
|
||||
if self._function_call_runner_task:
|
||||
await self.cancel_task(self._function_call_runner_task)
|
||||
self._function_call_runner_task = None
|
||||
async def _cancel_sequential_runner_task(self):
|
||||
if self._sequential_runner_task:
|
||||
await self.cancel_task(self._sequential_runner_task)
|
||||
self._sequential_runner_task = None
|
||||
|
||||
async def _function_call_runner_handler(self):
|
||||
async def _sequential_runner_handler(self):
|
||||
while True:
|
||||
self._current_runner = await self._function_call_runner_queue.get()
|
||||
self._current_task = self.create_task(self._run_function_call(self._current_runner))
|
||||
await self.wait_for_task(self._current_task)
|
||||
self._current_runner = None
|
||||
self._current_task = None
|
||||
runner_item = await self._sequential_runner_queue.get()
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
# Since we run tasks sequentially we don't need to call
|
||||
# task.add_done_callback(self._function_call_task_finished).
|
||||
await self.wait_for_task(task)
|
||||
del self._function_call_tasks[task]
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunner):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
@@ -378,20 +390,37 @@ class LLMService(AIService):
|
||||
)
|
||||
await item.handler(params)
|
||||
|
||||
async def _cancel_function_call(self):
|
||||
if (
|
||||
self._current_runner
|
||||
and self._current_task
|
||||
and self._current_runner.registry_item.cancel_on_interruption
|
||||
):
|
||||
name = self._current_runner.function_name
|
||||
tool_call_id = self._current_runner.tool_call_id
|
||||
async def _cancel_function_call(self, function_name: Optional[str]):
|
||||
cancelled_tasks = set()
|
||||
for task, runner_item in self._function_call_tasks.items():
|
||||
if runner_item.registry_item.function_name == function_name:
|
||||
name = runner_item.function_name
|
||||
tool_call_id = runner_item.tool_call_id
|
||||
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
# We remove the callback because we are going to cancel the task
|
||||
# now, otherwise we will be removing it from the set while we
|
||||
# are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
|
||||
await self.cancel_task(self._current_task)
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
await self.cancel_task(task)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
for task in cancelled_tasks:
|
||||
self._function_call_task_finished(task)
|
||||
|
||||
def _function_call_task_finished(self, task: asyncio.Task):
|
||||
if task in self._function_call_tasks:
|
||||
del self._function_call_tasks[task]
|
||||
# The task is finished so this should exit immediately. We need to
|
||||
# do this because otherwise the task manager would report a dangling
|
||||
# task if we don't remove it.
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
|
||||
|
||||
Reference in New Issue
Block a user