Fold BaseTask.handoff_to into activate_task(deactivate_self=...)

BaseTask.handoff_to was just deactivate_self + activate_task. Remove
it and add a deactivate_self flag on activate_task instead, so there's
one entry point for activating another task.

LLMTask now overrides activate_task (mirroring its end() override) to
keep the messages / result_callback hooks that finish an in-progress
tool call before the target is activated. All multi-task examples and
unit tests switch to the new call.
This commit is contained in:
Aleix Conchillo Flaqué
2026-05-20 15:50:03 -07:00
parent e8bbb5ee09
commit 373894fc65
8 changed files with 39 additions and 50 deletions

View File

@@ -68,7 +68,7 @@ uv run local-handoff/local-handoff-two-agents.py --transport daily
### Overview
- **[`local-handoff-two-agents.py`](local-handoff/local-handoff-two-agents.py)** — Two LLM tasks (greeter + support) that hand off via `BaseTask.handoff_to(...)`. The main task owns STT, TTS, transport, and a `BusBridgeProcessor`.
- **[`local-handoff-two-agents.py`](local-handoff/local-handoff-two-agents.py)** — Two LLM tasks (greeter + support) that hand off via `activate_task(..., deactivate_self=True)`. The main task owns STT, TTS, transport, and a `BusBridgeProcessor`.
- **[`local-handoff-two-agents-tts.py`](local-handoff/local-handoff-two-agents-tts.py)** — Same shape, but each child task ships with its own `CartesiaTTSService` in a custom pipeline. The main task has no TTS — audio comes from whichever child is active over the bus.
## Parallel debate

View File

