Add UIWorker

UIWorker is an LLMContextWorker that observes and drives a client GUI over the
RTVI UI channel: it stores accessibility snapshots, auto-injects <ui_state> at
the start of each respond job, dispatches client events to @on_ui_event
handlers, sends UI commands back to the client, and surfaces fan-out work as
cancellable task cards via user_job_group(). The optional ReplyToolMixin exposes
a bundled reply tool.

The prompt_guide parameter auto-appends the UI wire-format guide to the LLM's
system instruction (default UI_STATE_PROMPT_GUIDE; override with a string or
disable with None), so the LLM can parse the injected <ui_state> / <ui_event>
messages without the app concatenating the guide by hand.
This commit is contained in:
Mark Backman
2026-05-21 16:51:30 -04:00
parent 02667a7255
commit f1f5a986e8
11 changed files with 3456 additions and 0 deletions

3
changelog/xxxx.added.md Normal file
View File

@@ -0,0 +1,3 @@
- Added `pipecat.workers.ui.UIWorker`, an `LLMContextWorker` that observes and drives a client GUI over the RTVI UI channel: it stores live accessibility snapshots, auto-injects `<ui_state>` at the start of each `respond` job, dispatches client events to `@on_ui_event` handlers, and sends UI commands (`scroll_to`, `highlight`, `select_text`, `click`, `set_input_value`) back to the client. The optional `ReplyToolMixin` exposes a bundled `reply` tool, and `user_job_group(...)` surfaces fan-out work to the client as cancellable task cards. A native RTVI⇄bus UI bridge is built into `PipelineWorker` (active whenever RTVI is enabled), so no decorator or manual wiring is needed: inbound UI messages are broadcast on the bus as `BusUIEventMessage`, and outbound `BusUICommandMessage` / `BusUITask*` carriers are translated into RTVI frames for the client.
- `UIWorker` auto-injects the UI wire-format guide (`UI_STATE_PROMPT_GUIDE`) into its LLM's system instruction by default, via a `prompt_guide` parameter — pass your own string to override the guide, or `None` to disable. Apps no longer need to concatenate `UI_STATE_PROMPT_GUIDE` into the LLM's `system_instruction` by hand.

View File

@@ -0,0 +1,49 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""UI worker: LLM worker that observes and drives a GUI app.
Composes the RTVI wire protocol for client UI events, accessibility-tree
snapshots, and server-emitted UI commands (in
``pipecat.processors.frameworks.rtvi.models``) with an opt-in
``ReplyToolMixin`` exposing the canonical bundled reply tool
(``answer`` + optional ``scroll_to`` / ``highlight`` / ...).
The RTVI⇄bus bridge that connects a ``UIWorker`` to the client is built
into ``PipelineWorker`` and is active whenever RTVI is enabled — there
is no decorator or bridge to wire up.
"""
from pipecat.bus.ui_messages import (
BusUICommandMessage,
BusUIEventMessage,
BusUITaskCompletedMessage,
BusUITaskGroupCompletedMessage,
BusUITaskGroupStartedMessage,
BusUITaskUpdateMessage,
)
from pipecat.workers.ui.ui_event_decorator import on_ui_event
from pipecat.workers.ui.ui_prompts import UI_STATE_PROMPT_GUIDE
from pipecat.workers.ui.ui_tools import ReplyToolMixin
from pipecat.workers.ui.ui_worker import UIWorker
# Built-in UI command payload models (Toast, Navigate, ScrollTo,
# Highlight, Focus, Click, SetInputValue, SelectText) live in
# ``pipecat.processors.frameworks.rtvi.models``. Import them from there
# directly.
__all__ = [
"BusUICommandMessage",
"BusUIEventMessage",
"BusUITaskCompletedMessage",
"BusUITaskGroupCompletedMessage",
"BusUITaskGroupStartedMessage",
"BusUITaskUpdateMessage",
"ReplyToolMixin",
"UIWorker",
"UI_STATE_PROMPT_GUIDE",
"on_ui_event",
]

View File

@@ -0,0 +1,67 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Decorator for marking worker methods as UI event handlers."""
def on_ui_event(name: str):
"""Mark a worker method as a handler for a named UI event.
On ``UIWorker`` subclasses, decorated methods are automatically
dispatched when a ``BusUIEventMessage`` with a matching ``name``
arrives.
Example::
class MyUIWorker(UIWorker):
@on_ui_event("nav_click")
async def on_nav(self, message):
view = message.payload.get("view")
...
Args:
name: The UI event name to match.
"""
def decorator(fn):
fn.is_ui_event_handler = True
fn.ui_event_name = name
return fn
return decorator
def _collect_ui_event_handlers(obj) -> dict:
"""Collect all ``@on_ui_event`` decorated bound methods from an object.
Walks the MRO so that overridden methods in subclasses take
precedence over base-class definitions.
Returns:
A dict mapping event name to the bound method.
Raises:
ValueError: If two handlers share the same event name on the
same subclass level.
"""
seen: set[str] = set()
handlers: dict[str, object] = {}
source_names: dict[str, str] = {} # event name -> defining method name, for errors
for cls in type(obj).__mro__:
for attr_name, val in cls.__dict__.items():
if attr_name in seen:
continue
seen.add(attr_name)
if callable(val) and getattr(val, "is_ui_event_handler", False):
event_name = val.ui_event_name
if event_name in handlers:
raise ValueError(
f"Duplicate @on_ui_event handler for '{event_name}': "
f"'{attr_name}' conflicts with '{source_names[event_name]}'"
)
handlers[event_name] = getattr(obj, attr_name)
source_names[event_name] = attr_name
return handlers

View File

