diff --git a/src/pipecat/pipeline/job_context.py b/src/pipecat/pipeline/job_context.py new file mode 100644 index 000000000..0767b127a --- /dev/null +++ b/src/pipecat/pipeline/job_context.py @@ -0,0 +1,362 @@ +# +# Copyright (c) 2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Task group types for structured concurrent task execution.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from enum import StrEnum +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from pipecat.pipeline.base_task import BaseTask + + +class JobStatus(StrEnum): + """Status of a completed task. + + Inherits from ``str`` so values compare naturally with plain strings + and serialize without extra handling. + + Attributes: + COMPLETED: The task finished successfully. + CANCELLED: The task was cancelled by the requester. + FAILED: The task failed due to a logical or business error. + ERROR: The task encountered an unexpected runtime error. + """ + + COMPLETED = "completed" + CANCELLED = "cancelled" + FAILED = "failed" + ERROR = "error" + + +class JobError(Exception): + """Raised when a task is cancelled due to a worker error or timeout.""" + + pass + + +class JobGroupError(Exception): + """Raised when a task group is cancelled due to a worker error or timeout.""" + + pass + + +@dataclass +class JobGroupResponse: + """Collected results from a completed task group. + + Parameters: + job_id: The shared task identifier. + responses: Collected responses keyed by worker name. + """ + + job_id: str + responses: dict[str, dict] + + +@dataclass +class JobEvent: + """An event received from a worker during a single-worker job. + + Parameters: + type: The event type. + data: Optional event payload. + """ + + UPDATE: ClassVar[str] = "update" + STREAM_START: ClassVar[str] = "stream_start" + STREAM_DATA: ClassVar[str] = "stream_data" + STREAM_END: ClassVar[str] = "stream_end" + + type: str + data: dict | None = None + + +@dataclass +class JobGroupEvent: + """An event received from a worker during task group execution. + + Parameters: + type: The event type. + task_name: The name of the worker that sent the event. + data: Optional event payload. + """ + + UPDATE: ClassVar[str] = "update" + STREAM_START: ClassVar[str] = "stream_start" + STREAM_DATA: ClassVar[str] = "stream_data" + STREAM_END: ClassVar[str] = "stream_end" + + type: str + task_name: str + data: dict | None = None + + +@dataclass +class JobGroup: + """Tracks a group of workers launched together. + + Parameters: + job_id: Shared identifier for all workers in this group. + task_names: Names of the workers in the group. + responses: Collected responses keyed by worker name. + timeout_task: Optional asyncio task that cancels the group on timeout. + cancel_on_error: Whether to cancel the group if a worker errors. + event_queue: Optional queue for streaming events to a + ``JobGroupContext`` async iterator. + """ + + job_id: str + task_names: set[str] + responses: dict[str, dict] = field(default_factory=dict) + timeout_task: asyncio.Task | None = None + cancel_on_error: bool = True + event_queue: asyncio.Queue | None = field(default=None, repr=False) + _done: asyncio.Event = field(default_factory=asyncio.Event, repr=False) + _error: str | None = field(default=None, repr=False) + + @property + def is_done(self) -> bool: + """Whether the group has completed or failed.""" + return self._done.is_set() + + async def wait(self) -> None: + """Wait for all workers in the group to respond. + + Raises: + JobGroupError: If the group was cancelled due to error or timeout. + """ + await self._done.wait() + if self._error: + raise JobGroupError(self._error) + + def complete(self) -> None: + """Signal that all workers have responded.""" + self._done.set() + if self.event_queue: + self.event_queue.put_nowait(None) + + def fail(self, reason: str | None = None) -> None: + """Signal that the group was cancelled. + + Args: + reason: Human-readable reason for the failure. + """ + self._error = reason + self._done.set() + if self.event_queue: + self.event_queue.put_nowait(None) + + +class JobGroupContext: + """Async context manager and iterator for structured task group execution. + + Sends task requests on enter, waits for all responses on exit. + Supports ``async for`` to receive intermediate events (updates + and streaming data) from workers while waiting for completion. + + On normal completion, results are available via ``responses``. + On worker error (with ``cancel_on_error=True``) or timeout, raises + ``JobGroupError``. If the ``async with`` block raises, remaining + tasks are cancelled. + + Example:: + + async with self.job_group("w1", "w2", payload=data) as tg: + async for event in tg: + print(f"{event.task_name} [{event.type}]: {event.data}") + + for name, result in tg.responses.items(): + print(name, result) + """ + + def __init__( + self, + task: BaseTask, + task_names: tuple[str, ...], + *, + name: str | None = None, + payload: dict | None = None, + timeout: float | None = None, + cancel_on_error: bool = True, + ): + """Initialize the JobGroupContext. + + Args: + task: The parent `BaseTask` that owns this job group. + task_names: Names of the workers to send the job to. + name: Optional task name for routing to named handlers. + payload: Optional structured data describing the work. + timeout: Optional timeout in seconds covering both the + ready-wait and task execution. + cancel_on_error: Whether to cancel the group if a worker + errors. Defaults to True. + """ + self._task = task + self._task_names = task_names + self._name = name + self._payload = payload + self._timeout = timeout + self._cancel_on_error = cancel_on_error + self._group: JobGroup | None = None + + @property + def job_id(self) -> str: + """The shared task identifier for this group.""" + if not self._group: + raise RuntimeError("Task group has not been started") + return self._group.job_id + + @property + def responses(self) -> dict[str, dict]: + """Collected responses keyed by worker name.""" + if not self._group: + raise RuntimeError("Task group has not been started") + return self._group.responses + + def __aiter__(self): + return self + + async def __anext__(self) -> JobGroupEvent: + if not self._group or not self._group.event_queue: + raise StopAsyncIteration + event = await self._group.event_queue.get() + if event is None: + raise StopAsyncIteration + return event + + async def __aenter__(self) -> JobGroupContext: + self._group = await self._task.create_job_group_and_request_job( + list(self._task_names), + name=self._name, + payload=self._payload, + timeout=self._timeout, + cancel_on_error=self._cancel_on_error, + ) + self._group.event_queue = asyncio.Queue() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + if exc_type is not None: + if self._group and self._group.job_id in self._task.job_groups: + # Shield the cleanup so it completes even if the + # surrounding task is being cancelled (e.g. tool + # interruption). + await asyncio.shield( + self._task.cancel_job_group( + self._group.job_id, reason="context exited with error" + ) + ) + return False + + await self._group.wait() + return False + + +class JobContext: + """Async context manager and iterator for a single-worker job. + + Sends a task request on enter, waits for the response on exit. + Supports ``async for`` to receive intermediate events (updates + and streaming data) from the worker while waiting for completion. + + On normal completion, the result is available via ``response``. + On worker error or timeout, raises ``JobError``. If the + ``async with`` block raises, the task is cancelled. + + Example:: + + async with self.job("worker", payload=data) as t: + async for event in t: + print(f"[{event.type}]: {event.data}") + + print(t.response) + """ + + def __init__( + self, + task: BaseTask, + task_name: str, + *, + name: str | None = None, + payload: dict | None = None, + timeout: float | None = None, + ): + """Initialize the JobContext. + + Args: + task: The parent `BaseTask` that owns this job. + task_name: Name of the worker to send the job to. + name: Optional task name for routing to a named handler. + payload: Optional structured data describing the work. + timeout: Optional timeout in seconds covering both the + ready-wait and task execution. + """ + self._task = task + self._task_name = task_name + self._name = name + self._payload = payload + self._timeout = timeout + self._group: JobGroup | None = None + + @property + def job_id(self) -> str: + """The task identifier.""" + if not self._group: + raise RuntimeError("Task has not been started") + return self._group.job_id + + @property + def response(self) -> dict: + """The worker's response payload.""" + if not self._group: + raise RuntimeError("Task has not been started") + return self._group.responses.get(self._task_name, {}) + + def __aiter__(self): + return self + + async def __anext__(self) -> JobEvent: + if not self._group or not self._group.event_queue: + raise StopAsyncIteration + event = await self._group.event_queue.get() + if event is None: + raise StopAsyncIteration + return JobEvent(type=event.type, data=event.data) + + async def __aenter__(self) -> JobContext: + self._group = await self._task.create_job_group_and_request_job( + [self._task_name], + name=self._name, + payload=self._payload, + timeout=self._timeout, + cancel_on_error=True, + ) + self._group.event_queue = asyncio.Queue() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + if exc_type is not None: + if self._group and self._group.job_id in self._task.job_groups: + # Shield the cleanup so it completes even if the + # surrounding task is being cancelled (e.g. tool + # interruption). + await asyncio.shield( + self._task.cancel_job_group( + self._group.job_id, reason="context exited with error" + ) + ) + return False + + try: + await self._group.wait() + except JobGroupError as e: + raise JobError(str(e)) from e + return False diff --git a/src/pipecat/pipeline/job_decorator.py b/src/pipecat/pipeline/job_decorator.py new file mode 100644 index 000000000..e91e8e7ca --- /dev/null +++ b/src/pipecat/pipeline/job_decorator.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Decorator for marking task methods as job handlers.""" + +from collections.abc import Callable + + +def job(*, name: str, sequential: bool = False): + """Mark a task method as a job handler. + + Decorated methods are automatically collected by ``BaseTask`` at + initialization and dispatched when matching job requests arrive. + Each request runs in its own asyncio task so the bus message loop + is never blocked. + + Example:: + + @job(name="research") + async def on_research(self, message): + ... + + @job(name="write", sequential=True) + async def on_write(self, message): + ... + + Args: + name: Job name to match. The handler only receives requests + with a matching name. + sequential: When ``True``, requests with this name run one at + a time in FIFO order. Concurrent requests wait for the + previous one to finish before running. When ``False`` (the + default), multiple requests run concurrently. The wait + time counts against the requester's timeout, so a slow + predecessor can cause queued requests to time out before + they start. + """ + + def decorator(fn: Callable) -> Callable: + fn.is_job_handler = True # type: ignore[attr-defined] + fn.job_name = name # type: ignore[attr-defined] + fn.job_sequential = sequential # type: ignore[attr-defined] + return fn + + return decorator + + +def _collect_job_handlers(obj) -> dict[str, Callable]: + seen: set[str] = set() + handlers: dict[str, Callable] = {} + for cls in type(obj).__mro__: + for attr_name, val in cls.__dict__.items(): + if attr_name in seen: + continue + seen.add(attr_name) + if callable(val) and getattr(val, "is_job_handler", False): + job_name: str = getattr(val, "job_name") + if job_name in handlers: + existing = handlers[job_name].__name__ + raise ValueError( + f"Duplicate @job handler for '{job_name}': " + f"'{attr_name}' conflicts with '{existing}'" + ) + handlers[job_name] = getattr(obj, attr_name) + return handlers diff --git a/src/pipecat/pipeline/task_ready_decorator.py b/src/pipecat/pipeline/task_ready_decorator.py new file mode 100644 index 000000000..ee773d293 --- /dev/null +++ b/src/pipecat/pipeline/task_ready_decorator.py @@ -0,0 +1,59 @@ +# +# Copyright (c) 2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Decorator for marking agent methods as agent-ready handlers.""" + + +def task_ready(*, name: str): + """Mark a method as a handler for a specific agent becoming ready. + + Decorated methods are automatically collected by ``BaseTask`` at + initialization. When ``on_ready`` fires, the agent calls + ``watch_task`` for each decorated handler. When the watched agent + registers, the decorated method is called with the ready data. + + Example:: + + @task_ready(name="greeter") + async def on_greeter_ready(self, data: TaskReadyData) -> None: + await self.activate_task("greeter", args=...) + + Args: + name: The name of the agent to watch. + """ + + def decorator(fn): + fn.agent_ready_name = name + return fn + + return decorator + + +def _collect_task_ready_handlers(obj) -> dict: + """Collect all ``@task_ready`` decorated bound methods from an object. + + Returns a dict mapping agent name to the bound method. + + Raises: + ValueError: If two handlers watch the same agent name. + """ + seen: set[str] = set() + handlers: dict[str, object] = {} + for cls in type(obj).__mro__: + for attr_name, val in cls.__dict__.items(): + if attr_name in seen: + continue + seen.add(attr_name) + if callable(val) and hasattr(val, "agent_ready_name"): + agent_name = val.agent_ready_name + if agent_name in handlers: + existing = handlers[agent_name].__name__ + raise ValueError( + f"Duplicate @task_ready handler for '{agent_name}': " + f"'{attr_name}' conflicts with '{existing}'" + ) + handlers[agent_name] = getattr(obj, attr_name) + return handlers