@@ -122,11 +122,10 @@ class AcmeLLMTask(LLMTask):
reason (str): Why the user is being transferred.
"""
logger.info(f"Task '{self.name}': transferring to '{agent}' ({reason})")
await self.handoff_to(
await self.activate_task(
agent,
activation_args=LLMTaskActivationArgs(
messages=[{"role": "developer", "content": reason}]
),
args=LLMTaskActivationArgs(messages=[{"role": "developer", "content": reason}]),
deactivate_self=True,
result_callback=params.result_callback,
)

View File

@@ -104,11 +104,10 @@ class AcmeLLMTask(LLMTask):
reason (str): Why the user is being transferred.
"""
logger.info(f"Task '{self.name}': transferring to '{agent}' ({reason})")
await self.handoff_to(
await self.activate_task(
agent,
activation_args=LLMTaskActivationArgs(
messages=[{"role": "developer", "content": reason}]
),
args=LLMTaskActivationArgs(messages=[{"role": "developer", "content": reason}]),
deactivate_self=True,
result_callback=params.result_callback,
)

View File

@@ -106,7 +106,7 @@ class AcmeTTSTask(LLMTask):
reason (str): Why the user is being transferred.
"""
logger.info(f"Task '{self.name}': transferring to '{agent}' ({reason})")
await self.handoff_to(
await self.activate_task(
agent,
messages=[
{
@@ -114,9 +114,10 @@ class AcmeTTSTask(LLMTask):
"content": f"Tell the user about the transfer ({reason}).",
}
],
activation_args=LLMTaskActivationArgs(
args=LLMTaskActivationArgs(
messages=[{"role": "developer", "content": reason}],
),
deactivate_self=True,
result_callback=params.result_callback,
)

View File

@@ -85,11 +85,12 @@ class AcmeLLMTask(LLMTask):
reason (str): Why the user is being transferred.
"""
logger.info(f"Task '{self.name}': transferring to '{agent}' ({reason})")
await self.handoff_to(
await self.activate_task(
agent,
activation_args=LLMTaskActivationArgs(
args=LLMTaskActivationArgs(
messages=[{"role": "developer", "content": reason}],
),
deactivate_self=True,
result_callback=params.result_callback,
)

View File

@@ -565,6 +565,7 @@ class BaseTask(BaseObject, BusSubscriber):
task_name: str,
*,
args: TaskActivationArgs | None = None,
deactivate_self: bool = False,
) -> None:
"""Activate a task by name.
@@ -575,7 +576,11 @@ class BaseTask(BaseObject, BusSubscriber):
task_name: The name of the task to activate.
args: Optional ``TaskActivationArgs`` forwarded to the
target task's ``on_activated``.
deactivate_self: Whether to deactivate this task before activating
the target.
"""
if self._active and deactivate_self:
await self.deactivate_task(self.name)
await self.send_message(
BusActivateTaskMessage(
source=self.name, target=task_name, args=args.to_dict() if args else None
@@ -592,26 +597,6 @@ class BaseTask(BaseObject, BusSubscriber):
"""
await self.send_message(BusDeactivateTaskMessage(source=self.name, target=task_name))
async def handoff_to(
self,
task_name: str,
*,
activation_args: TaskActivationArgs | None = None,
) -> None:
"""Hand off to another task.
Deactivates this task and activates the target. For independent
control, use ``activate_task()`` and ``deactivate_task()`` directly.
Args:
task_name: The name of the task to hand off to.
activation_args: Optional arguments forwarded to the target
task's ``on_activated`` handler.
"""
if self._active:
await self.deactivate_task(self.name)
await self.activate_task(task_name, args=activation_args)
async def watch_task(self, task_name: str) -> None:
"""Request notification when a task registers.

View File

@@ -231,30 +231,34 @@ class LLMTask(PipelineTask):
await self._finish_function_call(result_callback, messages=messages)
await super().end(reason=reason)
async def handoff_to(
async def activate_task(
self,
task_name: str,
*,
activation_args: TaskActivationArgs | None = None,
args: TaskActivationArgs | None = None,
deactivate_self: bool = False,
messages: list | None = None,
result_callback: FunctionCallResultCallback | None = None,
) -> None:
"""Hand off to another task.
"""Activate another task, optionally finishing an in-progress tool call.
When called from a ``@tool`` handler, pass ``params.result_callback`` to
ensure any pending LLM output is fully delivered before handing off.
ensure any pending LLM output is fully delivered before the target is
activated.
Args:
task_name: The name of the task to hand off to.
activation_args: Optional arguments forwarded to the target
task_name: The name of the task to activate.
args: Optional ``TaskActivationArgs`` forwarded to the target
task's ``on_activated`` handler.
deactivate_self: Whether to deactivate this task before activating
the target.
messages: Optional LLM messages to inject and speak before
handing off. The LLM runs immediately so the output is
delivered before the transfer completes.
activating the target. The LLM runs immediately so the output
is delivered before the transfer completes.
result_callback: The ``result_callback`` from `FunctionCallParams`.
"""
await self._finish_function_call(result_callback, messages=messages)
await super().handoff_to(task_name, activation_args=activation_args)
await super().activate_task(task_name, args=args, deactivate_self=deactivate_self)
async def process_deferred_tool_frames(
self, frames: list[tuple[Frame, FrameDirection]]

View File

@@ -200,14 +200,14 @@ class TestPipelineTaskLifecycle(unittest.IsolatedAsyncioTestCase):
task._pending_activation = False
self.assertIsNone(task.activation_args)
async def test_handoff_to_sends_activate_and_deactivates(self):
"""handoff_to() sends BusDeactivateTaskMessage and BusActivateTaskMessage."""
async def test_activate_task_with_deactivate_self_sends_both_messages(self):
"""activate_task(deactivate_self=True) sends deactivate then activate."""
sent = capture_bus(self.bus)
task = make_stub_pipeline_task("task_a", bridged=())
task.attach(registry=self.registry, bus=self.bus)
await task.handoff_to("task_b")
await task.activate_task("task_b", deactivate_self=True)
deactivate_msgs = [m for m in sent if isinstance(m, BusDeactivateTaskMessage)]
self.assertEqual(len(deactivate_msgs), 1)
@@ -331,14 +331,14 @@ class TestPipelineTaskLifecycle(unittest.IsolatedAsyncioTestCase):
self.assertTrue(finished_fired.is_set())
async def test_handoff_deactivates(self):
"""handoff_to() sends a deactivate message for the calling task."""
async def test_activate_task_with_deactivate_self_deactivates(self):
"""activate_task(deactivate_self=True) sends a deactivate for the calling task."""
sent = capture_bus(self.bus)
task = make_stub_pipeline_task("test", bridged=())
task.attach(registry=self.registry, bus=self.bus)
self.assertTrue(task.active)
await task.handoff_to("other")
await task.activate_task("other", deactivate_self=True)
deactivate_msgs = [m for m in sent if isinstance(m, BusDeactivateTaskMessage)]
self.assertEqual(len(deactivate_msgs), 1)
self.assertEqual(deactivate_msgs[0].target, "test")
@@ -429,7 +429,7 @@ class TestPipelineTaskLifecycle(unittest.IsolatedAsyncioTestCase):
self.assertEqual(received[1].text, "b")
async def test_self_handoff(self):
"""A task can handoff to itself via handoff_to(self.name)."""
"""A task can hand off to itself via activate_task(self.name, deactivate_self=True)."""
task = make_stub_pipeline_task("test", bridged=())
handoff_done = asyncio.Event()
@@ -442,7 +442,7 @@ class TestPipelineTaskLifecycle(unittest.IsolatedAsyncioTestCase):
# Wait for first activation (from active=True)
await asyncio.sleep(0.05)
handoff_done.clear()
await task.handoff_to("test")
await task.activate_task("test", deactivate_self=True)
await asyncio.wait_for(handoff_done.wait(), timeout=2.0)
await task.queue_frame(EndFrame())