@@ -0,0 +1,147 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""User-facing job group context.
Wraps ``JobGroupContext`` so the work it dispatches is also surfaced
to the UI client through the UI Worker protocol. Apps reach this via
``UIWorker.user_job_group(...)`` rather than constructing it directly.
"""
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from pipecat.bus.ui_messages import (
BusUITaskGroupCompletedMessage,
BusUITaskGroupStartedMessage,
)
from pipecat.pipeline.job_context import JobGroupContext
if TYPE_CHECKING:
from pipecat.workers.ui.ui_worker import UIWorker
class UserJobGroupContext(JobGroupContext):
"""Job group whose lifecycle is forwarded to the UI client.
Behaves exactly like ``JobGroupContext`` for the dispatching code.
Additionally, on enter the context registers the group with its
parent ``UIWorker`` and publishes a ``BusUITaskGroupStartedMessage``.
The worker forwards any subsequent ``BusJobUpdateMessage`` /
``BusJobResponseMessage`` whose ``job_id`` matches a registered
group as ``BusUITaskUpdateMessage`` / ``BusUITaskCompletedMessage``.
On exit the context publishes ``BusUITaskGroupCompletedMessage`` and
deregisters.
Workers don't need to know about the UI surface: any
``send_job_update`` they emit against the group's ``job_id`` is
forwarded automatically.
Example::
async with self.user_job_group(
"researcher_a", "researcher_b",
payload={"query": query},
label=f"Research: {query}",
cancellable=True,
) as tg:
async for event in tg:
...
results = tg.responses
"""
def __init__(
self,
worker: UIWorker,
worker_names: tuple[str, ...],
*,
name: str | None = None,
payload: dict | None = None,
timeout: float | None = None,
cancel_on_error: bool = True,
label: str | None = None,
cancellable: bool = True,
):
"""Initialize the UserJobGroupContext.
Args:
worker: The parent ``UIWorker`` that owns this job group.
worker_names: Names of the workers to send the job to.
name: Optional job name for routing to named ``@job``
handlers on the workers.
payload: Optional structured data describing the work.
timeout: Optional timeout in seconds covering both the
ready-wait and job execution.
cancel_on_error: Whether to cancel the group if a worker
errors. Defaults to True.
label: Optional human-readable label surfaced to the
client (e.g. ``"Research: Radiohead"``). The client UI
uses it to title the in-flight task card.
cancellable: Whether the client may request cancellation
of this group via the reserved ``__cancel_task`` event.
Defaults to True.
"""
super().__init__(
worker,
worker_names,
name=name,
payload=payload,
timeout=timeout,
cancel_on_error=cancel_on_error,
)
self._ui_worker = worker
self._label = label
self._cancellable = cancellable
@property
def label(self) -> str | None:
"""The group's human-readable label, if any."""
return self._label
@property
def cancellable(self) -> bool:
"""Whether the client may request cancellation."""
return self._cancellable
async def __aenter__(self) -> UserJobGroupContext:
await super().__aenter__()
job_id = self.job_id
self._ui_worker._register_user_job_group(
job_id=job_id,
worker_names=list(self._worker_names),
label=self._label,
cancellable=self._cancellable,
)
await self._ui_worker.send_bus_message(
BusUITaskGroupStartedMessage(
source=self._ui_worker.name,
target=None,
task_id=job_id,
agents=list(self._worker_names),
label=self._label,
cancellable=self._cancellable,
at=int(time.time() * 1000),
)
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool:
job_id = self._group.job_id if self._group else None
try:
return await super().__aexit__(exc_type, exc_val, exc_tb)
finally:
if job_id:
self._ui_worker._unregister_user_job_group(job_id)
await self._ui_worker.send_bus_message(
BusUITaskGroupCompletedMessage(
source=self._ui_worker.name,
target=None,
task_id=job_id,
at=int(time.time() * 1000),
)
)

View File

@@ -0,0 +1,64 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Canonical prompt fragments describing the UI worker wire format.
Apps concatenate these constants into their system prompt so the
LLM understands the ``<ui_state>`` and ``<ui_event>`` developer
messages the SDK injects on its behalf.
Example::
system_prompt = f'''
You are a voice-driven music player agent.
...app-specific tool and behavior instructions...
{UI_STATE_PROMPT_GUIDE}
'''
The SDK updates the guide alongside the wire format, so apps get
new tags and semantics automatically on the next release.
"""
UI_STATE_PROMPT_GUIDE: str = """\
## UI context
Your developer context includes two kinds of SDK-managed messages:
- ``<ui_event name="..." >payload</ui_event>``: an event the user just \
triggered on the client (click, tab switch, navigation, etc.). The \
payload is JSON for that event.
- ``<ui_state>...</ui_state>``: an accessibility snapshot of the \
current screen, injected at the start of every task request. \
Indented tree in Playwright-MCP style. Each line is \
``- role "name" [state] [ref=eN]`` with children nested one level \
deeper.
State tags include ``[focused]``, ``[selected]``, ``[disabled]``, and \
``[offscreen]``. A node tagged ``[offscreen]`` exists on the page \
but is not currently in the user's viewport; only visible \
(non-offscreen) nodes count for position-based references.
Grids carry a ``[cols=N]`` tag. Their cells are listed in reading \
order (left-to-right, top-to-bottom); with N columns, cell K sits \
at row ``ceil(K/N)``, column ``((K-1) mod N) + 1``. Example with \
``[cols=8]`` and 16 children: "top right" is cell 8, "bottom left" \
is cell 9.
Resolve position references ("top right", "the first one", "the \
third new release") against the most recent ``<ui_state>`` tree. \
Sibling order matches reading order on screen (top-to-bottom, \
left-to-right within each region).
When the user has text selected on the page, the snapshot ends with \
a ``<selection ref="eN">selected text</selection>`` block inside \
``<ui_state>``. Treat the selection as the deictic referent for \
"this", "that", "what I selected", and similar phrases. The ``ref`` \
identifies the closest enclosing element that has a ref in the tree; \
the inner text is the actual selected content (truncated if very \
long). Text inside ``<input>`` or ``<textarea>`` selections is \
faithful to ``selectionStart``/``selectionEnd`` on the element.\
"""

View File

@@ -0,0 +1,155 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Opt-in tool mixin for ``UIWorker``.
Ships ``ReplyToolMixin``: a single ``reply(answer, scroll_to,
highlight, select_text, fills, click)`` LLM tool that bundles a
required spoken answer with the full set of standard UI actions.
One tool call per turn; the model cannot drop the terminator because
``answer`` is a required argument that the API schema enforces.
The bundled mixin covers the canonical app shapes (pointing,
reading, form-fill) and any blend of them. Apps don't need to pick
a mode up front; the LLM uses whichever fields make sense per turn,
leaving the rest as ``null``.
Apps that want a tighter schema (only the fields they use, or
app-specific commands like ``play_song``) hand-roll their own
``@tool reply`` on the ``UIWorker`` subclass directly. The helper
methods on ``UIWorker`` (``scroll_to``, ``highlight``,
``select_text``, ``click``, ``set_input_value``) plus
``send_command`` and the standard payload dataclasses cover the
building blocks for custom replies.
"""
from loguru import logger
from pipecat.services.llm_service import FunctionCallParams
from pipecat.workers.llm.tool_decorator import tool
class ReplyToolMixin:
"""Expose a ``reply`` tool covering the full standard action set.
Single bundled LLM tool with a required spoken ``answer`` plus
optional visual and state-changing actions. One tool call per
turn, no chaining; the required ``answer`` argument is enforced
by the API schema so the model cannot omit the terminator.
Compose alongside ``UIWorker``::
class MyUIWorker(ReplyToolMixin, UIWorker):
...
Covers pointing apps (``scroll_to`` + ``highlight``), reading
apps (``scroll_to`` + ``select_text``), form apps (``fills`` +
``click``), and any blend (e.g. a document review with
selection-based deixis AND voice-driven note-taking). The LLM
uses whichever fields fit the user's request per turn; unused
fields stay ``null`` and don't affect behavior.
Apps that want a minimal schema (only the fields actually used,
or app-specific commands) write their own ``@tool reply`` on the
``UIWorker`` subclass directly. Use the helper methods on
``UIWorker`` plus ``send_command`` to dispatch the underlying UI
commands.
The host class must provide ``scroll_to``, ``highlight``,
``select_text``, ``click``, ``set_input_value``, and
``respond_to_job`` (``UIWorker`` does) and must be the target of
``@tool`` discovery on the LLM pipeline.
"""
@tool
async def reply(
self,
params: FunctionCallParams,
answer: str,
scroll_to: str | None = None,
highlight: list[str] | None = None,
select_text: str | None = None,
fills: list[dict] | None = None,
click: list[str] | None = None,
):
"""Reply to the user. Optionally point at content and act on inputs.
Always called exactly once per turn. ``answer`` is required;
the action fields are optional and may be combined.
Visual / pointing actions (draw the user's attention):
- ``scroll_to`` brings an element into view (single ref).
- ``highlight`` flashes elements briefly (list of refs).
Best for short emphasis like a button or a fact.
- ``select_text`` puts the page's text selection on an
element (single ref). Best for "this paragraph" / "the
section about X" so the user sees exactly what was meant.
Persists until the user clicks elsewhere.
State-changing actions (modify form / app state):
- ``fills`` writes values into inputs (list of
``{"ref", "value"}`` objects, multi-fill in one turn).
- ``click`` clicks elements (list of refs in order). Use for
checkboxes, radios, submit buttons.
Order of dispatch within a turn: ``scroll_to``, then
``highlight``, then ``select_text``, then ``fills``, then
``click``, then speak the answer.
Args:
params: Framework-provided tool invocation context.
answer: The spoken reply in plain language. One short
sentence. No markdown, no symbols.
scroll_to: Optional snapshot ref. Scrolls the element
into view before speaking.
highlight: Optional list of snapshot refs. Visually
pulses each element.
select_text: Optional snapshot ref. Places the page's
text selection on that element.
fills: Optional list of ``{"ref": "eN", "value": "..."}``
objects. Writes each value into the input at ``ref``.
click: Optional list of snapshot refs to click in order.
"""
preview = (answer or "").strip()
if len(preview) > 80:
preview = preview[:80] + ""
logger.info(
f"{self}: reply(answer={preview!r}, scroll_to={scroll_to!r}, "
f"highlight={highlight!r}, select_text={select_text!r}, "
f"fills={fills!r}, click={click!r})"
)
# Defensive guards on the list arguments: an LLM that emits a
# malformed entry (None, a bare string, etc.) would crash the
# tool body before respond_to_job fires, leaving the
# single-flight lock held until the requester's timeout cancels
# us. Skip non-conforming entries instead.
if scroll_to:
await self.scroll_to(scroll_to) # type: ignore[attr-defined]
if highlight:
for ref in highlight:
if not isinstance(ref, str):
continue
await self.highlight(ref) # type: ignore[attr-defined]
if select_text:
await self.select_text(select_text) # type: ignore[attr-defined]
if fills:
for entry in fills:
if not isinstance(entry, dict):
continue
ref = entry.get("ref")
value = entry.get("value")
if not isinstance(ref, str) or value is None:
continue
await self.set_input_value(ref, str(value)) # type: ignore[attr-defined]
if click:
for ref in click:
if not isinstance(ref, str):
continue
await self.click(ref) # type: ignore[attr-defined]
await self.respond_to_job(speak=answer) # type: ignore[attr-defined]
await params.result_callback(None)

File diff suppressed because it is too large Load Diff

231
tests/test_ui_commands.py Normal file
View File

@@ -0,0 +1,231 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for UIWorker.send_command and standard command payload models."""
import unittest
from unittest.mock import MagicMock
from pipecat.bus.ui_messages import BusUICommandMessage
from pipecat.processors.frameworks.rtvi.models import (
Click,
Focus,
Highlight,
Navigate,
ScrollTo,
SelectText,
SetInputValue,
Toast,
)
from pipecat.workers.ui import UIWorker
def _make_worker():
"""A UIWorker whose ``send_bus_message`` captures sent commands.
``send_command`` publishes via ``send_bus_message``; replacing it
avoids needing an attached bus for these unit tests.
"""
worker = UIWorker("ui", llm=MagicMock(), active=False)
sent: list[BusUICommandMessage] = []
async def _record(message):
if isinstance(message, BusUICommandMessage):
sent.append(message)
worker.send_bus_message = _record # type: ignore[method-assign]
return worker, sent
class TestSendCommand(unittest.IsolatedAsyncioTestCase):
async def test_serializes_pydantic_payload_via_model_dump(self):
worker, sent = _make_worker()
await worker.send_command("toast", Toast(title="Saved", subtitle="Favorites"))
self.assertEqual(len(sent), 1)
cmd = sent[0]
self.assertEqual(cmd.source, "ui")
self.assertIsNone(cmd.target)
self.assertEqual(cmd.command_name, "toast")
self.assertEqual(
cmd.payload,
{
"title": "Saved",
"subtitle": "Favorites",
"description": None,
"image_url": None,
"duration_ms": None,
},
)
async def test_forwards_dict_payload_as_is(self):
worker, sent = _make_worker()
await worker.send_command("app_specific", {"foo": 1, "bar": [1, 2, 3]})
self.assertEqual(sent[0].command_name, "app_specific")
self.assertEqual(sent[0].payload, {"foo": 1, "bar": [1, 2, 3]})
async def test_none_payload_becomes_empty_dict(self):
worker, sent = _make_worker()
await worker.send_command("ping")
self.assertEqual(sent[0].payload, {})
async def test_dict_payload_for_apps_with_custom_command_names(self):
worker, sent = _make_worker()
await worker.send_command("navigate", {"view": "home"})
self.assertEqual(sent[0].payload, {"view": "home"})
class TestStandardCommands(unittest.IsolatedAsyncioTestCase):
async def test_toast_payload_shape(self):
worker, sent = _make_worker()
await worker.send_command(
"toast",
Toast(
title="Now playing",
subtitle="Nirvana",
description="Smells Like Teen Spirit",
image_url="https://example.com/cover.jpg",
duration_ms=3000,
),
)
self.assertEqual(
sent[0].payload,
{
"title": "Now playing",
"subtitle": "Nirvana",
"description": "Smells Like Teen Spirit",
"image_url": "https://example.com/cover.jpg",
"duration_ms": 3000,
},
)
async def test_navigate_payload_shape(self):
worker, sent = _make_worker()
await worker.send_command("navigate", Navigate(view="detail", params={"id": "42"}))
self.assertEqual(sent[0].payload, {"view": "detail", "params": {"id": "42"}})
async def test_scroll_to_payload_shape_by_target_id(self):
worker, sent = _make_worker()
await worker.send_command(
"scroll_to", ScrollTo(target_id="new_releases", behavior="smooth")
)
self.assertEqual(
sent[0].payload,
{"ref": None, "target_id": "new_releases", "behavior": "smooth"},
)
async def test_scroll_to_payload_shape_by_ref(self):
worker, sent = _make_worker()
await worker.send_command("scroll_to", ScrollTo(ref="e42", behavior="smooth"))
self.assertEqual(
sent[0].payload,
{"ref": "e42", "target_id": None, "behavior": "smooth"},
)
async def test_highlight_payload_shape(self):
worker, sent = _make_worker()
await worker.send_command("highlight", Highlight(target_id="play_btn", duration_ms=1000))
self.assertEqual(
sent[0].payload,
{"ref": None, "target_id": "play_btn", "duration_ms": 1000},
)
async def test_focus_payload_shape(self):
worker, sent = _make_worker()
await worker.send_command("focus", Focus(target_id="search_input"))
self.assertEqual(
sent[0].payload,
{"ref": None, "target_id": "search_input"},
)
async def test_focus_payload_shape_by_ref(self):
worker, sent = _make_worker()
await worker.send_command("focus", Focus(ref="e7"))
self.assertEqual(
sent[0].payload,
{"ref": "e7", "target_id": None},
)
async def test_select_text_payload_whole_element(self):
worker, sent = _make_worker()
await worker.send_command("select_text", SelectText(ref="e42"))
self.assertEqual(
sent[0].payload,
{
"ref": "e42",
"target_id": None,
"start_offset": None,
"end_offset": None,
},
)
async def test_select_text_payload_with_offsets(self):
worker, sent = _make_worker()
await worker.send_command(
"select_text",
SelectText(ref="e42", start_offset=5, end_offset=12),
)
self.assertEqual(
sent[0].payload,
{
"ref": "e42",
"target_id": None,
"start_offset": 5,
"end_offset": 12,
},
)
async def test_set_input_value_payload_replace_default(self):
worker, sent = _make_worker()
await worker.send_command(
"set_input_value",
SetInputValue(ref="e7", value="hello"),
)
self.assertEqual(
sent[0].payload,
{
"ref": "e7",
"target_id": None,
"value": "hello",
"replace": True,
},
)
async def test_set_input_value_payload_append(self):
worker, sent = _make_worker()
await worker.send_command(
"set_input_value",
SetInputValue(ref="e7", value=" world", replace=False),
)
self.assertEqual(
sent[0].payload,
{
"ref": "e7",
"target_id": None,
"value": " world",
"replace": False,
},
)
async def test_click_payload_by_ref(self):
worker, sent = _make_worker()
await worker.send_command("click", Click(ref="e42"))
self.assertEqual(sent[0].payload, {"ref": "e42", "target_id": None})
async def test_click_payload_by_target_id(self):
worker, sent = _make_worker()
await worker.send_command("click", Click(target_id="submit"))
self.assertEqual(sent[0].payload, {"ref": None, "target_id": "submit"})
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,366 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for the UIWorker user-job-group lifecycle.
Covers:
- ``UIWorker.on_bus_message`` forwarding of worker job updates/responses
for registered user job groups as ``BusUITask*`` carriers.
- The reserved ``__cancel_task`` client event routing to
``cancel_job_group``.
- ``UserJobGroupContext`` publishing ``group_started`` / ``group_completed``
envelopes and (de)registering the group on the worker.
"""
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock
from pipecat.bus.messages import BusJobResponseMessage, BusJobUpdateMessage
from pipecat.bus.ui_messages import (
_UI_CANCEL_TASK_BUS_EVENT_NAME,
BusUIEventMessage,
BusUITaskCompletedMessage,
BusUITaskGroupCompletedMessage,
BusUITaskGroupStartedMessage,
BusUITaskUpdateMessage,
)
from pipecat.frames.frames import LLMMessagesAppendFrame
from pipecat.pipeline.job_context import JobGroup, JobStatus
from pipecat.processors.frame_processor import FrameDirection
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
from pipecat.workers.ui import UIWorker
async def _make_solo_worker(**kwargs) -> UIWorker:
"""A UIWorker with a task manager and a ``queue_frame`` spy.
Suitable for testing forwarding logic by directly invoking
``on_bus_message`` and asserting on captured ``send_bus_message``
calls.
"""
worker = UIWorker("ui", llm=MagicMock(), active=False, **kwargs)
tm = TaskManager()
tm.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
worker._task_manager = tm
recorded: list = []
async def _record(frame, direction=FrameDirection.DOWNSTREAM):
recorded.append(frame)
worker.queue_frame = _record # type: ignore[method-assign]
worker._recorded = recorded # type: ignore[attr-defined]
return worker
class TestUIWorkerForwarding(unittest.IsolatedAsyncioTestCase):
async def test_unregistered_job_update_is_not_forwarded(self):
worker = await _make_solo_worker()
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusJobUpdateMessage(
source="worker", target=worker.name, job_id="t-unknown", update={"x": 1}
)
)
forwarded = [
c.args[0]
for c in worker.send_bus_message.await_args_list
if isinstance(c.args[0], BusUITaskUpdateMessage)
]
self.assertEqual(forwarded, [])
async def test_registered_job_update_is_forwarded(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["worker"], label="hello", cancellable=True
)
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusJobUpdateMessage(
source="worker",
target=worker.name,
job_id="t1",
update={"kind": "tool_call", "tool": "WebSearch"},
)
)
forwarded = [
c.args[0]
for c in worker.send_bus_message.await_args_list
if isinstance(c.args[0], BusUITaskUpdateMessage)
]
self.assertEqual(len(forwarded), 1)
self.assertEqual(forwarded[0].task_id, "t1")
self.assertEqual(forwarded[0].agent_name, "worker")
self.assertEqual(forwarded[0].data, {"kind": "tool_call", "tool": "WebSearch"})
async def test_registered_job_response_is_forwarded(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["worker"], label=None, cancellable=True
)
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusJobResponseMessage(
source="worker",
target=worker.name,
job_id="t1",
status=JobStatus.COMPLETED,
response={"answer": 42},
)
)
forwarded = [
c.args[0]
for c in worker.send_bus_message.await_args_list
if isinstance(c.args[0], BusUITaskCompletedMessage)
]
self.assertEqual(len(forwarded), 1)
self.assertEqual(forwarded[0].task_id, "t1")
self.assertEqual(forwarded[0].agent_name, "worker")
self.assertEqual(forwarded[0].status, "completed")
self.assertEqual(forwarded[0].response, {"answer": 42})
async def test_response_status_serializes_for_cancelled_and_error(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["w"], label=None, cancellable=True
)
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusJobResponseMessage(
source="w", target=worker.name, job_id="t1", status=JobStatus.CANCELLED
)
)
await worker.on_bus_message(
BusJobResponseMessage(
source="w", target=worker.name, job_id="t1", status=JobStatus.ERROR
)
)
statuses = [
c.args[0].status
for c in worker.send_bus_message.await_args_list
if isinstance(c.args[0], BusUITaskCompletedMessage)
]
self.assertEqual(statuses, ["cancelled", "error"])
class TestCancelJobEvent(unittest.IsolatedAsyncioTestCase):
async def test_cancel_event_routes_to_cancel_job_group(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["w"], label=None, cancellable=True
)
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusUIEventMessage(
source="bridge",
target=worker.name,
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
payload={"task_id": "t1", "reason": "user clicked cancel"},
)
)
worker.cancel_job_group.assert_awaited_once_with("t1", reason="user clicked cancel")
async def test_cancel_event_default_reason_when_omitted(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["w"], label=None, cancellable=True
)
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusUIEventMessage(
source="bridge",
target=worker.name,
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
payload={"task_id": "t1"},
)
)
worker.cancel_job_group.assert_awaited_once()
self.assertEqual(worker.cancel_job_group.await_args.kwargs["reason"], "cancelled by user")
async def test_non_cancellable_group_is_ignored(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["w"], label=None, cancellable=False
)
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusUIEventMessage(
source="bridge",
target=worker.name,
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
payload={"task_id": "t1"},
)
)
worker.cancel_job_group.assert_not_awaited()
async def test_unknown_job_id_is_ignored(self):
worker = await _make_solo_worker()
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusUIEventMessage(
source="bridge",
target=worker.name,
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
payload={"task_id": "nope"},
)
)
worker.cancel_job_group.assert_not_awaited()
async def test_missing_or_bad_payload_is_ignored(self):
worker = await _make_solo_worker()
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
await worker.on_bus_message(
BusUIEventMessage(
source="bridge",
target=worker.name,
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
payload=None,
)
)
await worker.on_bus_message(
BusUIEventMessage(
source="bridge",
target=worker.name,
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
payload={"task_id": 42},
)
)
worker.cancel_job_group.assert_not_awaited()
class TestForwardingDoesNotInjectLLMContext(unittest.IsolatedAsyncioTestCase):
async def test_job_update_forwarding_does_not_queue_append_frames(self):
worker = await _make_solo_worker()
worker._register_user_job_group(
job_id="t1", worker_names=["w"], label=None, cancellable=True
)
await worker.on_bus_message(
BusJobUpdateMessage(source="w", target=worker.name, job_id="t1", update={"x": 1})
)
appends = [f for f in worker._recorded if isinstance(f, LLMMessagesAppendFrame)]
self.assertEqual(appends, [])
def _stub_job_group(worker, job_id="t1", worker_names=("w1",)):
"""Make ``create_job_group_and_request_job`` return a self-completing group.
The group completes on the next loop tick so both the context-manager
and fire-and-forget paths terminate without a running bus or workers.
"""
async def _fake_create(names, *, name=None, payload=None, timeout=None, cancel_on_error=True):
group = JobGroup(job_id=job_id, worker_names=set(names))
async def _finish():
# Yield so JobGroupContext.__aenter__ can set event_queue first.
await asyncio.sleep(0)
group.complete()
asyncio.create_task(_finish())
return group
worker.create_job_group_and_request_job = _fake_create # type: ignore[method-assign]
class TestUserJobGroupContext(unittest.IsolatedAsyncioTestCase):
async def test_context_publishes_started_and_completed(self):
worker = await _make_solo_worker()
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
_stub_job_group(worker)
async with worker.user_job_group("w1", label="My research") as tg:
self.assertEqual(tg.job_id, "t1")
self.assertIn("t1", worker._user_job_groups)
self.assertNotIn("t1", worker._user_job_groups)
kinds = [type(c.args[0]).__name__ for c in worker.send_bus_message.await_args_list]
self.assertEqual(
kinds,
["BusUITaskGroupStartedMessage", "BusUITaskGroupCompletedMessage"],
)
started = worker.send_bus_message.await_args_list[0].args[0]
self.assertIsInstance(started, BusUITaskGroupStartedMessage)
self.assertEqual(started.task_id, "t1")
self.assertEqual(started.agents, ["w1"])
self.assertEqual(started.label, "My research")
self.assertTrue(started.cancellable)
completed = worker.send_bus_message.await_args_list[1].args[0]
self.assertIsInstance(completed, BusUITaskGroupCompletedMessage)
self.assertEqual(completed.task_id, "t1")
async def test_non_cancellable_group_sets_flag_in_started_message(self):
worker = await _make_solo_worker()
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
_stub_job_group(worker)
async with worker.user_job_group("w1", cancellable=False):
pass
started = worker.send_bus_message.await_args_list[0].args[0]
self.assertFalse(started.cancellable)
async def test_unregisters_on_exit(self):
worker = await _make_solo_worker()
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
_stub_job_group(worker)
async with worker.user_job_group("w1") as tg:
pass
self.assertNotIn(tg.job_id, worker._user_job_groups)
async def test_start_user_job_group_returns_id_and_publishes(self):
worker = await _make_solo_worker()
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
_stub_job_group(worker)
job_id = await worker.start_user_job_group("w1", label="Background work")
self.assertEqual(job_id, "t1")
started = worker.send_bus_message.await_args_list[0].args[0]
self.assertIsInstance(started, BusUITaskGroupStartedMessage)
self.assertEqual(started.label, "Background work")
# The background runner drains the group and publishes completion.
for _ in range(50):
await asyncio.sleep(0)
if any(
isinstance(c.args[0], BusUITaskGroupCompletedMessage)
for c in worker.send_bus_message.await_args_list
):
break
else:
self.fail("group_completed envelope was not published")
self.assertNotIn("t1", worker._user_job_groups)
if __name__ == "__main__":
unittest.main()

