Add job context and decorators

Adds `JobContext` / `JobGroupContext` async context managers,
the `JobGroup` / `JobGroupEvent` / `JobGroupResponse` /
`JobGroupError` types, the `@job` decorator (with collector),
and the `@task_ready` decorator (with collector). These power
the bus-driven job RPC between tasks.
This commit is contained in:
Aleix Conchillo Flaqué
2026-05-13 19:14:01 -07:00
parent 7e2055b7d0
commit c0b2a8c572
3 changed files with 489 additions and 0 deletions

View File

@@ -0,0 +1,362 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Task group types for structured concurrent task execution."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from enum import StrEnum
from typing import TYPE_CHECKING, ClassVar
if TYPE_CHECKING:
from pipecat.pipeline.base_task import BaseTask
class JobStatus(StrEnum):
"""Status of a completed task.
Inherits from ``str`` so values compare naturally with plain strings
and serialize without extra handling.
Attributes:
COMPLETED: The task finished successfully.
CANCELLED: The task was cancelled by the requester.
FAILED: The task failed due to a logical or business error.
ERROR: The task encountered an unexpected runtime error.
"""
COMPLETED = "completed"
CANCELLED = "cancelled"
FAILED = "failed"
ERROR = "error"
class JobError(Exception):
"""Raised when a task is cancelled due to a worker error or timeout."""
pass
class JobGroupError(Exception):
"""Raised when a task group is cancelled due to a worker error or timeout."""
pass
@dataclass
class JobGroupResponse:
"""Collected results from a completed task group.
Parameters:
job_id: The shared task identifier.
responses: Collected responses keyed by worker name.
"""
job_id: str
responses: dict[str, dict]
@dataclass
class JobEvent:
"""An event received from a worker during a single-worker job.
Parameters:
type: The event type.
data: Optional event payload.
"""
UPDATE: ClassVar[str] = "update"
STREAM_START: ClassVar[str] = "stream_start"
STREAM_DATA: ClassVar[str] = "stream_data"
STREAM_END: ClassVar[str] = "stream_end"
type: str
data: dict | None = None
@dataclass
class JobGroupEvent:
"""An event received from a worker during task group execution.
Parameters:
type: The event type.
task_name: The name of the worker that sent the event.
data: Optional event payload.
"""
UPDATE: ClassVar[str] = "update"
STREAM_START: ClassVar[str] = "stream_start"
STREAM_DATA: ClassVar[str] = "stream_data"
STREAM_END: ClassVar[str] = "stream_end"
type: str
task_name: str
data: dict | None = None
@dataclass
class JobGroup:
"""Tracks a group of workers launched together.
Parameters:
job_id: Shared identifier for all workers in this group.
task_names: Names of the workers in the group.
responses: Collected responses keyed by worker name.
timeout_task: Optional asyncio task that cancels the group on timeout.
cancel_on_error: Whether to cancel the group if a worker errors.
event_queue: Optional queue for streaming events to a
``JobGroupContext`` async iterator.
"""
job_id: str
task_names: set[str]
responses: dict[str, dict] = field(default_factory=dict)
timeout_task: asyncio.Task | None = None
cancel_on_error: bool = True
event_queue: asyncio.Queue | None = field(default=None, repr=False)
_done: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
_error: str | None = field(default=None, repr=False)
@property
def is_done(self) -> bool:
"""Whether the group has completed or failed."""
return self._done.is_set()
async def wait(self) -> None:
"""Wait for all workers in the group to respond.
Raises:
JobGroupError: If the group was cancelled due to error or timeout.
"""
await self._done.wait()
if self._error:
raise JobGroupError(self._error)
def complete(self) -> None:
"""Signal that all workers have responded."""
self._done.set()
if self.event_queue:
self.event_queue.put_nowait(None)
def fail(self, reason: str | None = None) -> None:
"""Signal that the group was cancelled.
Args:
reason: Human-readable reason for the failure.
"""
self._error = reason
self._done.set()
if self.event_queue:
self.event_queue.put_nowait(None)
class JobGroupContext:
"""Async context manager and iterator for structured task group execution.
Sends task requests on enter, waits for all responses on exit.
Supports ``async for`` to receive intermediate events (updates
and streaming data) from workers while waiting for completion.
On normal completion, results are available via ``responses``.
On worker error (with ``cancel_on_error=True``) or timeout, raises
``JobGroupError``. If the ``async with`` block raises, remaining
tasks are cancelled.
Example::
async with self.job_group("w1", "w2", payload=data) as tg:
async for event in tg:
print(f"{event.task_name} [{event.type}]: {event.data}")
for name, result in tg.responses.items():
print(name, result)
"""
def __init__(
self,
task: BaseTask,
task_names: tuple[str, ...],
*,
name: str | None = None,
payload: dict | None = None,
timeout: float | None = None,
cancel_on_error: bool = True,
):
"""Initialize the JobGroupContext.
Args:
task: The parent `BaseTask` that owns this job group.
task_names: Names of the workers to send the job to.
name: Optional task name for routing to named handlers.
payload: Optional structured data describing the work.
timeout: Optional timeout in seconds covering both the
ready-wait and task execution.
cancel_on_error: Whether to cancel the group if a worker
errors. Defaults to True.
"""
self._task = task
self._task_names = task_names
self._name = name
self._payload = payload
self._timeout = timeout
self._cancel_on_error = cancel_on_error
self._group: JobGroup | None = None
@property
def job_id(self) -> str:
"""The shared task identifier for this group."""
if not self._group:
raise RuntimeError("Task group has not been started")
return self._group.job_id
@property
def responses(self) -> dict[str, dict]:
"""Collected responses keyed by worker name."""
if not self._group:
raise RuntimeError("Task group has not been started")
return self._group.responses
def __aiter__(self):
return self
async def __anext__(self) -> JobGroupEvent:
if not self._group or not self._group.event_queue:
raise StopAsyncIteration
event = await self._group.event_queue.get()
if event is None:
raise StopAsyncIteration
return event
async def __aenter__(self) -> JobGroupContext:
self._group = await self._task.create_job_group_and_request_job(
list(self._task_names),
name=self._name,
payload=self._payload,
timeout=self._timeout,
cancel_on_error=self._cancel_on_error,
)
self._group.event_queue = asyncio.Queue()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool:
if exc_type is not None:
if self._group and self._group.job_id in self._task.job_groups:
# Shield the cleanup so it completes even if the
# surrounding task is being cancelled (e.g. tool
# interruption).
await asyncio.shield(
self._task.cancel_job_group(
self._group.job_id, reason="context exited with error"
)
)
return False
await self._group.wait()
return False
class JobContext:
"""Async context manager and iterator for a single-worker job.
Sends a task request on enter, waits for the response on exit.
Supports ``async for`` to receive intermediate events (updates
and streaming data) from the worker while waiting for completion.
On normal completion, the result is available via ``response``.
On worker error or timeout, raises ``JobError``. If the
``async with`` block raises, the task is cancelled.
Example::
async with self.job("worker", payload=data) as t:
async for event in t:
print(f"[{event.type}]: {event.data}")
print(t.response)
"""
def __init__(
self,
task: BaseTask,
task_name: str,
*,
name: str | None = None,
payload: dict | None = None,
timeout: float | None = None,
):
"""Initialize the JobContext.
Args:
task: The parent `BaseTask` that owns this job.
task_name: Name of the worker to send the job to.
name: Optional task name for routing to a named handler.
payload: Optional structured data describing the work.
timeout: Optional timeout in seconds covering both the
ready-wait and task execution.
"""
self._task = task
self._task_name = task_name
self._name = name
self._payload = payload
self._timeout = timeout
self._group: JobGroup | None = None
@property
def job_id(self) -> str:
"""The task identifier."""
if not self._group:
raise RuntimeError("Task has not been started")
return self._group.job_id
@property
def response(self) -> dict:
"""The worker's response payload."""
if not self._group:
raise RuntimeError("Task has not been started")
return self._group.responses.get(self._task_name, {})
def __aiter__(self):
return self
async def __anext__(self) -> JobEvent:
if not self._group or not self._group.event_queue:
raise StopAsyncIteration
event = await self._group.event_queue.get()
if event is None:
raise StopAsyncIteration
return JobEvent(type=event.type, data=event.data)
async def __aenter__(self) -> JobContext:
self._group = await self._task.create_job_group_and_request_job(
[self._task_name],
name=self._name,
payload=self._payload,
timeout=self._timeout,
cancel_on_error=True,
)
self._group.event_queue = asyncio.Queue()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool:
if exc_type is not None:
if self._group and self._group.job_id in self._task.job_groups:
# Shield the cleanup so it completes even if the
# surrounding task is being cancelled (e.g. tool
# interruption).
await asyncio.shield(
self._task.cancel_job_group(
self._group.job_id, reason="context exited with error"
)
)
return False
try:
await self._group.wait()
except JobGroupError as e:
raise JobError(str(e)) from e
return False

