Add job context and decorators
Adds `JobContext` / `JobGroupContext` async context managers, the `JobGroup` / `JobGroupEvent` / `JobGroupResponse` / `JobGroupError` types, the `@job` decorator (with collector), and the `@task_ready` decorator (with collector). These power the bus-driven job RPC between tasks.
This commit is contained in:
362
src/pipecat/pipeline/job_context.py
Normal file
362
src/pipecat/pipeline/job_context.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Task group types for structured concurrent task execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.pipeline.base_task import BaseTask
|
||||
|
||||
|
||||
class JobStatus(StrEnum):
|
||||
"""Status of a completed task.
|
||||
|
||||
Inherits from ``str`` so values compare naturally with plain strings
|
||||
and serialize without extra handling.
|
||||
|
||||
Attributes:
|
||||
COMPLETED: The task finished successfully.
|
||||
CANCELLED: The task was cancelled by the requester.
|
||||
FAILED: The task failed due to a logical or business error.
|
||||
ERROR: The task encountered an unexpected runtime error.
|
||||
"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
FAILED = "failed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class JobError(Exception):
|
||||
"""Raised when a task is cancelled due to a worker error or timeout."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class JobGroupError(Exception):
|
||||
"""Raised when a task group is cancelled due to a worker error or timeout."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobGroupResponse:
|
||||
"""Collected results from a completed task group.
|
||||
|
||||
Parameters:
|
||||
job_id: The shared task identifier.
|
||||
responses: Collected responses keyed by worker name.
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
responses: dict[str, dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobEvent:
|
||||
"""An event received from a worker during a single-worker job.
|
||||
|
||||
Parameters:
|
||||
type: The event type.
|
||||
data: Optional event payload.
|
||||
"""
|
||||
|
||||
UPDATE: ClassVar[str] = "update"
|
||||
STREAM_START: ClassVar[str] = "stream_start"
|
||||
STREAM_DATA: ClassVar[str] = "stream_data"
|
||||
STREAM_END: ClassVar[str] = "stream_end"
|
||||
|
||||
type: str
|
||||
data: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobGroupEvent:
|
||||
"""An event received from a worker during task group execution.
|
||||
|
||||
Parameters:
|
||||
type: The event type.
|
||||
task_name: The name of the worker that sent the event.
|
||||
data: Optional event payload.
|
||||
"""
|
||||
|
||||
UPDATE: ClassVar[str] = "update"
|
||||
STREAM_START: ClassVar[str] = "stream_start"
|
||||
STREAM_DATA: ClassVar[str] = "stream_data"
|
||||
STREAM_END: ClassVar[str] = "stream_end"
|
||||
|
||||
type: str
|
||||
task_name: str
|
||||
data: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobGroup:
|
||||
"""Tracks a group of workers launched together.
|
||||
|
||||
Parameters:
|
||||
job_id: Shared identifier for all workers in this group.
|
||||
task_names: Names of the workers in the group.
|
||||
responses: Collected responses keyed by worker name.
|
||||
timeout_task: Optional asyncio task that cancels the group on timeout.
|
||||
cancel_on_error: Whether to cancel the group if a worker errors.
|
||||
event_queue: Optional queue for streaming events to a
|
||||
``JobGroupContext`` async iterator.
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
task_names: set[str]
|
||||
responses: dict[str, dict] = field(default_factory=dict)
|
||||
timeout_task: asyncio.Task | None = None
|
||||
cancel_on_error: bool = True
|
||||
event_queue: asyncio.Queue | None = field(default=None, repr=False)
|
||||
_done: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
_error: str | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
"""Whether the group has completed or failed."""
|
||||
return self._done.is_set()
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""Wait for all workers in the group to respond.
|
||||
|
||||
Raises:
|
||||
JobGroupError: If the group was cancelled due to error or timeout.
|
||||
"""
|
||||
await self._done.wait()
|
||||
if self._error:
|
||||
raise JobGroupError(self._error)
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Signal that all workers have responded."""
|
||||
self._done.set()
|
||||
if self.event_queue:
|
||||
self.event_queue.put_nowait(None)
|
||||
|
||||
def fail(self, reason: str | None = None) -> None:
|
||||
"""Signal that the group was cancelled.
|
||||
|
||||
Args:
|
||||
reason: Human-readable reason for the failure.
|
||||
"""
|
||||
self._error = reason
|
||||
self._done.set()
|
||||
if self.event_queue:
|
||||
self.event_queue.put_nowait(None)
|
||||
|
||||
|
||||
class JobGroupContext:
|
||||
"""Async context manager and iterator for structured task group execution.
|
||||
|
||||
Sends task requests on enter, waits for all responses on exit.
|
||||
Supports ``async for`` to receive intermediate events (updates
|
||||
and streaming data) from workers while waiting for completion.
|
||||
|
||||
On normal completion, results are available via ``responses``.
|
||||
On worker error (with ``cancel_on_error=True``) or timeout, raises
|
||||
``JobGroupError``. If the ``async with`` block raises, remaining
|
||||
tasks are cancelled.
|
||||
|
||||
Example::
|
||||
|
||||
async with self.job_group("w1", "w2", payload=data) as tg:
|
||||
async for event in tg:
|
||||
print(f"{event.task_name} [{event.type}]: {event.data}")
|
||||
|
||||
for name, result in tg.responses.items():
|
||||
print(name, result)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: BaseTask,
|
||||
task_names: tuple[str, ...],
|
||||
*,
|
||||
name: str | None = None,
|
||||
payload: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
cancel_on_error: bool = True,
|
||||
):
|
||||
"""Initialize the JobGroupContext.
|
||||
|
||||
Args:
|
||||
task: The parent `BaseTask` that owns this job group.
|
||||
task_names: Names of the workers to send the job to.
|
||||
name: Optional task name for routing to named handlers.
|
||||
payload: Optional structured data describing the work.
|
||||
timeout: Optional timeout in seconds covering both the
|
||||
ready-wait and task execution.
|
||||
cancel_on_error: Whether to cancel the group if a worker
|
||||
errors. Defaults to True.
|
||||
"""
|
||||
self._task = task
|
||||
self._task_names = task_names
|
||||
self._name = name
|
||||
self._payload = payload
|
||||
self._timeout = timeout
|
||||
self._cancel_on_error = cancel_on_error
|
||||
self._group: JobGroup | None = None
|
||||
|
||||
@property
|
||||
def job_id(self) -> str:
|
||||
"""The shared task identifier for this group."""
|
||||
if not self._group:
|
||||
raise RuntimeError("Task group has not been started")
|
||||
return self._group.job_id
|
||||
|
||||
@property
|
||||
def responses(self) -> dict[str, dict]:
|
||||
"""Collected responses keyed by worker name."""
|
||||
if not self._group:
|
||||
raise RuntimeError("Task group has not been started")
|
||||
return self._group.responses
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> JobGroupEvent:
|
||||
if not self._group or not self._group.event_queue:
|
||||
raise StopAsyncIteration
|
||||
event = await self._group.event_queue.get()
|
||||
if event is None:
|
||||
raise StopAsyncIteration
|
||||
return event
|
||||
|
||||
async def __aenter__(self) -> JobGroupContext:
|
||||
self._group = await self._task.create_job_group_and_request_job(
|
||||
list(self._task_names),
|
||||
name=self._name,
|
||||
payload=self._payload,
|
||||
timeout=self._timeout,
|
||||
cancel_on_error=self._cancel_on_error,
|
||||
)
|
||||
self._group.event_queue = asyncio.Queue()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||
if exc_type is not None:
|
||||
if self._group and self._group.job_id in self._task.job_groups:
|
||||
# Shield the cleanup so it completes even if the
|
||||
# surrounding task is being cancelled (e.g. tool
|
||||
# interruption).
|
||||
await asyncio.shield(
|
||||
self._task.cancel_job_group(
|
||||
self._group.job_id, reason="context exited with error"
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
await self._group.wait()
|
||||
return False
|
||||
|
||||
|
||||
class JobContext:
|
||||
"""Async context manager and iterator for a single-worker job.
|
||||
|
||||
Sends a task request on enter, waits for the response on exit.
|
||||
Supports ``async for`` to receive intermediate events (updates
|
||||
and streaming data) from the worker while waiting for completion.
|
||||
|
||||
On normal completion, the result is available via ``response``.
|
||||
On worker error or timeout, raises ``JobError``. If the
|
||||
``async with`` block raises, the task is cancelled.
|
||||
|
||||
Example::
|
||||
|
||||
async with self.job("worker", payload=data) as t:
|
||||
async for event in t:
|
||||
print(f"[{event.type}]: {event.data}")
|
||||
|
||||
print(t.response)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: BaseTask,
|
||||
task_name: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
payload: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
"""Initialize the JobContext.
|
||||
|
||||
Args:
|
||||
task: The parent `BaseTask` that owns this job.
|
||||
task_name: Name of the worker to send the job to.
|
||||
name: Optional task name for routing to a named handler.
|
||||
payload: Optional structured data describing the work.
|
||||
timeout: Optional timeout in seconds covering both the
|
||||
ready-wait and task execution.
|
||||
"""
|
||||
self._task = task
|
||||
self._task_name = task_name
|
||||
self._name = name
|
||||
self._payload = payload
|
||||
self._timeout = timeout
|
||||
self._group: JobGroup | None = None
|
||||
|
||||
@property
|
||||
def job_id(self) -> str:
|
||||
"""The task identifier."""
|
||||
if not self._group:
|
||||
raise RuntimeError("Task has not been started")
|
||||
return self._group.job_id
|
||||
|
||||
@property
|
||||
def response(self) -> dict:
|
||||
"""The worker's response payload."""
|
||||
if not self._group:
|
||||
raise RuntimeError("Task has not been started")
|
||||
return self._group.responses.get(self._task_name, {})
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> JobEvent:
|
||||
if not self._group or not self._group.event_queue:
|
||||
raise StopAsyncIteration
|
||||
event = await self._group.event_queue.get()
|
||||
if event is None:
|
||||
raise StopAsyncIteration
|
||||
return JobEvent(type=event.type, data=event.data)
|
||||
|
||||
async def __aenter__(self) -> JobContext:
|
||||
self._group = await self._task.create_job_group_and_request_job(
|
||||
[self._task_name],
|
||||
name=self._name,
|
||||
payload=self._payload,
|
||||
timeout=self._timeout,
|
||||
cancel_on_error=True,
|
||||
)
|
||||
self._group.event_queue = asyncio.Queue()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||
if exc_type is not None:
|
||||
if self._group and self._group.job_id in self._task.job_groups:
|
||||
# Shield the cleanup so it completes even if the
|
||||
# surrounding task is being cancelled (e.g. tool
|
||||
# interruption).
|
||||
await asyncio.shield(
|
||||
self._task.cancel_job_group(
|
||||
self._group.job_id, reason="context exited with error"
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._group.wait()
|
||||
except JobGroupError as e:
|
||||
raise JobError(str(e)) from e
|
||||
return False
|
||||
68
src/pipecat/pipeline/job_decorator.py
Normal file
68
src/pipecat/pipeline/job_decorator.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Decorator for marking task methods as job handlers."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def job(*, name: str, sequential: bool = False):
|
||||
"""Mark a task method as a job handler.
|
||||
|
||||
Decorated methods are automatically collected by ``BaseTask`` at
|
||||
initialization and dispatched when matching job requests arrive.
|
||||
Each request runs in its own asyncio task so the bus message loop
|
||||
is never blocked.
|
||||
|
||||
Example::
|
||||
|
||||
@job(name="research")
|
||||
async def on_research(self, message):
|
||||
...
|
||||
|
||||
@job(name="write", sequential=True)
|
||||
async def on_write(self, message):
|
||||
...
|
||||
|
||||
Args:
|
||||
name: Job name to match. The handler only receives requests
|
||||
with a matching name.
|
||||
sequential: When ``True``, requests with this name run one at
|
||||
a time in FIFO order. Concurrent requests wait for the
|
||||
previous one to finish before running. When ``False`` (the
|
||||
default), multiple requests run concurrently. The wait
|
||||
time counts against the requester's timeout, so a slow
|
||||
predecessor can cause queued requests to time out before
|
||||
they start.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
fn.is_job_handler = True # type: ignore[attr-defined]
|
||||
fn.job_name = name # type: ignore[attr-defined]
|
||||
fn.job_sequential = sequential # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _collect_job_handlers(obj) -> dict[str, Callable]:
|
||||
seen: set[str] = set()
|
||||
handlers: dict[str, Callable] = {}
|
||||
for cls in type(obj).__mro__:
|
||||
for attr_name, val in cls.__dict__.items():
|
||||
if attr_name in seen:
|
||||
continue
|
||||
seen.add(attr_name)
|
||||
if callable(val) and getattr(val, "is_job_handler", False):
|
||||
job_name: str = getattr(val, "job_name")
|
||||
if job_name in handlers:
|
||||
existing = handlers[job_name].__name__
|
||||
raise ValueError(
|
||||
f"Duplicate @job handler for '{job_name}': "
|
||||
f"'{attr_name}' conflicts with '{existing}'"
|
||||
)
|
||||
handlers[job_name] = getattr(obj, attr_name)
|
||||
return handlers
|
||||
59
src/pipecat/pipeline/task_ready_decorator.py
Normal file
59
src/pipecat/pipeline/task_ready_decorator.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Decorator for marking agent methods as agent-ready handlers."""
|
||||
|
||||
|
||||
def task_ready(*, name: str):
|
||||
"""Mark a method as a handler for a specific agent becoming ready.
|
||||
|
||||
Decorated methods are automatically collected by ``BaseTask`` at
|
||||
initialization. When ``on_ready`` fires, the agent calls
|
||||
``watch_task`` for each decorated handler. When the watched agent
|
||||
registers, the decorated method is called with the ready data.
|
||||
|
||||
Example::
|
||||
|
||||
@task_ready(name="greeter")
|
||||
async def on_greeter_ready(self, data: TaskReadyData) -> None:
|
||||
await self.activate_task("greeter", args=...)
|
||||
|
||||
Args:
|
||||
name: The name of the agent to watch.
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
fn.agent_ready_name = name
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _collect_task_ready_handlers(obj) -> dict:
|
||||
"""Collect all ``@task_ready`` decorated bound methods from an object.
|
||||
|
||||
Returns a dict mapping agent name to the bound method.
|
||||
|
||||
Raises:
|
||||
ValueError: If two handlers watch the same agent name.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
handlers: dict[str, object] = {}
|
||||
for cls in type(obj).__mro__:
|
||||
for attr_name, val in cls.__dict__.items():
|
||||
if attr_name in seen:
|
||||
continue
|
||||
seen.add(attr_name)
|
||||
if callable(val) and hasattr(val, "agent_ready_name"):
|
||||
agent_name = val.agent_ready_name
|
||||
if agent_name in handlers:
|
||||
existing = handlers[agent_name].__name__
|
||||
raise ValueError(
|
||||
f"Duplicate @task_ready handler for '{agent_name}': "
|
||||
f"'{attr_name}' conflicts with '{existing}'"
|
||||
)
|
||||
handlers[agent_name] = getattr(obj, attr_name)
|
||||
return handlers
|
||||
Reference in New Issue
Block a user