418
tests/test_ui_tools.py Normal file
View File

@@ -0,0 +1,418 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for ``ReplyToolMixin`` and the action helper methods on ``UIWorker``.
The mixin exposes a single bundled ``reply(answer, scroll_to,
highlight, ...)`` LLM tool whose ``answer`` argument is required. The
helper methods (``scroll_to``, ``highlight``, ...) are plain instance
methods on ``UIWorker`` that wrap ``send_command`` with the standard
payload models; apps call them inside custom ``@tool`` bodies when the
canonical ``reply`` shape doesn't fit.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock
from pipecat.bus.ui_messages import BusUICommandMessage
from pipecat.workers.llm.tool_decorator import _collect_tools
from pipecat.workers.ui import ReplyToolMixin, UIWorker
class _WorkerWithReply(ReplyToolMixin, UIWorker):
pass
class _PlainWorker(UIWorker):
pass
def _new(cls: type) -> UIWorker:
return cls("ui", llm=MagicMock(), active=False)
def _capture(worker: UIWorker) -> list[BusUICommandMessage]:
sent: list[BusUICommandMessage] = []
async def _record(message):
sent.append(message)
worker.send_bus_message = _record # type: ignore[method-assign]
return sent
class TestUIWorkerActionHelpers(unittest.IsolatedAsyncioTestCase):
"""The helper methods are plain methods, not LLM tools."""
async def test_scroll_to_helper_dispatches_command(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.scroll_to("e42")
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].command_name, "scroll_to")
self.assertEqual(
sent[0].payload,
{"ref": "e42", "target_id": None, "behavior": None},
)
async def test_highlight_helper_dispatches_command(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.highlight("e7")
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].command_name, "highlight")
self.assertEqual(
sent[0].payload,
{"ref": "e7", "target_id": None, "duration_ms": None},
)
async def test_select_text_helper_whole_element(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.select_text("e42")
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].command_name, "select_text")
self.assertEqual(
sent[0].payload,
{
"ref": "e42",
"target_id": None,
"start_offset": None,
"end_offset": None,
},
)
async def test_select_text_helper_with_offsets(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.select_text("e42", start_offset=10, end_offset=25)
self.assertEqual(len(sent), 1)
self.assertEqual(
sent[0].payload,
{
"ref": "e42",
"target_id": None,
"start_offset": 10,
"end_offset": 25,
},
)
async def test_click_helper_dispatches_command(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.click("e42")
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].command_name, "click")
self.assertEqual(sent[0].payload, {"ref": "e42", "target_id": None})
async def test_set_input_value_helper_default_replace(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.set_input_value("e42", "hello world")
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].command_name, "set_input_value")
self.assertEqual(
sent[0].payload,
{
"ref": "e42",
"target_id": None,
"value": "hello world",
"replace": True,
},
)
async def test_set_input_value_helper_append_mode(self):
worker = _new(_PlainWorker)
sent = _capture(worker)
await worker.set_input_value("e42", "more text", replace=False)
self.assertEqual(sent[0].payload["replace"], False)
async def test_helpers_are_not_llm_tools(self):
worker = _new(_PlainWorker)
tool_names = [t.__name__ for t in _collect_tools(worker)]
for name in ("scroll_to", "highlight", "select_text", "click", "set_input_value"):
self.assertNotIn(name, tool_names)
class TestReplyToolMixin(unittest.IsolatedAsyncioTestCase):
async def test_mixin_exposes_reply_tool(self):
worker = _new(_WorkerWithReply)
tool_names = [t.__name__ for t in _collect_tools(worker)]
self.assertEqual(tool_names, ["reply"])
async def test_plain_uiworker_has_no_reply_tool(self):
worker = _new(_PlainWorker)
tool_names = [t.__name__ for t in _collect_tools(worker)]
self.assertNotIn("reply", tool_names)
async def test_reply_with_answer_only_terminates(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(params, answer="The Pixel 9 is from Google.")
self.assertEqual(sent, [])
worker.respond_to_job.assert_awaited_once_with(speak="The Pixel 9 is from Google.")
params.result_callback.assert_awaited_once_with(None)
async def test_reply_with_highlight_only(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="This one, the Nothing Phone 3.",
highlight=["e29"],
)
self.assertEqual([m.command_name for m in sent], ["highlight"])
self.assertEqual(sent[0].payload["ref"], "e29")
worker.respond_to_job.assert_awaited_once_with(speak="This one, the Nothing Phone 3.")
async def test_reply_with_multiple_highlights(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="Here are the Apple phones.",
highlight=["e5", "e8", "e47"],
)
self.assertEqual([m.command_name for m in sent], ["highlight"] * 3)
self.assertEqual([m.payload["ref"] for m in sent], ["e5", "e8", "e47"])
worker.respond_to_job.assert_awaited_once_with(speak="Here are the Apple phones.")
async def test_reply_with_scroll_and_highlight_runs_in_order(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="Here's the iPhone 17.",
scroll_to="e5",
highlight=["e5"],
)
self.assertEqual([m.command_name for m in sent], ["scroll_to", "highlight"])
worker.respond_to_job.assert_awaited_once_with(speak="Here's the iPhone 17.")
async def test_reply_with_select_text_only(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="Here, in this paragraph.",
select_text="e11",
)
self.assertEqual([m.command_name for m in sent], ["select_text"])
self.assertEqual(sent[0].payload["ref"], "e11")
worker.respond_to_job.assert_awaited_once_with(speak="Here, in this paragraph.")
async def test_reply_with_scroll_and_select_text(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="Here, in this paragraph.",
scroll_to="e11",
select_text="e11",
)
self.assertEqual([m.command_name for m in sent], ["scroll_to", "select_text"])
async def test_reply_with_fills_writes_each_input(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="Got it.",
fills=[
{"ref": "e5", "value": "Mark"},
{"ref": "e7", "value": "Backman"},
],
)
self.assertEqual(
[m.command_name for m in sent],
["set_input_value", "set_input_value"],
)
self.assertEqual(sent[0].payload["ref"], "e5")
self.assertEqual(sent[0].payload["value"], "Mark")
self.assertEqual(sent[1].payload["ref"], "e7")
self.assertEqual(sent[1].payload["value"], "Backman")
async def test_reply_with_click_clicks_each_in_order(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(params, answer="Submitted.", click=["e22", "e26"])
self.assertEqual([m.command_name for m in sent], ["click", "click"])
self.assertEqual([m.payload["ref"] for m in sent], ["e22", "e26"])
async def test_reply_with_fills_skips_invalid_entries(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="x",
fills=[
{"ref": "e5", "value": "Mark"},
{"ref": None, "value": "missing ref"},
{"value": "no ref"},
{"ref": "e7"},
],
)
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].payload["ref"], "e5")
async def test_reply_with_non_dict_fill_entries_does_not_crash(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="x",
fills=[None, "e5", 42, {"ref": "e9", "value": "ok"}], # type: ignore[list-item]
)
self.assertEqual(len(sent), 1)
self.assertEqual(sent[0].payload["ref"], "e9")
worker.respond_to_job.assert_awaited_once_with(speak="x")
params.result_callback.assert_awaited_once_with(None)
async def test_reply_with_non_string_highlight_refs_skipped(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="x",
highlight=[None, "e1", 42, "e2"], # type: ignore[list-item]
)
self.assertEqual([m.payload["ref"] for m in sent], ["e1", "e2"])
worker.respond_to_job.assert_awaited_once_with(speak="x")
async def test_reply_with_non_string_click_refs_skipped(self):
worker = _new(_WorkerWithReply)
sent = _capture(worker)
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="x",
click=[None, "e1", {"ref": "e2"}, "e3"], # type: ignore[list-item]
)
self.assertEqual([m.payload["ref"] for m in sent], ["e1", "e3"])
worker.respond_to_job.assert_awaited_once_with(speak="x")
async def test_reply_dispatches_via_helper_methods(self):
worker = _new(_WorkerWithReply)
worker.scroll_to = AsyncMock() # type: ignore[method-assign]
worker.highlight = AsyncMock() # type: ignore[method-assign]
worker.select_text = AsyncMock() # type: ignore[method-assign]
worker.set_input_value = AsyncMock() # type: ignore[method-assign]
worker.click = AsyncMock() # type: ignore[method-assign]
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
params = MagicMock()
params.result_callback = AsyncMock()
await worker.reply(
params,
answer="x",
scroll_to="e1",
highlight=["e2", "e3"],
select_text="e4",
fills=[{"ref": "e5", "value": "v"}],
click=["e6", "e7"],
)
worker.scroll_to.assert_awaited_once_with("e1")
self.assertEqual(
worker.highlight.await_args_list,
[unittest.mock.call("e2"), unittest.mock.call("e3")],
)
worker.select_text.assert_awaited_once_with("e4")
worker.set_input_value.assert_awaited_once_with("e5", "v")
self.assertEqual(
worker.click.await_args_list,
[unittest.mock.call("e6"), unittest.mock.call("e7")],
)
if __name__ == "__main__":
unittest.main()