View File

@@ -0,0 +1,68 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Decorator for marking task methods as job handlers."""
from collections.abc import Callable
def job(*, name: str, sequential: bool = False):
"""Mark a task method as a job handler.
Decorated methods are automatically collected by ``BaseTask`` at
initialization and dispatched when matching job requests arrive.
Each request runs in its own asyncio task so the bus message loop
is never blocked.
Example::
@job(name="research")
async def on_research(self, message):
...
@job(name="write", sequential=True)
async def on_write(self, message):
...
Args:
name: Job name to match. The handler only receives requests
with a matching name.
sequential: When ``True``, requests with this name run one at
a time in FIFO order. Concurrent requests wait for the
previous one to finish before running. When ``False`` (the
default), multiple requests run concurrently. The wait
time counts against the requester's timeout, so a slow
predecessor can cause queued requests to time out before
they start.
"""
def decorator(fn: Callable) -> Callable:
fn.is_job_handler = True # type: ignore[attr-defined]
fn.job_name = name # type: ignore[attr-defined]
fn.job_sequential = sequential # type: ignore[attr-defined]
return fn
return decorator
def _collect_job_handlers(obj) -> dict[str, Callable]:
seen: set[str] = set()
handlers: dict[str, Callable] = {}
for cls in type(obj).__mro__:
for attr_name, val in cls.__dict__.items():
if attr_name in seen:
continue
seen.add(attr_name)
if callable(val) and getattr(val, "is_job_handler", False):
job_name: str = getattr(val, "job_name")
if job_name in handlers:
existing = handlers[job_name].__name__
raise ValueError(
f"Duplicate @job handler for '{job_name}': "
f"'{attr_name}' conflicts with '{existing}'"
)
handlers[job_name] = getattr(obj, attr_name)
return handlers

View File

@@ -0,0 +1,59 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Decorator for marking agent methods as agent-ready handlers."""
def task_ready(*, name: str):
"""Mark a method as a handler for a specific agent becoming ready.
Decorated methods are automatically collected by ``BaseTask`` at
initialization. When ``on_ready`` fires, the agent calls
``watch_task`` for each decorated handler. When the watched agent
registers, the decorated method is called with the ready data.
Example::
@task_ready(name="greeter")
async def on_greeter_ready(self, data: TaskReadyData) -> None:
await self.activate_task("greeter", args=...)
Args:
name: The name of the agent to watch.
"""
def decorator(fn):
fn.agent_ready_name = name
return fn
return decorator
def _collect_task_ready_handlers(obj) -> dict:
"""Collect all ``@task_ready`` decorated bound methods from an object.
Returns a dict mapping agent name to the bound method.
Raises:
ValueError: If two handlers watch the same agent name.
"""
seen: set[str] = set()
handlers: dict[str, object] = {}
for cls in type(obj).__mro__:
for attr_name, val in cls.__dict__.items():
if attr_name in seen:
continue
seen.add(attr_name)
if callable(val) and hasattr(val, "agent_ready_name"):
agent_name = val.agent_ready_name
if agent_name in handlers:
existing = handlers[agent_name].__name__
raise ValueError(
f"Duplicate @task_ready handler for '{agent_name}': "
f"'{attr_name}' conflicts with '{existing}'"
)
handlers[agent_name] = getattr(obj, attr_name)
return handlers