805
tests/test_ui_worker.py Normal file
View File

@@ -0,0 +1,805 @@
#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for UIWorker dispatch, LLM-context injection, and single-flight respond."""
import asyncio
import json
import unittest
from unittest.mock import AsyncMock, MagicMock
from pipecat.bus.messages import BusJobCancelMessage, BusJobRequestMessage
from pipecat.bus.ui_messages import _UI_SNAPSHOT_BUS_EVENT_NAME, BusUIEventMessage
from pipecat.frames.frames import LLMMessagesAppendFrame, LLMMessagesUpdateFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
from pipecat.workers.ui import UI_STATE_PROMPT_GUIDE, UIWorker, on_ui_event
class _StubUIWorker(UIWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.captured: list[BusUIEventMessage] = []
@on_ui_event("nav_click")
async def _on_nav(self, message: BusUIEventMessage) -> None:
self.captured.append(message)
class _PlainWorker(UIWorker):
pass
async def _make_worker(cls=_StubUIWorker, **kwargs) -> UIWorker:
"""A UIWorker wired with a task manager and a ``queue_frame`` spy.
``queue_frame`` is replaced with a recorder so tests can assert the
frames the UI logic produces without running the pipeline. The
respond handler sends its own response, so ``send_job_response`` is
mocked to let ``respond_with_llm`` run without an active job entry.
"""
worker = cls("ui", llm=MagicMock(), active=False, **kwargs)
tm = TaskManager()
tm.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
worker._task_manager = tm
recorded: list = []
async def _record(frame, direction=FrameDirection.DOWNSTREAM):
recorded.append(frame)
worker.queue_frame = _record # type: ignore[method-assign]
worker._recorded = recorded # type: ignore[attr-defined]
worker.send_job_response = AsyncMock() # type: ignore[method-assign]
return worker
async def _settle() -> None:
"""Yield enough times for spawned task handlers to run/park."""
for _ in range(10):
await asyncio.sleep(0)
async def _start(worker: UIWorker, message: BusJobRequestMessage) -> asyncio.Task:
"""Start a respond turn; it parks at ``await self._pending`` until resolved."""
t = asyncio.create_task(worker.respond_with_llm(message))
await _settle()
return t
def _respond_msg(job_id: str, query: str = "hi") -> BusJobRequestMessage:
return BusJobRequestMessage(
source="voice", target="ui", job_name="respond", job_id=job_id, payload={"query": query}
)
async def _dispatch(worker: UIWorker, message: BusUIEventMessage) -> None:
await worker.on_bus_message(message)
for _ in range(5):
await asyncio.sleep(0)
def _append_frames(worker) -> list[LLMMessagesAppendFrame]:
return [f for f in worker._recorded if isinstance(f, LLMMessagesAppendFrame)]
def _update_frames(worker) -> list[LLMMessagesUpdateFrame]:
return [f for f in worker._recorded if isinstance(f, LLMMessagesUpdateFrame)]
class TestUIWorkerDispatch(unittest.IsolatedAsyncioTestCase):
async def test_dispatches_to_matching_on_ui_event_handler(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
),
)
self.assertEqual(len(worker.captured), 1)
self.assertEqual(worker.captured[0].event_name, "nav_click")
self.assertEqual(worker.captured[0].payload, {"view": "home"})
async def test_unknown_event_name_does_not_raise(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music", target="ui", event_name="never_registered", payload={"x": 1}
),
)
self.assertEqual(worker.captured, [])
async def test_ignores_events_targeted_at_other_workers(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music",
target="someone_else",
event_name="nav_click",
payload={"view": "home"},
),
)
self.assertEqual(worker.captured, [])
self.assertEqual(_append_frames(worker), [])
async def test_broadcast_event_with_no_target_is_handled(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music", target=None, event_name="nav_click", payload={"view": "home"}
),
)
self.assertEqual(len(worker.captured), 1)
async def test_handler_runs_in_separate_task_so_bus_is_not_blocked(self):
gate = asyncio.Event()
observed: list[bool] = []
class _BlockingWorker(_StubUIWorker):
@on_ui_event("slow")
async def _slow(self, message):
await gate.wait()
observed.append(True)
worker = await _make_worker(cls=_BlockingWorker)
await worker.on_bus_message(
BusUIEventMessage(source="music", target="ui", event_name="slow", payload={})
)
self.assertEqual(observed, [])
gate.set()
await _settle()
self.assertEqual(observed, [True])
async def test_duplicate_handler_names_raise_at_init(self):
with self.assertRaises(ValueError):
class _Bad(UIWorker):
@on_ui_event("nav")
async def a(self, message):
pass
@on_ui_event("nav")
async def b(self, message):
pass
_Bad("ui", llm=MagicMock())
async def test_bridged_with_default_auto_inject_raises(self):
with self.assertRaises(ValueError) as ctx:
_PlainWorker("ui", llm=MagicMock(), bridged=())
self.assertIn("bridged", str(ctx.exception))
self.assertIn("auto_inject_ui_state", str(ctx.exception))
async def test_bridged_with_explicit_auto_inject_disabled_is_allowed(self):
worker = _PlainWorker("ui", llm=MagicMock(), bridged=(), auto_inject_ui_state=False)
self.assertFalse(worker._auto_inject_ui_state)
async def test_default_construction_unaffected(self):
worker = _PlainWorker("ui", llm=MagicMock())
self.assertTrue(worker._auto_inject_ui_state)
self.assertTrue(worker.active)
class TestUIWorkerInjection(unittest.IsolatedAsyncioTestCase):
async def test_injects_xml_developer_message_by_default(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
),
)
frames = _append_frames(worker)
self.assertEqual(len(frames), 1)
frame = frames[0]
self.assertFalse(frame.run_llm)
self.assertEqual(len(frame.messages), 1)
msg = frame.messages[0]
self.assertEqual(msg["role"], "developer")
content = msg["content"]
self.assertIn('<ui_event name="nav_click">', content)
self.assertIn("</ui_event>", content)
inner = content[len('<ui_event name="nav_click">') : -len("</ui_event>")]
self.assertEqual(json.loads(inner), {"view": "home"})
async def test_inject_events_false_disables_injection(self):
worker = await _make_worker(inject_events=False)
await _dispatch(
worker,
BusUIEventMessage(
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
),
)
self.assertEqual(_append_frames(worker), [])
self.assertEqual(len(worker.captured), 1)
async def test_render_override_replaces_default_xml(self):
class _CustomRender(_StubUIWorker):
def render_ui_event(self, message):
return f"[UI] {message.event_name}"
worker = await _make_worker(cls=_CustomRender)
await _dispatch(
worker,
BusUIEventMessage(
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
),
)
frames = _append_frames(worker)
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].messages[0]["content"], "[UI] nav_click")
async def test_empty_render_skips_injection(self):
class _NoRender(_StubUIWorker):
def render_ui_event(self, message):
return ""
worker = await _make_worker(cls=_NoRender)
await _dispatch(
worker,
BusUIEventMessage(
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
),
)
self.assertEqual(_append_frames(worker), [])
_SAMPLE_SNAPSHOT = {
"root": {
"ref": "e1",
"role": "generic",
"children": [
{
"ref": "e2",
"role": "main",
"children": [
{"ref": "e3", "role": "heading", "name": "Home", "level": 1},
{
"ref": "e4",
"role": "region",
"name": "Trending artists",
"children": [
{"ref": "e5", "role": "button", "name": "Bad Bunny"},
{
"ref": "e6",
"role": "button",
"name": "Taylor Swift",
"state": ["focused"],
},
],
},
],
},
],
},
"captured_at": 1700000000000,
}
class TestUIWorkerSnapshot(unittest.IsolatedAsyncioTestCase):
async def test_reserved_snapshot_event_stored_without_dispatch(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music",
target="ui",
event_name=_UI_SNAPSHOT_BUS_EVENT_NAME,
payload=_SAMPLE_SNAPSHOT,
),
)
self.assertEqual(worker._latest_snapshot, _SAMPLE_SNAPSHOT)
self.assertEqual(worker.captured, [])
self.assertEqual(_append_frames(worker), [])
async def test_non_dict_snapshot_payload_is_ignored(self):
worker = await _make_worker()
await _dispatch(
worker,
BusUIEventMessage(
source="music",
target="ui",
event_name=_UI_SNAPSHOT_BUS_EVENT_NAME,
payload="not a snapshot",
),
)
self.assertIsNone(worker._latest_snapshot)
async def test_render_ui_state_empty_without_snapshot(self):
worker = await _make_worker()
self.assertEqual(worker.render_ui_state(), "")
async def test_render_ui_state_produces_indented_block(self):
worker = await _make_worker()
worker._latest_snapshot = _SAMPLE_SNAPSHOT
rendered = worker.render_ui_state()
self.assertTrue(rendered.startswith("<ui_state>\n"))
self.assertTrue(rendered.endswith("\n</ui_state>"))
self.assertIn("- generic [ref=e1]:", rendered)
self.assertIn("- main [ref=e2]:", rendered)
self.assertIn('- heading "Home" [level=1] [ref=e3]', rendered)
self.assertIn('- region "Trending artists" [ref=e4]:', rendered)
self.assertIn('- button "Bad Bunny" [ref=e5]', rendered)
self.assertIn('- button "Taylor Swift" [focused] [ref=e6]', rendered)
self.assertIn(" - main", rendered)
self.assertIn(" - heading", rendered)
self.assertIn(" - button", rendered)
async def test_inject_ui_state_queues_expected_frame(self):
worker = await _make_worker()
worker._latest_snapshot = _SAMPLE_SNAPSHOT
await worker.inject_ui_state()
frames = _append_frames(worker)
self.assertEqual(len(frames), 1)
frame = frames[0]
self.assertFalse(frame.run_llm)
self.assertEqual(len(frame.messages), 1)
msg = frame.messages[0]
self.assertEqual(msg["role"], "developer")
self.assertTrue(msg["content"].startswith("<ui_state>"))
self.assertTrue(msg["content"].endswith("</ui_state>"))
async def test_inject_ui_state_no_op_without_snapshot(self):
worker = await _make_worker()
await worker.inject_ui_state()
self.assertEqual(_append_frames(worker), [])
async def test_render_emits_grid_dims(self):
worker = await _make_worker()
worker._latest_snapshot = {
"root": {
"ref": "e1",
"role": "generic",
"children": [
{
"ref": "e2",
"role": "grid",
"name": "Trending artists",
"colcount": 8,
"rowcount": 2,
"children": [{"ref": "e3", "role": "button", "name": "Bad Bunny"}],
},
],
},
"captured_at": 1700000000000,
}
rendered = worker.render_ui_state()
self.assertIn('- grid "Trending artists" [cols=8] [rows=2] [ref=e2]', rendered)
async def test_render_preserves_offscreen_tag(self):
worker = await _make_worker()
worker._latest_snapshot = {
"root": {
"ref": "e1",
"role": "generic",
"children": [
{"ref": "e2", "role": "button", "name": "Visible"},
{"ref": "e3", "role": "button", "name": "Below fold", "state": ["offscreen"]},
],
},
"captured_at": 1700000000000,
}
rendered = worker.render_ui_state()
self.assertIn('- button "Visible" [ref=e2]', rendered)
self.assertIn('- button "Below fold" [offscreen] [ref=e3]', rendered)
async def test_visible_nodes_empty_without_snapshot(self):
worker = await _make_worker()
self.assertEqual(worker.visible_nodes(), [])
async def test_render_emits_selection_block_when_present(self):
worker = await _make_worker()
worker._latest_snapshot = {
"root": {"ref": "e1", "role": "generic", "children": [{"ref": "e2", "role": "main"}]},
"captured_at": 1700000000000,
"selection": {"ref": "e2", "text": "the highlighted passage"},
}
rendered = worker.render_ui_state()
self.assertIn('<selection ref="e2">', rendered)
self.assertIn("the highlighted passage", rendered)
self.assertIn("</selection>", rendered)
self.assertTrue(rendered.endswith("</ui_state>"))
sel_idx = rendered.index('<selection ref="e2">')
close_idx = rendered.index("</ui_state>")
self.assertLess(sel_idx, close_idx)
async def test_render_omits_selection_when_missing(self):
worker = await _make_worker()
worker._latest_snapshot = {
"root": {"ref": "e1", "role": "generic"},
"captured_at": 1700000000000,
}
rendered = worker.render_ui_state()
self.assertNotIn("<selection", rendered)
async def test_render_skips_selection_with_missing_ref_or_text(self):
worker = await _make_worker()
worker._latest_snapshot = {
"root": {"ref": "e1", "role": "generic"},
"captured_at": 1,
"selection": {"ref": "e2"},
}
self.assertNotIn("<selection", worker.render_ui_state())
worker._latest_snapshot = {
"root": {"ref": "e1", "role": "generic"},
"captured_at": 1,
"selection": {"text": "stuff"},
}
self.assertNotIn("<selection", worker.render_ui_state())
async def test_visible_nodes_filters_offscreen_entries(self):
worker = await _make_worker()
worker._latest_snapshot = {
"root": {
"ref": "e1",
"role": "generic",
"children": [
{"ref": "e2", "role": "button", "name": "Visible"},
{"ref": "e3", "role": "button", "name": "Below fold", "state": ["offscreen"]},
{
"ref": "e4",
"role": "region",
"name": "Tracks",
"state": ["offscreen"],
"children": [{"ref": "e5", "role": "button", "name": "Bloom"}],
},
],
},
"captured_at": 1700000000000,
}
refs = [n["ref"] for n in worker.visible_nodes()]
self.assertIn("e1", refs)
self.assertIn("e2", refs)
self.assertNotIn("e3", refs)
self.assertNotIn("e4", refs)
self.assertIn("e5", refs)
class TestUIWorkerAutoInject(unittest.IsolatedAsyncioTestCase):
async def test_respond_auto_injects_latest_snapshot(self):
worker = await _make_worker()
worker._latest_snapshot = _SAMPLE_SNAPSHOT
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
)
frames = _append_frames(worker)
self.assertEqual(len(frames), 2)
self.assertEqual(frames[0].messages[0]["role"], "developer")
self.assertTrue(frames[0].messages[0]["content"].startswith("<ui_state>"))
self.assertFalse(frames[0].run_llm)
self.assertEqual(frames[1].messages[0]["content"], "hi")
self.assertTrue(frames[1].run_llm)
await worker.respond_to_job()
await t
async def test_auto_inject_ui_state_false_suppresses_injection(self):
worker = await _make_worker(auto_inject_ui_state=False)
worker._latest_snapshot = _SAMPLE_SNAPSHOT
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
)
frames = _append_frames(worker)
self.assertEqual(len(frames), 1)
self.assertFalse(frames[0].messages[0]["content"].startswith("<ui_state>"))
await worker.respond_to_job()
await t
async def test_auto_inject_no_op_without_snapshot(self):
worker = await _make_worker()
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
)
frames = _append_frames(worker)
self.assertEqual(len(frames), 1)
self.assertFalse(frames[0].messages[0]["content"].startswith("<ui_state>"))
await worker.respond_to_job()
await t
class TestUIWorkerKeepHistory(unittest.IsolatedAsyncioTestCase):
async def test_default_resets_context_per_job(self):
worker = await _make_worker()
worker._latest_snapshot = _SAMPLE_SNAPSHOT
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
)
updates = _update_frames(worker)
self.assertEqual(len(updates), 1)
self.assertEqual(updates[0].messages, [])
self.assertFalse(updates[0].run_llm)
await worker.respond_to_job()
await t
async def test_reset_runs_before_inject(self):
worker = await _make_worker()
worker._latest_snapshot = _SAMPLE_SNAPSHOT
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
)
frame_types = [type(f).__name__ for f in worker._recorded]
update_idx = frame_types.index("LLMMessagesUpdateFrame")
append_idx = frame_types.index("LLMMessagesAppendFrame")
self.assertLess(update_idx, append_idx)
await worker.respond_to_job()
await t
async def test_keep_history_true_skips_reset(self):
worker = await _make_worker(keep_history=True)
worker._latest_snapshot = _SAMPLE_SNAPSHOT
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
)
self.assertEqual(_update_frames(worker), [])
self.assertEqual(len(_append_frames(worker)), 2)
await worker.respond_to_job()
await t
async def test_reset_context_method_emits_update_frame(self):
worker = await _make_worker(keep_history=True)
await worker.reset_context()
updates = _update_frames(worker)
self.assertEqual(len(updates), 1)
self.assertEqual(updates[0].messages, [])
self.assertFalse(updates[0].run_llm)
class TestUIWorkerRespondToJob(unittest.IsolatedAsyncioTestCase):
async def test_current_job_tracks_in_flight_request(self):
worker = await _make_worker()
self.assertIsNone(worker.current_job)
message = BusJobRequestMessage(
source="voice", target="ui", job_id="t1", payload={"query": "hi"}
)
t = await _start(worker, message)
self.assertIs(worker.current_job, message)
await worker.respond_to_job()
await t
async def test_respond_to_job_clears_current_and_sends_response(self):
worker = await _make_worker()
message = BusJobRequestMessage(source="voice", target="ui", job_id="t1")
t = await _start(worker, message)
await worker.respond_to_job(speak="hello")
await t
worker.send_job_response.assert_awaited_once()
call = worker.send_job_response.await_args
self.assertEqual(call.args[0], "t1")
self.assertEqual(call.kwargs["response"], {"speak": "hello"})
self.assertIsNone(worker.current_job)
async def test_respond_to_job_no_op_when_idle(self):
worker = await _make_worker()
await worker.respond_to_job(speak="hello")
worker.send_job_response.assert_not_awaited()
async def test_respond_to_job_omits_speak_when_none(self):
worker = await _make_worker()
t = await _start(worker, BusJobRequestMessage(source="voice", target="ui", job_id="t1"))
await worker.respond_to_job()
await t
call = worker.send_job_response.await_args
self.assertEqual(call.kwargs["response"], {})
async def test_respond_to_job_merges_speak_into_response(self):
worker = await _make_worker()
t = await _start(worker, BusJobRequestMessage(source="voice", target="ui", job_id="t1"))
await worker.respond_to_job({"description": "scrolled"}, speak="ok")
await t
call = worker.send_job_response.await_args
self.assertEqual(call.kwargs["response"], {"description": "scrolled", "speak": "ok"})
async def test_cancellation_frees_lock_for_subsequent_jobs(self):
worker = await _make_worker()
msg_a = _respond_msg("a")
msg_b = _respond_msg("b")
await worker._handle_job_request(msg_a)
await _settle()
self.assertIs(worker.current_job, msg_a)
self.assertTrue(worker._job_locks["respond"].locked())
await worker._handle_job_cancel(
BusJobCancelMessage(source="voice", target="ui", job_id="a")
)
await _settle()
self.assertIsNone(worker.current_job)
self.assertFalse(worker._job_locks["respond"].locked())
await worker._handle_job_request(msg_b)
await _settle()
self.assertIs(worker.current_job, msg_b)
await worker.respond_to_job(speak="B done")
await _settle()
async def test_cancel_unknown_job_id_is_noop(self):
worker = await _make_worker()
msg_a = _respond_msg("a")
await worker._handle_job_request(msg_a)
await _settle()
await worker._handle_job_cancel(
BusJobCancelMessage(source="voice", target="ui", job_id="unrelated")
)
await _settle()
self.assertIs(worker.current_job, msg_a)
self.assertTrue(worker._job_locks["respond"].locked())
await worker.respond_to_job(speak="A done")
await _settle()
async def test_concurrent_same_name_jobs_serialize(self):
worker = await _make_worker()
msg_a = _respond_msg("a")
msg_b = _respond_msg("b")
await worker._handle_job_request(msg_a)
await _settle()
self.assertIs(worker.current_job, msg_a)
await worker._handle_job_request(msg_b)
await _settle()
self.assertIs(worker.current_job, msg_a)
await worker.respond_to_job(speak="A done")
await _settle()
self.assertIs(worker.current_job, msg_b)
await worker.respond_to_job(speak="B done")
await _settle()
self.assertIsNone(worker.current_job)
class TestUIWorkerRespondJob(unittest.IsolatedAsyncioTestCase):
async def test_respond_handler_runs_after_setup(self):
worker = await _make_worker()
worker._latest_snapshot = _SAMPLE_SNAPSHOT
t = await _start(
worker,
BusJobRequestMessage(
source="voice", target="ui", job_id="t1", payload={"query": "hello"}
),
)
appends = _append_frames(worker)
self.assertEqual(len(appends), 2)
self.assertTrue(appends[0].messages[0]["content"].startswith("<ui_state>"))
self.assertFalse(appends[0].run_llm)
self.assertEqual(appends[1].messages[0]["content"], "hello")
self.assertTrue(appends[1].run_llm)
self.assertEqual(worker.current_job.job_id, "t1")
await worker.respond_to_job(speak="done")
await t
self.assertIsNone(worker.current_job)
worker.send_job_response.assert_awaited_once()
async def test_render_query_override(self):
class _Custom(_StubUIWorker):
def render_query(self, message):
return f"Q: {message.payload['q']}"
worker = await _make_worker(cls=_Custom)
t = await _start(
worker,
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"q": "hi"}),
)
query_frames = [f for f in _append_frames(worker) if f.run_llm]
self.assertEqual(query_frames[0].messages[0]["content"], "Q: hi")
await worker.respond_to_job()
await t
async def test_handler_failure_clears_state(self):
class _Boom(_StubUIWorker):
def render_query(self, message):
raise RuntimeError("boom")
worker = await _make_worker(cls=_Boom)
with self.assertRaises(RuntimeError):
await worker.respond_with_llm(
BusJobRequestMessage(
source="voice", target="ui", job_id="t1", payload={"query": "x"}
)
)
self.assertIsNone(worker.current_job)
self.assertIsNone(worker._pending)
class TestUIWorkerPromptGuide(unittest.IsolatedAsyncioTestCase):
async def test_default_appends_ui_state_prompt_guide(self):
worker = UIWorker("ui", llm=MagicMock(), active=False)
worker.llm.append_system_instruction.assert_called_once_with(UI_STATE_PROMPT_GUIDE)
async def test_custom_prompt_guide_overrides(self):
worker = UIWorker("ui", llm=MagicMock(), active=False, prompt_guide="MY GUIDE")
worker.llm.append_system_instruction.assert_called_once_with("MY GUIDE")
async def test_none_disables_injection(self):
worker = UIWorker("ui", llm=MagicMock(), active=False, prompt_guide=None)
worker.llm.append_system_instruction.assert_not_called()
if __name__ == "__main__":
unittest.main()