Add UIWorker
UIWorker is an LLMContextWorker that observes and drives a client GUI over the RTVI UI channel: it stores accessibility snapshots, auto-injects <ui_state> at the start of each respond job, dispatches client events to @on_ui_event handlers, sends UI commands back to the client, and surfaces fan-out work as cancellable task cards via user_job_group(). The optional ReplyToolMixin exposes a bundled reply tool. The prompt_guide parameter auto-appends the UI wire-format guide to the LLM's system instruction (default UI_STATE_PROMPT_GUIDE; override with a string or disable with None), so the LLM can parse the injected <ui_state> / <ui_event> messages without the app concatenating the guide by hand.
This commit is contained in:
3
changelog/xxxx.added.md
Normal file
3
changelog/xxxx.added.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Added `pipecat.workers.ui.UIWorker`, an `LLMContextWorker` that observes and drives a client GUI over the RTVI UI channel: it stores live accessibility snapshots, auto-injects `<ui_state>` at the start of each `respond` job, dispatches client events to `@on_ui_event` handlers, and sends UI commands (`scroll_to`, `highlight`, `select_text`, `click`, `set_input_value`) back to the client. The optional `ReplyToolMixin` exposes a bundled `reply` tool, and `user_job_group(...)` surfaces fan-out work to the client as cancellable task cards. A native RTVI⇄bus UI bridge is built into `PipelineWorker` (active whenever RTVI is enabled), so no decorator or manual wiring is needed: inbound UI messages are broadcast on the bus as `BusUIEventMessage`, and outbound `BusUICommandMessage` / `BusUITask*` carriers are translated into RTVI frames for the client.
|
||||
|
||||
- `UIWorker` auto-injects the UI wire-format guide (`UI_STATE_PROMPT_GUIDE`) into its LLM's system instruction by default, via a `prompt_guide` parameter — pass your own string to override the guide, or `None` to disable. Apps no longer need to concatenate `UI_STATE_PROMPT_GUIDE` into the LLM's `system_instruction` by hand.
|
||||
49
src/pipecat/workers/ui/__init__.py
Normal file
49
src/pipecat/workers/ui/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""UI worker: LLM worker that observes and drives a GUI app.
|
||||
|
||||
Composes the RTVI wire protocol for client UI events, accessibility-tree
|
||||
snapshots, and server-emitted UI commands (in
|
||||
``pipecat.processors.frameworks.rtvi.models``) with an opt-in
|
||||
``ReplyToolMixin`` exposing the canonical bundled reply tool
|
||||
(``answer`` + optional ``scroll_to`` / ``highlight`` / ...).
|
||||
|
||||
The RTVI⇄bus bridge that connects a ``UIWorker`` to the client is built
|
||||
into ``PipelineWorker`` and is active whenever RTVI is enabled — there
|
||||
is no decorator or bridge to wire up.
|
||||
"""
|
||||
|
||||
from pipecat.bus.ui_messages import (
|
||||
BusUICommandMessage,
|
||||
BusUIEventMessage,
|
||||
BusUITaskCompletedMessage,
|
||||
BusUITaskGroupCompletedMessage,
|
||||
BusUITaskGroupStartedMessage,
|
||||
BusUITaskUpdateMessage,
|
||||
)
|
||||
from pipecat.workers.ui.ui_event_decorator import on_ui_event
|
||||
from pipecat.workers.ui.ui_prompts import UI_STATE_PROMPT_GUIDE
|
||||
from pipecat.workers.ui.ui_tools import ReplyToolMixin
|
||||
from pipecat.workers.ui.ui_worker import UIWorker
|
||||
|
||||
# Built-in UI command payload models (Toast, Navigate, ScrollTo,
|
||||
# Highlight, Focus, Click, SetInputValue, SelectText) live in
|
||||
# ``pipecat.processors.frameworks.rtvi.models``. Import them from there
|
||||
# directly.
|
||||
|
||||
__all__ = [
|
||||
"BusUICommandMessage",
|
||||
"BusUIEventMessage",
|
||||
"BusUITaskCompletedMessage",
|
||||
"BusUITaskGroupCompletedMessage",
|
||||
"BusUITaskGroupStartedMessage",
|
||||
"BusUITaskUpdateMessage",
|
||||
"ReplyToolMixin",
|
||||
"UIWorker",
|
||||
"UI_STATE_PROMPT_GUIDE",
|
||||
"on_ui_event",
|
||||
]
|
||||
67
src/pipecat/workers/ui/ui_event_decorator.py
Normal file
67
src/pipecat/workers/ui/ui_event_decorator.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Decorator for marking worker methods as UI event handlers."""
|
||||
|
||||
|
||||
def on_ui_event(name: str):
|
||||
"""Mark a worker method as a handler for a named UI event.
|
||||
|
||||
On ``UIWorker`` subclasses, decorated methods are automatically
|
||||
dispatched when a ``BusUIEventMessage`` with a matching ``name``
|
||||
arrives.
|
||||
|
||||
Example::
|
||||
|
||||
class MyUIWorker(UIWorker):
|
||||
@on_ui_event("nav_click")
|
||||
async def on_nav(self, message):
|
||||
view = message.payload.get("view")
|
||||
...
|
||||
|
||||
Args:
|
||||
name: The UI event name to match.
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
fn.is_ui_event_handler = True
|
||||
fn.ui_event_name = name
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _collect_ui_event_handlers(obj) -> dict:
|
||||
"""Collect all ``@on_ui_event`` decorated bound methods from an object.
|
||||
|
||||
Walks the MRO so that overridden methods in subclasses take
|
||||
precedence over base-class definitions.
|
||||
|
||||
Returns:
|
||||
A dict mapping event name to the bound method.
|
||||
|
||||
Raises:
|
||||
ValueError: If two handlers share the same event name on the
|
||||
same subclass level.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
handlers: dict[str, object] = {}
|
||||
source_names: dict[str, str] = {} # event name -> defining method name, for errors
|
||||
for cls in type(obj).__mro__:
|
||||
for attr_name, val in cls.__dict__.items():
|
||||
if attr_name in seen:
|
||||
continue
|
||||
seen.add(attr_name)
|
||||
if callable(val) and getattr(val, "is_ui_event_handler", False):
|
||||
event_name = val.ui_event_name
|
||||
if event_name in handlers:
|
||||
raise ValueError(
|
||||
f"Duplicate @on_ui_event handler for '{event_name}': "
|
||||
f"'{attr_name}' conflicts with '{source_names[event_name]}'"
|
||||
)
|
||||
handlers[event_name] = getattr(obj, attr_name)
|
||||
source_names[event_name] = attr_name
|
||||
return handlers
|
||||
147
src/pipecat/workers/ui/ui_job_context.py
Normal file
147
src/pipecat/workers/ui/ui_job_context.py
Normal file
@@ -0,0 +1,147 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""User-facing job group context.
|
||||
|
||||
Wraps ``JobGroupContext`` so the work it dispatches is also surfaced
|
||||
to the UI client through the UI Worker protocol. Apps reach this via
|
||||
``UIWorker.user_job_group(...)`` rather than constructing it directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pipecat.bus.ui_messages import (
|
||||
BusUITaskGroupCompletedMessage,
|
||||
BusUITaskGroupStartedMessage,
|
||||
)
|
||||
from pipecat.pipeline.job_context import JobGroupContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.workers.ui.ui_worker import UIWorker
|
||||
|
||||
|
||||
class UserJobGroupContext(JobGroupContext):
|
||||
"""Job group whose lifecycle is forwarded to the UI client.
|
||||
|
||||
Behaves exactly like ``JobGroupContext`` for the dispatching code.
|
||||
Additionally, on enter the context registers the group with its
|
||||
parent ``UIWorker`` and publishes a ``BusUITaskGroupStartedMessage``.
|
||||
The worker forwards any subsequent ``BusJobUpdateMessage`` /
|
||||
``BusJobResponseMessage`` whose ``job_id`` matches a registered
|
||||
group as ``BusUITaskUpdateMessage`` / ``BusUITaskCompletedMessage``.
|
||||
On exit the context publishes ``BusUITaskGroupCompletedMessage`` and
|
||||
deregisters.
|
||||
|
||||
Workers don't need to know about the UI surface: any
|
||||
``send_job_update`` they emit against the group's ``job_id`` is
|
||||
forwarded automatically.
|
||||
|
||||
Example::
|
||||
|
||||
async with self.user_job_group(
|
||||
"researcher_a", "researcher_b",
|
||||
payload={"query": query},
|
||||
label=f"Research: {query}",
|
||||
cancellable=True,
|
||||
) as tg:
|
||||
async for event in tg:
|
||||
...
|
||||
results = tg.responses
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: UIWorker,
|
||||
worker_names: tuple[str, ...],
|
||||
*,
|
||||
name: str | None = None,
|
||||
payload: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
cancel_on_error: bool = True,
|
||||
label: str | None = None,
|
||||
cancellable: bool = True,
|
||||
):
|
||||
"""Initialize the UserJobGroupContext.
|
||||
|
||||
Args:
|
||||
worker: The parent ``UIWorker`` that owns this job group.
|
||||
worker_names: Names of the workers to send the job to.
|
||||
name: Optional job name for routing to named ``@job``
|
||||
handlers on the workers.
|
||||
payload: Optional structured data describing the work.
|
||||
timeout: Optional timeout in seconds covering both the
|
||||
ready-wait and job execution.
|
||||
cancel_on_error: Whether to cancel the group if a worker
|
||||
errors. Defaults to True.
|
||||
label: Optional human-readable label surfaced to the
|
||||
client (e.g. ``"Research: Radiohead"``). The client UI
|
||||
uses it to title the in-flight task card.
|
||||
cancellable: Whether the client may request cancellation
|
||||
of this group via the reserved ``__cancel_task`` event.
|
||||
Defaults to True.
|
||||
"""
|
||||
super().__init__(
|
||||
worker,
|
||||
worker_names,
|
||||
name=name,
|
||||
payload=payload,
|
||||
timeout=timeout,
|
||||
cancel_on_error=cancel_on_error,
|
||||
)
|
||||
self._ui_worker = worker
|
||||
self._label = label
|
||||
self._cancellable = cancellable
|
||||
|
||||
@property
|
||||
def label(self) -> str | None:
|
||||
"""The group's human-readable label, if any."""
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def cancellable(self) -> bool:
|
||||
"""Whether the client may request cancellation."""
|
||||
return self._cancellable
|
||||
|
||||
async def __aenter__(self) -> UserJobGroupContext:
|
||||
await super().__aenter__()
|
||||
job_id = self.job_id
|
||||
self._ui_worker._register_user_job_group(
|
||||
job_id=job_id,
|
||||
worker_names=list(self._worker_names),
|
||||
label=self._label,
|
||||
cancellable=self._cancellable,
|
||||
)
|
||||
await self._ui_worker.send_bus_message(
|
||||
BusUITaskGroupStartedMessage(
|
||||
source=self._ui_worker.name,
|
||||
target=None,
|
||||
task_id=job_id,
|
||||
agents=list(self._worker_names),
|
||||
label=self._label,
|
||||
cancellable=self._cancellable,
|
||||
at=int(time.time() * 1000),
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||
job_id = self._group.job_id if self._group else None
|
||||
try:
|
||||
return await super().__aexit__(exc_type, exc_val, exc_tb)
|
||||
finally:
|
||||
if job_id:
|
||||
self._ui_worker._unregister_user_job_group(job_id)
|
||||
await self._ui_worker.send_bus_message(
|
||||
BusUITaskGroupCompletedMessage(
|
||||
source=self._ui_worker.name,
|
||||
target=None,
|
||||
task_id=job_id,
|
||||
at=int(time.time() * 1000),
|
||||
)
|
||||
)
|
||||
64
src/pipecat/workers/ui/ui_prompts.py
Normal file
64
src/pipecat/workers/ui/ui_prompts.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Canonical prompt fragments describing the UI worker wire format.
|
||||
|
||||
Apps concatenate these constants into their system prompt so the
|
||||
LLM understands the ``<ui_state>`` and ``<ui_event>`` developer
|
||||
messages the SDK injects on its behalf.
|
||||
|
||||
Example::
|
||||
|
||||
system_prompt = f'''
|
||||
You are a voice-driven music player agent.
|
||||
...app-specific tool and behavior instructions...
|
||||
|
||||
{UI_STATE_PROMPT_GUIDE}
|
||||
'''
|
||||
|
||||
The SDK updates the guide alongside the wire format, so apps get
|
||||
new tags and semantics automatically on the next release.
|
||||
"""
|
||||
|
||||
UI_STATE_PROMPT_GUIDE: str = """\
|
||||
## UI context
|
||||
|
||||
Your developer context includes two kinds of SDK-managed messages:
|
||||
|
||||
- ``<ui_event name="..." >payload</ui_event>``: an event the user just \
|
||||
triggered on the client (click, tab switch, navigation, etc.). The \
|
||||
payload is JSON for that event.
|
||||
- ``<ui_state>...</ui_state>``: an accessibility snapshot of the \
|
||||
current screen, injected at the start of every task request. \
|
||||
Indented tree in Playwright-MCP style. Each line is \
|
||||
``- role "name" [state] [ref=eN]`` with children nested one level \
|
||||
deeper.
|
||||
|
||||
State tags include ``[focused]``, ``[selected]``, ``[disabled]``, and \
|
||||
``[offscreen]``. A node tagged ``[offscreen]`` exists on the page \
|
||||
but is not currently in the user's viewport; only visible \
|
||||
(non-offscreen) nodes count for position-based references.
|
||||
|
||||
Grids carry a ``[cols=N]`` tag. Their cells are listed in reading \
|
||||
order (left-to-right, top-to-bottom); with N columns, cell K sits \
|
||||
at row ``ceil(K/N)``, column ``((K-1) mod N) + 1``. Example with \
|
||||
``[cols=8]`` and 16 children: "top right" is cell 8, "bottom left" \
|
||||
is cell 9.
|
||||
|
||||
Resolve position references ("top right", "the first one", "the \
|
||||
third new release") against the most recent ``<ui_state>`` tree. \
|
||||
Sibling order matches reading order on screen (top-to-bottom, \
|
||||
left-to-right within each region).
|
||||
|
||||
When the user has text selected on the page, the snapshot ends with \
|
||||
a ``<selection ref="eN">selected text</selection>`` block inside \
|
||||
``<ui_state>``. Treat the selection as the deictic referent for \
|
||||
"this", "that", "what I selected", and similar phrases. The ``ref`` \
|
||||
identifies the closest enclosing element that has a ref in the tree; \
|
||||
the inner text is the actual selected content (truncated if very \
|
||||
long). Text inside ``<input>`` or ``<textarea>`` selections is \
|
||||
faithful to ``selectionStart``/``selectionEnd`` on the element.\
|
||||
"""
|
||||
155
src/pipecat/workers/ui/ui_tools.py
Normal file
155
src/pipecat/workers/ui/ui_tools.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Opt-in tool mixin for ``UIWorker``.
|
||||
|
||||
Ships ``ReplyToolMixin``: a single ``reply(answer, scroll_to,
|
||||
highlight, select_text, fills, click)`` LLM tool that bundles a
|
||||
required spoken answer with the full set of standard UI actions.
|
||||
One tool call per turn; the model cannot drop the terminator because
|
||||
``answer`` is a required argument that the API schema enforces.
|
||||
|
||||
The bundled mixin covers the canonical app shapes (pointing,
|
||||
reading, form-fill) and any blend of them. Apps don't need to pick
|
||||
a mode up front; the LLM uses whichever fields make sense per turn,
|
||||
leaving the rest as ``null``.
|
||||
|
||||
Apps that want a tighter schema (only the fields they use, or
|
||||
app-specific commands like ``play_song``) hand-roll their own
|
||||
``@tool reply`` on the ``UIWorker`` subclass directly. The helper
|
||||
methods on ``UIWorker`` (``scroll_to``, ``highlight``,
|
||||
``select_text``, ``click``, ``set_input_value``) plus
|
||||
``send_command`` and the standard payload dataclasses cover the
|
||||
building blocks for custom replies.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.workers.llm.tool_decorator import tool
|
||||
|
||||
|
||||
class ReplyToolMixin:
|
||||
"""Expose a ``reply`` tool covering the full standard action set.
|
||||
|
||||
Single bundled LLM tool with a required spoken ``answer`` plus
|
||||
optional visual and state-changing actions. One tool call per
|
||||
turn, no chaining; the required ``answer`` argument is enforced
|
||||
by the API schema so the model cannot omit the terminator.
|
||||
|
||||
Compose alongside ``UIWorker``::
|
||||
|
||||
class MyUIWorker(ReplyToolMixin, UIWorker):
|
||||
...
|
||||
|
||||
Covers pointing apps (``scroll_to`` + ``highlight``), reading
|
||||
apps (``scroll_to`` + ``select_text``), form apps (``fills`` +
|
||||
``click``), and any blend (e.g. a document review with
|
||||
selection-based deixis AND voice-driven note-taking). The LLM
|
||||
uses whichever fields fit the user's request per turn; unused
|
||||
fields stay ``null`` and don't affect behavior.
|
||||
|
||||
Apps that want a minimal schema (only the fields actually used,
|
||||
or app-specific commands) write their own ``@tool reply`` on the
|
||||
``UIWorker`` subclass directly. Use the helper methods on
|
||||
``UIWorker`` plus ``send_command`` to dispatch the underlying UI
|
||||
commands.
|
||||
|
||||
The host class must provide ``scroll_to``, ``highlight``,
|
||||
``select_text``, ``click``, ``set_input_value``, and
|
||||
``respond_to_job`` (``UIWorker`` does) and must be the target of
|
||||
``@tool`` discovery on the LLM pipeline.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def reply(
|
||||
self,
|
||||
params: FunctionCallParams,
|
||||
answer: str,
|
||||
scroll_to: str | None = None,
|
||||
highlight: list[str] | None = None,
|
||||
select_text: str | None = None,
|
||||
fills: list[dict] | None = None,
|
||||
click: list[str] | None = None,
|
||||
):
|
||||
"""Reply to the user. Optionally point at content and act on inputs.
|
||||
|
||||
Always called exactly once per turn. ``answer`` is required;
|
||||
the action fields are optional and may be combined.
|
||||
|
||||
Visual / pointing actions (draw the user's attention):
|
||||
|
||||
- ``scroll_to`` brings an element into view (single ref).
|
||||
- ``highlight`` flashes elements briefly (list of refs).
|
||||
Best for short emphasis like a button or a fact.
|
||||
- ``select_text`` puts the page's text selection on an
|
||||
element (single ref). Best for "this paragraph" / "the
|
||||
section about X" so the user sees exactly what was meant.
|
||||
Persists until the user clicks elsewhere.
|
||||
|
||||
State-changing actions (modify form / app state):
|
||||
|
||||
- ``fills`` writes values into inputs (list of
|
||||
``{"ref", "value"}`` objects, multi-fill in one turn).
|
||||
- ``click`` clicks elements (list of refs in order). Use for
|
||||
checkboxes, radios, submit buttons.
|
||||
|
||||
Order of dispatch within a turn: ``scroll_to``, then
|
||||
``highlight``, then ``select_text``, then ``fills``, then
|
||||
``click``, then speak the answer.
|
||||
|
||||
Args:
|
||||
params: Framework-provided tool invocation context.
|
||||
answer: The spoken reply in plain language. One short
|
||||
sentence. No markdown, no symbols.
|
||||
scroll_to: Optional snapshot ref. Scrolls the element
|
||||
into view before speaking.
|
||||
highlight: Optional list of snapshot refs. Visually
|
||||
pulses each element.
|
||||
select_text: Optional snapshot ref. Places the page's
|
||||
text selection on that element.
|
||||
fills: Optional list of ``{"ref": "eN", "value": "..."}``
|
||||
objects. Writes each value into the input at ``ref``.
|
||||
click: Optional list of snapshot refs to click in order.
|
||||
"""
|
||||
preview = (answer or "").strip()
|
||||
if len(preview) > 80:
|
||||
preview = preview[:80] + "…"
|
||||
logger.info(
|
||||
f"{self}: reply(answer={preview!r}, scroll_to={scroll_to!r}, "
|
||||
f"highlight={highlight!r}, select_text={select_text!r}, "
|
||||
f"fills={fills!r}, click={click!r})"
|
||||
)
|
||||
# Defensive guards on the list arguments: an LLM that emits a
|
||||
# malformed entry (None, a bare string, etc.) would crash the
|
||||
# tool body before respond_to_job fires, leaving the
|
||||
# single-flight lock held until the requester's timeout cancels
|
||||
# us. Skip non-conforming entries instead.
|
||||
if scroll_to:
|
||||
await self.scroll_to(scroll_to) # type: ignore[attr-defined]
|
||||
if highlight:
|
||||
for ref in highlight:
|
||||
if not isinstance(ref, str):
|
||||
continue
|
||||
await self.highlight(ref) # type: ignore[attr-defined]
|
||||
if select_text:
|
||||
await self.select_text(select_text) # type: ignore[attr-defined]
|
||||
if fills:
|
||||
for entry in fills:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
ref = entry.get("ref")
|
||||
value = entry.get("value")
|
||||
if not isinstance(ref, str) or value is None:
|
||||
continue
|
||||
await self.set_input_value(ref, str(value)) # type: ignore[attr-defined]
|
||||
if click:
|
||||
for ref in click:
|
||||
if not isinstance(ref, str):
|
||||
continue
|
||||
await self.click(ref) # type: ignore[attr-defined]
|
||||
await self.respond_to_job(speak=answer) # type: ignore[attr-defined]
|
||||
await params.result_callback(None)
|
||||
1151
src/pipecat/workers/ui/ui_worker.py
Normal file
1151
src/pipecat/workers/ui/ui_worker.py
Normal file
File diff suppressed because it is too large
Load Diff
231
tests/test_ui_commands.py
Normal file
231
tests/test_ui_commands.py
Normal file
@@ -0,0 +1,231 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for UIWorker.send_command and standard command payload models."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pipecat.bus.ui_messages import BusUICommandMessage
|
||||
from pipecat.processors.frameworks.rtvi.models import (
|
||||
Click,
|
||||
Focus,
|
||||
Highlight,
|
||||
Navigate,
|
||||
ScrollTo,
|
||||
SelectText,
|
||||
SetInputValue,
|
||||
Toast,
|
||||
)
|
||||
from pipecat.workers.ui import UIWorker
|
||||
|
||||
|
||||
def _make_worker():
|
||||
"""A UIWorker whose ``send_bus_message`` captures sent commands.
|
||||
|
||||
``send_command`` publishes via ``send_bus_message``; replacing it
|
||||
avoids needing an attached bus for these unit tests.
|
||||
"""
|
||||
worker = UIWorker("ui", llm=MagicMock(), active=False)
|
||||
sent: list[BusUICommandMessage] = []
|
||||
|
||||
async def _record(message):
|
||||
if isinstance(message, BusUICommandMessage):
|
||||
sent.append(message)
|
||||
|
||||
worker.send_bus_message = _record # type: ignore[method-assign]
|
||||
return worker, sent
|
||||
|
||||
|
||||
class TestSendCommand(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_serializes_pydantic_payload_via_model_dump(self):
|
||||
worker, sent = _make_worker()
|
||||
|
||||
await worker.send_command("toast", Toast(title="Saved", subtitle="Favorites"))
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
cmd = sent[0]
|
||||
self.assertEqual(cmd.source, "ui")
|
||||
self.assertIsNone(cmd.target)
|
||||
self.assertEqual(cmd.command_name, "toast")
|
||||
self.assertEqual(
|
||||
cmd.payload,
|
||||
{
|
||||
"title": "Saved",
|
||||
"subtitle": "Favorites",
|
||||
"description": None,
|
||||
"image_url": None,
|
||||
"duration_ms": None,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_forwards_dict_payload_as_is(self):
|
||||
worker, sent = _make_worker()
|
||||
|
||||
await worker.send_command("app_specific", {"foo": 1, "bar": [1, 2, 3]})
|
||||
|
||||
self.assertEqual(sent[0].command_name, "app_specific")
|
||||
self.assertEqual(sent[0].payload, {"foo": 1, "bar": [1, 2, 3]})
|
||||
|
||||
async def test_none_payload_becomes_empty_dict(self):
|
||||
worker, sent = _make_worker()
|
||||
|
||||
await worker.send_command("ping")
|
||||
|
||||
self.assertEqual(sent[0].payload, {})
|
||||
|
||||
async def test_dict_payload_for_apps_with_custom_command_names(self):
|
||||
worker, sent = _make_worker()
|
||||
|
||||
await worker.send_command("navigate", {"view": "home"})
|
||||
self.assertEqual(sent[0].payload, {"view": "home"})
|
||||
|
||||
|
||||
class TestStandardCommands(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_toast_payload_shape(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command(
|
||||
"toast",
|
||||
Toast(
|
||||
title="Now playing",
|
||||
subtitle="Nirvana",
|
||||
description="Smells Like Teen Spirit",
|
||||
image_url="https://example.com/cover.jpg",
|
||||
duration_ms=3000,
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"title": "Now playing",
|
||||
"subtitle": "Nirvana",
|
||||
"description": "Smells Like Teen Spirit",
|
||||
"image_url": "https://example.com/cover.jpg",
|
||||
"duration_ms": 3000,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_navigate_payload_shape(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("navigate", Navigate(view="detail", params={"id": "42"}))
|
||||
self.assertEqual(sent[0].payload, {"view": "detail", "params": {"id": "42"}})
|
||||
|
||||
async def test_scroll_to_payload_shape_by_target_id(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command(
|
||||
"scroll_to", ScrollTo(target_id="new_releases", behavior="smooth")
|
||||
)
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": None, "target_id": "new_releases", "behavior": "smooth"},
|
||||
)
|
||||
|
||||
async def test_scroll_to_payload_shape_by_ref(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("scroll_to", ScrollTo(ref="e42", behavior="smooth"))
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": "e42", "target_id": None, "behavior": "smooth"},
|
||||
)
|
||||
|
||||
async def test_highlight_payload_shape(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("highlight", Highlight(target_id="play_btn", duration_ms=1000))
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": None, "target_id": "play_btn", "duration_ms": 1000},
|
||||
)
|
||||
|
||||
async def test_focus_payload_shape(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("focus", Focus(target_id="search_input"))
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": None, "target_id": "search_input"},
|
||||
)
|
||||
|
||||
async def test_focus_payload_shape_by_ref(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("focus", Focus(ref="e7"))
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": "e7", "target_id": None},
|
||||
)
|
||||
|
||||
async def test_select_text_payload_whole_element(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("select_text", SelectText(ref="e42"))
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e42",
|
||||
"target_id": None,
|
||||
"start_offset": None,
|
||||
"end_offset": None,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_select_text_payload_with_offsets(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command(
|
||||
"select_text",
|
||||
SelectText(ref="e42", start_offset=5, end_offset=12),
|
||||
)
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e42",
|
||||
"target_id": None,
|
||||
"start_offset": 5,
|
||||
"end_offset": 12,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_set_input_value_payload_replace_default(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command(
|
||||
"set_input_value",
|
||||
SetInputValue(ref="e7", value="hello"),
|
||||
)
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e7",
|
||||
"target_id": None,
|
||||
"value": "hello",
|
||||
"replace": True,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_set_input_value_payload_append(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command(
|
||||
"set_input_value",
|
||||
SetInputValue(ref="e7", value=" world", replace=False),
|
||||
)
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e7",
|
||||
"target_id": None,
|
||||
"value": " world",
|
||||
"replace": False,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_click_payload_by_ref(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("click", Click(ref="e42"))
|
||||
self.assertEqual(sent[0].payload, {"ref": "e42", "target_id": None})
|
||||
|
||||
async def test_click_payload_by_target_id(self):
|
||||
worker, sent = _make_worker()
|
||||
await worker.send_command("click", Click(target_id="submit"))
|
||||
self.assertEqual(sent[0].payload, {"ref": None, "target_id": "submit"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
366
tests/test_ui_job_lifecycle.py
Normal file
366
tests/test_ui_job_lifecycle.py
Normal file
@@ -0,0 +1,366 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for the UIWorker user-job-group lifecycle.
|
||||
|
||||
Covers:
|
||||
- ``UIWorker.on_bus_message`` forwarding of worker job updates/responses
|
||||
for registered user job groups as ``BusUITask*`` carriers.
|
||||
- The reserved ``__cancel_task`` client event routing to
|
||||
``cancel_job_group``.
|
||||
- ``UserJobGroupContext`` publishing ``group_started`` / ``group_completed``
|
||||
envelopes and (de)registering the group on the worker.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from pipecat.bus.messages import BusJobResponseMessage, BusJobUpdateMessage
|
||||
from pipecat.bus.ui_messages import (
|
||||
_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
BusUIEventMessage,
|
||||
BusUITaskCompletedMessage,
|
||||
BusUITaskGroupCompletedMessage,
|
||||
BusUITaskGroupStartedMessage,
|
||||
BusUITaskUpdateMessage,
|
||||
)
|
||||
from pipecat.frames.frames import LLMMessagesAppendFrame
|
||||
from pipecat.pipeline.job_context import JobGroup, JobStatus
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
from pipecat.workers.ui import UIWorker
|
||||
|
||||
|
||||
async def _make_solo_worker(**kwargs) -> UIWorker:
|
||||
"""A UIWorker with a task manager and a ``queue_frame`` spy.
|
||||
|
||||
Suitable for testing forwarding logic by directly invoking
|
||||
``on_bus_message`` and asserting on captured ``send_bus_message``
|
||||
calls.
|
||||
"""
|
||||
worker = UIWorker("ui", llm=MagicMock(), active=False, **kwargs)
|
||||
tm = TaskManager()
|
||||
tm.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
worker._task_manager = tm
|
||||
|
||||
recorded: list = []
|
||||
|
||||
async def _record(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
recorded.append(frame)
|
||||
|
||||
worker.queue_frame = _record # type: ignore[method-assign]
|
||||
worker._recorded = recorded # type: ignore[attr-defined]
|
||||
return worker
|
||||
|
||||
|
||||
class TestUIWorkerForwarding(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_unregistered_job_update_is_not_forwarded(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusJobUpdateMessage(
|
||||
source="worker", target=worker.name, job_id="t-unknown", update={"x": 1}
|
||||
)
|
||||
)
|
||||
|
||||
forwarded = [
|
||||
c.args[0]
|
||||
for c in worker.send_bus_message.await_args_list
|
||||
if isinstance(c.args[0], BusUITaskUpdateMessage)
|
||||
]
|
||||
self.assertEqual(forwarded, [])
|
||||
|
||||
async def test_registered_job_update_is_forwarded(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["worker"], label="hello", cancellable=True
|
||||
)
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusJobUpdateMessage(
|
||||
source="worker",
|
||||
target=worker.name,
|
||||
job_id="t1",
|
||||
update={"kind": "tool_call", "tool": "WebSearch"},
|
||||
)
|
||||
)
|
||||
|
||||
forwarded = [
|
||||
c.args[0]
|
||||
for c in worker.send_bus_message.await_args_list
|
||||
if isinstance(c.args[0], BusUITaskUpdateMessage)
|
||||
]
|
||||
self.assertEqual(len(forwarded), 1)
|
||||
self.assertEqual(forwarded[0].task_id, "t1")
|
||||
self.assertEqual(forwarded[0].agent_name, "worker")
|
||||
self.assertEqual(forwarded[0].data, {"kind": "tool_call", "tool": "WebSearch"})
|
||||
|
||||
async def test_registered_job_response_is_forwarded(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["worker"], label=None, cancellable=True
|
||||
)
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusJobResponseMessage(
|
||||
source="worker",
|
||||
target=worker.name,
|
||||
job_id="t1",
|
||||
status=JobStatus.COMPLETED,
|
||||
response={"answer": 42},
|
||||
)
|
||||
)
|
||||
|
||||
forwarded = [
|
||||
c.args[0]
|
||||
for c in worker.send_bus_message.await_args_list
|
||||
if isinstance(c.args[0], BusUITaskCompletedMessage)
|
||||
]
|
||||
self.assertEqual(len(forwarded), 1)
|
||||
self.assertEqual(forwarded[0].task_id, "t1")
|
||||
self.assertEqual(forwarded[0].agent_name, "worker")
|
||||
self.assertEqual(forwarded[0].status, "completed")
|
||||
self.assertEqual(forwarded[0].response, {"answer": 42})
|
||||
|
||||
async def test_response_status_serializes_for_cancelled_and_error(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["w"], label=None, cancellable=True
|
||||
)
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusJobResponseMessage(
|
||||
source="w", target=worker.name, job_id="t1", status=JobStatus.CANCELLED
|
||||
)
|
||||
)
|
||||
await worker.on_bus_message(
|
||||
BusJobResponseMessage(
|
||||
source="w", target=worker.name, job_id="t1", status=JobStatus.ERROR
|
||||
)
|
||||
)
|
||||
|
||||
statuses = [
|
||||
c.args[0].status
|
||||
for c in worker.send_bus_message.await_args_list
|
||||
if isinstance(c.args[0], BusUITaskCompletedMessage)
|
||||
]
|
||||
self.assertEqual(statuses, ["cancelled", "error"])
|
||||
|
||||
|
||||
class TestCancelJobEvent(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_cancel_event_routes_to_cancel_job_group(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["w"], label=None, cancellable=True
|
||||
)
|
||||
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(
|
||||
source="bridge",
|
||||
target=worker.name,
|
||||
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
payload={"task_id": "t1", "reason": "user clicked cancel"},
|
||||
)
|
||||
)
|
||||
|
||||
worker.cancel_job_group.assert_awaited_once_with("t1", reason="user clicked cancel")
|
||||
|
||||
async def test_cancel_event_default_reason_when_omitted(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["w"], label=None, cancellable=True
|
||||
)
|
||||
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(
|
||||
source="bridge",
|
||||
target=worker.name,
|
||||
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
payload={"task_id": "t1"},
|
||||
)
|
||||
)
|
||||
|
||||
worker.cancel_job_group.assert_awaited_once()
|
||||
self.assertEqual(worker.cancel_job_group.await_args.kwargs["reason"], "cancelled by user")
|
||||
|
||||
async def test_non_cancellable_group_is_ignored(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["w"], label=None, cancellable=False
|
||||
)
|
||||
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(
|
||||
source="bridge",
|
||||
target=worker.name,
|
||||
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
payload={"task_id": "t1"},
|
||||
)
|
||||
)
|
||||
|
||||
worker.cancel_job_group.assert_not_awaited()
|
||||
|
||||
async def test_unknown_job_id_is_ignored(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(
|
||||
source="bridge",
|
||||
target=worker.name,
|
||||
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
payload={"task_id": "nope"},
|
||||
)
|
||||
)
|
||||
|
||||
worker.cancel_job_group.assert_not_awaited()
|
||||
|
||||
async def test_missing_or_bad_payload_is_ignored(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.cancel_job_group = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(
|
||||
source="bridge",
|
||||
target=worker.name,
|
||||
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
payload=None,
|
||||
)
|
||||
)
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(
|
||||
source="bridge",
|
||||
target=worker.name,
|
||||
event_name=_UI_CANCEL_TASK_BUS_EVENT_NAME,
|
||||
payload={"task_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
worker.cancel_job_group.assert_not_awaited()
|
||||
|
||||
|
||||
class TestForwardingDoesNotInjectLLMContext(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_job_update_forwarding_does_not_queue_append_frames(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker._register_user_job_group(
|
||||
job_id="t1", worker_names=["w"], label=None, cancellable=True
|
||||
)
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusJobUpdateMessage(source="w", target=worker.name, job_id="t1", update={"x": 1})
|
||||
)
|
||||
|
||||
appends = [f for f in worker._recorded if isinstance(f, LLMMessagesAppendFrame)]
|
||||
self.assertEqual(appends, [])
|
||||
|
||||
|
||||
def _stub_job_group(worker, job_id="t1", worker_names=("w1",)):
|
||||
"""Make ``create_job_group_and_request_job`` return a self-completing group.
|
||||
|
||||
The group completes on the next loop tick so both the context-manager
|
||||
and fire-and-forget paths terminate without a running bus or workers.
|
||||
"""
|
||||
|
||||
async def _fake_create(names, *, name=None, payload=None, timeout=None, cancel_on_error=True):
|
||||
group = JobGroup(job_id=job_id, worker_names=set(names))
|
||||
|
||||
async def _finish():
|
||||
# Yield so JobGroupContext.__aenter__ can set event_queue first.
|
||||
await asyncio.sleep(0)
|
||||
group.complete()
|
||||
|
||||
asyncio.create_task(_finish())
|
||||
return group
|
||||
|
||||
worker.create_job_group_and_request_job = _fake_create # type: ignore[method-assign]
|
||||
|
||||
|
||||
class TestUserJobGroupContext(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_context_publishes_started_and_completed(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
_stub_job_group(worker)
|
||||
|
||||
async with worker.user_job_group("w1", label="My research") as tg:
|
||||
self.assertEqual(tg.job_id, "t1")
|
||||
self.assertIn("t1", worker._user_job_groups)
|
||||
|
||||
self.assertNotIn("t1", worker._user_job_groups)
|
||||
|
||||
kinds = [type(c.args[0]).__name__ for c in worker.send_bus_message.await_args_list]
|
||||
self.assertEqual(
|
||||
kinds,
|
||||
["BusUITaskGroupStartedMessage", "BusUITaskGroupCompletedMessage"],
|
||||
)
|
||||
|
||||
started = worker.send_bus_message.await_args_list[0].args[0]
|
||||
self.assertIsInstance(started, BusUITaskGroupStartedMessage)
|
||||
self.assertEqual(started.task_id, "t1")
|
||||
self.assertEqual(started.agents, ["w1"])
|
||||
self.assertEqual(started.label, "My research")
|
||||
self.assertTrue(started.cancellable)
|
||||
|
||||
completed = worker.send_bus_message.await_args_list[1].args[0]
|
||||
self.assertIsInstance(completed, BusUITaskGroupCompletedMessage)
|
||||
self.assertEqual(completed.task_id, "t1")
|
||||
|
||||
async def test_non_cancellable_group_sets_flag_in_started_message(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
_stub_job_group(worker)
|
||||
|
||||
async with worker.user_job_group("w1", cancellable=False):
|
||||
pass
|
||||
|
||||
started = worker.send_bus_message.await_args_list[0].args[0]
|
||||
self.assertFalse(started.cancellable)
|
||||
|
||||
async def test_unregisters_on_exit(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
_stub_job_group(worker)
|
||||
|
||||
async with worker.user_job_group("w1") as tg:
|
||||
pass
|
||||
|
||||
self.assertNotIn(tg.job_id, worker._user_job_groups)
|
||||
|
||||
async def test_start_user_job_group_returns_id_and_publishes(self):
|
||||
worker = await _make_solo_worker()
|
||||
worker.send_bus_message = AsyncMock() # type: ignore[method-assign]
|
||||
_stub_job_group(worker)
|
||||
|
||||
job_id = await worker.start_user_job_group("w1", label="Background work")
|
||||
self.assertEqual(job_id, "t1")
|
||||
|
||||
started = worker.send_bus_message.await_args_list[0].args[0]
|
||||
self.assertIsInstance(started, BusUITaskGroupStartedMessage)
|
||||
self.assertEqual(started.label, "Background work")
|
||||
|
||||
# The background runner drains the group and publishes completion.
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0)
|
||||
if any(
|
||||
isinstance(c.args[0], BusUITaskGroupCompletedMessage)
|
||||
for c in worker.send_bus_message.await_args_list
|
||||
):
|
||||
break
|
||||
else:
|
||||
self.fail("group_completed envelope was not published")
|
||||
|
||||
self.assertNotIn("t1", worker._user_job_groups)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
418
tests/test_ui_tools.py
Normal file
418
tests/test_ui_tools.py
Normal file
@@ -0,0 +1,418 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for ``ReplyToolMixin`` and the action helper methods on ``UIWorker``.
|
||||
|
||||
The mixin exposes a single bundled ``reply(answer, scroll_to,
|
||||
highlight, ...)`` LLM tool whose ``answer`` argument is required. The
|
||||
helper methods (``scroll_to``, ``highlight``, ...) are plain instance
|
||||
methods on ``UIWorker`` that wrap ``send_command`` with the standard
|
||||
payload models; apps call them inside custom ``@tool`` bodies when the
|
||||
canonical ``reply`` shape doesn't fit.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from pipecat.bus.ui_messages import BusUICommandMessage
|
||||
from pipecat.workers.llm.tool_decorator import _collect_tools
|
||||
from pipecat.workers.ui import ReplyToolMixin, UIWorker
|
||||
|
||||
|
||||
class _WorkerWithReply(ReplyToolMixin, UIWorker):
|
||||
pass
|
||||
|
||||
|
||||
class _PlainWorker(UIWorker):
|
||||
pass
|
||||
|
||||
|
||||
def _new(cls: type) -> UIWorker:
|
||||
return cls("ui", llm=MagicMock(), active=False)
|
||||
|
||||
|
||||
def _capture(worker: UIWorker) -> list[BusUICommandMessage]:
|
||||
sent: list[BusUICommandMessage] = []
|
||||
|
||||
async def _record(message):
|
||||
sent.append(message)
|
||||
|
||||
worker.send_bus_message = _record # type: ignore[method-assign]
|
||||
return sent
|
||||
|
||||
|
||||
class TestUIWorkerActionHelpers(unittest.IsolatedAsyncioTestCase):
|
||||
"""The helper methods are plain methods, not LLM tools."""
|
||||
|
||||
async def test_scroll_to_helper_dispatches_command(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.scroll_to("e42")
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].command_name, "scroll_to")
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": "e42", "target_id": None, "behavior": None},
|
||||
)
|
||||
|
||||
async def test_highlight_helper_dispatches_command(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.highlight("e7")
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].command_name, "highlight")
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{"ref": "e7", "target_id": None, "duration_ms": None},
|
||||
)
|
||||
|
||||
async def test_select_text_helper_whole_element(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.select_text("e42")
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].command_name, "select_text")
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e42",
|
||||
"target_id": None,
|
||||
"start_offset": None,
|
||||
"end_offset": None,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_select_text_helper_with_offsets(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.select_text("e42", start_offset=10, end_offset=25)
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e42",
|
||||
"target_id": None,
|
||||
"start_offset": 10,
|
||||
"end_offset": 25,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_click_helper_dispatches_command(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.click("e42")
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].command_name, "click")
|
||||
self.assertEqual(sent[0].payload, {"ref": "e42", "target_id": None})
|
||||
|
||||
async def test_set_input_value_helper_default_replace(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.set_input_value("e42", "hello world")
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].command_name, "set_input_value")
|
||||
self.assertEqual(
|
||||
sent[0].payload,
|
||||
{
|
||||
"ref": "e42",
|
||||
"target_id": None,
|
||||
"value": "hello world",
|
||||
"replace": True,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_set_input_value_helper_append_mode(self):
|
||||
worker = _new(_PlainWorker)
|
||||
sent = _capture(worker)
|
||||
|
||||
await worker.set_input_value("e42", "more text", replace=False)
|
||||
|
||||
self.assertEqual(sent[0].payload["replace"], False)
|
||||
|
||||
async def test_helpers_are_not_llm_tools(self):
|
||||
worker = _new(_PlainWorker)
|
||||
tool_names = [t.__name__ for t in _collect_tools(worker)]
|
||||
for name in ("scroll_to", "highlight", "select_text", "click", "set_input_value"):
|
||||
self.assertNotIn(name, tool_names)
|
||||
|
||||
|
||||
class TestReplyToolMixin(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_mixin_exposes_reply_tool(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
tool_names = [t.__name__ for t in _collect_tools(worker)]
|
||||
self.assertEqual(tool_names, ["reply"])
|
||||
|
||||
async def test_plain_uiworker_has_no_reply_tool(self):
|
||||
worker = _new(_PlainWorker)
|
||||
tool_names = [t.__name__ for t in _collect_tools(worker)]
|
||||
self.assertNotIn("reply", tool_names)
|
||||
|
||||
async def test_reply_with_answer_only_terminates(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(params, answer="The Pixel 9 is from Google.")
|
||||
|
||||
self.assertEqual(sent, [])
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="The Pixel 9 is from Google.")
|
||||
params.result_callback.assert_awaited_once_with(None)
|
||||
|
||||
async def test_reply_with_highlight_only(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="This one, the Nothing Phone 3.",
|
||||
highlight=["e29"],
|
||||
)
|
||||
|
||||
self.assertEqual([m.command_name for m in sent], ["highlight"])
|
||||
self.assertEqual(sent[0].payload["ref"], "e29")
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="This one, the Nothing Phone 3.")
|
||||
|
||||
async def test_reply_with_multiple_highlights(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="Here are the Apple phones.",
|
||||
highlight=["e5", "e8", "e47"],
|
||||
)
|
||||
|
||||
self.assertEqual([m.command_name for m in sent], ["highlight"] * 3)
|
||||
self.assertEqual([m.payload["ref"] for m in sent], ["e5", "e8", "e47"])
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="Here are the Apple phones.")
|
||||
|
||||
async def test_reply_with_scroll_and_highlight_runs_in_order(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="Here's the iPhone 17.",
|
||||
scroll_to="e5",
|
||||
highlight=["e5"],
|
||||
)
|
||||
|
||||
self.assertEqual([m.command_name for m in sent], ["scroll_to", "highlight"])
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="Here's the iPhone 17.")
|
||||
|
||||
async def test_reply_with_select_text_only(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="Here, in this paragraph.",
|
||||
select_text="e11",
|
||||
)
|
||||
|
||||
self.assertEqual([m.command_name for m in sent], ["select_text"])
|
||||
self.assertEqual(sent[0].payload["ref"], "e11")
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="Here, in this paragraph.")
|
||||
|
||||
async def test_reply_with_scroll_and_select_text(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="Here, in this paragraph.",
|
||||
scroll_to="e11",
|
||||
select_text="e11",
|
||||
)
|
||||
|
||||
self.assertEqual([m.command_name for m in sent], ["scroll_to", "select_text"])
|
||||
|
||||
async def test_reply_with_fills_writes_each_input(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="Got it.",
|
||||
fills=[
|
||||
{"ref": "e5", "value": "Mark"},
|
||||
{"ref": "e7", "value": "Backman"},
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[m.command_name for m in sent],
|
||||
["set_input_value", "set_input_value"],
|
||||
)
|
||||
self.assertEqual(sent[0].payload["ref"], "e5")
|
||||
self.assertEqual(sent[0].payload["value"], "Mark")
|
||||
self.assertEqual(sent[1].payload["ref"], "e7")
|
||||
self.assertEqual(sent[1].payload["value"], "Backman")
|
||||
|
||||
async def test_reply_with_click_clicks_each_in_order(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(params, answer="Submitted.", click=["e22", "e26"])
|
||||
|
||||
self.assertEqual([m.command_name for m in sent], ["click", "click"])
|
||||
self.assertEqual([m.payload["ref"] for m in sent], ["e22", "e26"])
|
||||
|
||||
async def test_reply_with_fills_skips_invalid_entries(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="x",
|
||||
fills=[
|
||||
{"ref": "e5", "value": "Mark"},
|
||||
{"ref": None, "value": "missing ref"},
|
||||
{"value": "no ref"},
|
||||
{"ref": "e7"},
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].payload["ref"], "e5")
|
||||
|
||||
async def test_reply_with_non_dict_fill_entries_does_not_crash(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="x",
|
||||
fills=[None, "e5", 42, {"ref": "e9", "value": "ok"}], # type: ignore[list-item]
|
||||
)
|
||||
|
||||
self.assertEqual(len(sent), 1)
|
||||
self.assertEqual(sent[0].payload["ref"], "e9")
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="x")
|
||||
params.result_callback.assert_awaited_once_with(None)
|
||||
|
||||
async def test_reply_with_non_string_highlight_refs_skipped(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="x",
|
||||
highlight=[None, "e1", 42, "e2"], # type: ignore[list-item]
|
||||
)
|
||||
|
||||
self.assertEqual([m.payload["ref"] for m in sent], ["e1", "e2"])
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="x")
|
||||
|
||||
async def test_reply_with_non_string_click_refs_skipped(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
sent = _capture(worker)
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="x",
|
||||
click=[None, "e1", {"ref": "e2"}, "e3"], # type: ignore[list-item]
|
||||
)
|
||||
|
||||
self.assertEqual([m.payload["ref"] for m in sent], ["e1", "e3"])
|
||||
worker.respond_to_job.assert_awaited_once_with(speak="x")
|
||||
|
||||
async def test_reply_dispatches_via_helper_methods(self):
|
||||
worker = _new(_WorkerWithReply)
|
||||
worker.scroll_to = AsyncMock() # type: ignore[method-assign]
|
||||
worker.highlight = AsyncMock() # type: ignore[method-assign]
|
||||
worker.select_text = AsyncMock() # type: ignore[method-assign]
|
||||
worker.set_input_value = AsyncMock() # type: ignore[method-assign]
|
||||
worker.click = AsyncMock() # type: ignore[method-assign]
|
||||
worker.respond_to_job = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
params = MagicMock()
|
||||
params.result_callback = AsyncMock()
|
||||
|
||||
await worker.reply(
|
||||
params,
|
||||
answer="x",
|
||||
scroll_to="e1",
|
||||
highlight=["e2", "e3"],
|
||||
select_text="e4",
|
||||
fills=[{"ref": "e5", "value": "v"}],
|
||||
click=["e6", "e7"],
|
||||
)
|
||||
|
||||
worker.scroll_to.assert_awaited_once_with("e1")
|
||||
self.assertEqual(
|
||||
worker.highlight.await_args_list,
|
||||
[unittest.mock.call("e2"), unittest.mock.call("e3")],
|
||||
)
|
||||
worker.select_text.assert_awaited_once_with("e4")
|
||||
worker.set_input_value.assert_awaited_once_with("e5", "v")
|
||||
self.assertEqual(
|
||||
worker.click.await_args_list,
|
||||
[unittest.mock.call("e6"), unittest.mock.call("e7")],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
805
tests/test_ui_worker.py
Normal file
805
tests/test_ui_worker.py
Normal file
@@ -0,0 +1,805 @@
|
||||
#
|
||||
# Copyright (c) 2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for UIWorker dispatch, LLM-context injection, and single-flight respond."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from pipecat.bus.messages import BusJobCancelMessage, BusJobRequestMessage
|
||||
from pipecat.bus.ui_messages import _UI_SNAPSHOT_BUS_EVENT_NAME, BusUIEventMessage
|
||||
from pipecat.frames.frames import LLMMessagesAppendFrame, LLMMessagesUpdateFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
from pipecat.workers.ui import UI_STATE_PROMPT_GUIDE, UIWorker, on_ui_event
|
||||
|
||||
|
||||
class _StubUIWorker(UIWorker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.captured: list[BusUIEventMessage] = []
|
||||
|
||||
@on_ui_event("nav_click")
|
||||
async def _on_nav(self, message: BusUIEventMessage) -> None:
|
||||
self.captured.append(message)
|
||||
|
||||
|
||||
class _PlainWorker(UIWorker):
|
||||
pass
|
||||
|
||||
|
||||
async def _make_worker(cls=_StubUIWorker, **kwargs) -> UIWorker:
|
||||
"""A UIWorker wired with a task manager and a ``queue_frame`` spy.
|
||||
|
||||
``queue_frame`` is replaced with a recorder so tests can assert the
|
||||
frames the UI logic produces without running the pipeline. The
|
||||
respond handler sends its own response, so ``send_job_response`` is
|
||||
mocked to let ``respond_with_llm`` run without an active job entry.
|
||||
"""
|
||||
worker = cls("ui", llm=MagicMock(), active=False, **kwargs)
|
||||
tm = TaskManager()
|
||||
tm.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
worker._task_manager = tm
|
||||
|
||||
recorded: list = []
|
||||
|
||||
async def _record(frame, direction=FrameDirection.DOWNSTREAM):
|
||||
recorded.append(frame)
|
||||
|
||||
worker.queue_frame = _record # type: ignore[method-assign]
|
||||
worker._recorded = recorded # type: ignore[attr-defined]
|
||||
worker.send_job_response = AsyncMock() # type: ignore[method-assign]
|
||||
return worker
|
||||
|
||||
|
||||
async def _settle() -> None:
|
||||
"""Yield enough times for spawned task handlers to run/park."""
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def _start(worker: UIWorker, message: BusJobRequestMessage) -> asyncio.Task:
|
||||
"""Start a respond turn; it parks at ``await self._pending`` until resolved."""
|
||||
t = asyncio.create_task(worker.respond_with_llm(message))
|
||||
await _settle()
|
||||
return t
|
||||
|
||||
|
||||
def _respond_msg(job_id: str, query: str = "hi") -> BusJobRequestMessage:
|
||||
return BusJobRequestMessage(
|
||||
source="voice", target="ui", job_name="respond", job_id=job_id, payload={"query": query}
|
||||
)
|
||||
|
||||
|
||||
async def _dispatch(worker: UIWorker, message: BusUIEventMessage) -> None:
|
||||
await worker.on_bus_message(message)
|
||||
for _ in range(5):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def _append_frames(worker) -> list[LLMMessagesAppendFrame]:
|
||||
return [f for f in worker._recorded if isinstance(f, LLMMessagesAppendFrame)]
|
||||
|
||||
|
||||
def _update_frames(worker) -> list[LLMMessagesUpdateFrame]:
|
||||
return [f for f in worker._recorded if isinstance(f, LLMMessagesUpdateFrame)]
|
||||
|
||||
|
||||
class TestUIWorkerDispatch(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_dispatches_to_matching_on_ui_event_handler(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(len(worker.captured), 1)
|
||||
self.assertEqual(worker.captured[0].event_name, "nav_click")
|
||||
self.assertEqual(worker.captured[0].payload, {"view": "home"})
|
||||
|
||||
async def test_unknown_event_name_does_not_raise(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target="ui", event_name="never_registered", payload={"x": 1}
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(worker.captured, [])
|
||||
|
||||
async def test_ignores_events_targeted_at_other_workers(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music",
|
||||
target="someone_else",
|
||||
event_name="nav_click",
|
||||
payload={"view": "home"},
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(worker.captured, [])
|
||||
self.assertEqual(_append_frames(worker), [])
|
||||
|
||||
async def test_broadcast_event_with_no_target_is_handled(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target=None, event_name="nav_click", payload={"view": "home"}
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(len(worker.captured), 1)
|
||||
|
||||
async def test_handler_runs_in_separate_task_so_bus_is_not_blocked(self):
|
||||
gate = asyncio.Event()
|
||||
observed: list[bool] = []
|
||||
|
||||
class _BlockingWorker(_StubUIWorker):
|
||||
@on_ui_event("slow")
|
||||
async def _slow(self, message):
|
||||
await gate.wait()
|
||||
observed.append(True)
|
||||
|
||||
worker = await _make_worker(cls=_BlockingWorker)
|
||||
|
||||
await worker.on_bus_message(
|
||||
BusUIEventMessage(source="music", target="ui", event_name="slow", payload={})
|
||||
)
|
||||
self.assertEqual(observed, [])
|
||||
|
||||
gate.set()
|
||||
await _settle()
|
||||
self.assertEqual(observed, [True])
|
||||
|
||||
async def test_duplicate_handler_names_raise_at_init(self):
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
class _Bad(UIWorker):
|
||||
@on_ui_event("nav")
|
||||
async def a(self, message):
|
||||
pass
|
||||
|
||||
@on_ui_event("nav")
|
||||
async def b(self, message):
|
||||
pass
|
||||
|
||||
_Bad("ui", llm=MagicMock())
|
||||
|
||||
async def test_bridged_with_default_auto_inject_raises(self):
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
_PlainWorker("ui", llm=MagicMock(), bridged=())
|
||||
self.assertIn("bridged", str(ctx.exception))
|
||||
self.assertIn("auto_inject_ui_state", str(ctx.exception))
|
||||
|
||||
async def test_bridged_with_explicit_auto_inject_disabled_is_allowed(self):
|
||||
worker = _PlainWorker("ui", llm=MagicMock(), bridged=(), auto_inject_ui_state=False)
|
||||
self.assertFalse(worker._auto_inject_ui_state)
|
||||
|
||||
async def test_default_construction_unaffected(self):
|
||||
worker = _PlainWorker("ui", llm=MagicMock())
|
||||
self.assertTrue(worker._auto_inject_ui_state)
|
||||
self.assertTrue(worker.active)
|
||||
|
||||
|
||||
class TestUIWorkerInjection(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_injects_xml_developer_message_by_default(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
|
||||
),
|
||||
)
|
||||
|
||||
frames = _append_frames(worker)
|
||||
self.assertEqual(len(frames), 1)
|
||||
|
||||
frame = frames[0]
|
||||
self.assertFalse(frame.run_llm)
|
||||
self.assertEqual(len(frame.messages), 1)
|
||||
msg = frame.messages[0]
|
||||
self.assertEqual(msg["role"], "developer")
|
||||
|
||||
content = msg["content"]
|
||||
self.assertIn('<ui_event name="nav_click">', content)
|
||||
self.assertIn("</ui_event>", content)
|
||||
inner = content[len('<ui_event name="nav_click">') : -len("</ui_event>")]
|
||||
self.assertEqual(json.loads(inner), {"view": "home"})
|
||||
|
||||
async def test_inject_events_false_disables_injection(self):
|
||||
worker = await _make_worker(inject_events=False)
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(_append_frames(worker), [])
|
||||
self.assertEqual(len(worker.captured), 1)
|
||||
|
||||
async def test_render_override_replaces_default_xml(self):
|
||||
class _CustomRender(_StubUIWorker):
|
||||
def render_ui_event(self, message):
|
||||
return f"[UI] {message.event_name}"
|
||||
|
||||
worker = await _make_worker(cls=_CustomRender)
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
|
||||
),
|
||||
)
|
||||
|
||||
frames = _append_frames(worker)
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].messages[0]["content"], "[UI] nav_click")
|
||||
|
||||
async def test_empty_render_skips_injection(self):
|
||||
class _NoRender(_StubUIWorker):
|
||||
def render_ui_event(self, message):
|
||||
return ""
|
||||
|
||||
worker = await _make_worker(cls=_NoRender)
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music", target="ui", event_name="nav_click", payload={"view": "home"}
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(_append_frames(worker), [])
|
||||
|
||||
|
||||
_SAMPLE_SNAPSHOT = {
|
||||
"root": {
|
||||
"ref": "e1",
|
||||
"role": "generic",
|
||||
"children": [
|
||||
{
|
||||
"ref": "e2",
|
||||
"role": "main",
|
||||
"children": [
|
||||
{"ref": "e3", "role": "heading", "name": "Home", "level": 1},
|
||||
{
|
||||
"ref": "e4",
|
||||
"role": "region",
|
||||
"name": "Trending artists",
|
||||
"children": [
|
||||
{"ref": "e5", "role": "button", "name": "Bad Bunny"},
|
||||
{
|
||||
"ref": "e6",
|
||||
"role": "button",
|
||||
"name": "Taylor Swift",
|
||||
"state": ["focused"],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"captured_at": 1700000000000,
|
||||
}
|
||||
|
||||
|
||||
class TestUIWorkerSnapshot(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reserved_snapshot_event_stored_without_dispatch(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music",
|
||||
target="ui",
|
||||
event_name=_UI_SNAPSHOT_BUS_EVENT_NAME,
|
||||
payload=_SAMPLE_SNAPSHOT,
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(worker._latest_snapshot, _SAMPLE_SNAPSHOT)
|
||||
self.assertEqual(worker.captured, [])
|
||||
self.assertEqual(_append_frames(worker), [])
|
||||
|
||||
async def test_non_dict_snapshot_payload_is_ignored(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
await _dispatch(
|
||||
worker,
|
||||
BusUIEventMessage(
|
||||
source="music",
|
||||
target="ui",
|
||||
event_name=_UI_SNAPSHOT_BUS_EVENT_NAME,
|
||||
payload="not a snapshot",
|
||||
),
|
||||
)
|
||||
|
||||
self.assertIsNone(worker._latest_snapshot)
|
||||
|
||||
async def test_render_ui_state_empty_without_snapshot(self):
|
||||
worker = await _make_worker()
|
||||
self.assertEqual(worker.render_ui_state(), "")
|
||||
|
||||
async def test_render_ui_state_produces_indented_block(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
rendered = worker.render_ui_state()
|
||||
|
||||
self.assertTrue(rendered.startswith("<ui_state>\n"))
|
||||
self.assertTrue(rendered.endswith("\n</ui_state>"))
|
||||
|
||||
self.assertIn("- generic [ref=e1]:", rendered)
|
||||
self.assertIn("- main [ref=e2]:", rendered)
|
||||
self.assertIn('- heading "Home" [level=1] [ref=e3]', rendered)
|
||||
self.assertIn('- region "Trending artists" [ref=e4]:', rendered)
|
||||
self.assertIn('- button "Bad Bunny" [ref=e5]', rendered)
|
||||
self.assertIn('- button "Taylor Swift" [focused] [ref=e6]', rendered)
|
||||
|
||||
self.assertIn(" - main", rendered)
|
||||
self.assertIn(" - heading", rendered)
|
||||
self.assertIn(" - button", rendered)
|
||||
|
||||
async def test_inject_ui_state_queues_expected_frame(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
await worker.inject_ui_state()
|
||||
|
||||
frames = _append_frames(worker)
|
||||
self.assertEqual(len(frames), 1)
|
||||
frame = frames[0]
|
||||
self.assertFalse(frame.run_llm)
|
||||
self.assertEqual(len(frame.messages), 1)
|
||||
msg = frame.messages[0]
|
||||
self.assertEqual(msg["role"], "developer")
|
||||
self.assertTrue(msg["content"].startswith("<ui_state>"))
|
||||
self.assertTrue(msg["content"].endswith("</ui_state>"))
|
||||
|
||||
async def test_inject_ui_state_no_op_without_snapshot(self):
|
||||
worker = await _make_worker()
|
||||
await worker.inject_ui_state()
|
||||
self.assertEqual(_append_frames(worker), [])
|
||||
|
||||
async def test_render_emits_grid_dims(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = {
|
||||
"root": {
|
||||
"ref": "e1",
|
||||
"role": "generic",
|
||||
"children": [
|
||||
{
|
||||
"ref": "e2",
|
||||
"role": "grid",
|
||||
"name": "Trending artists",
|
||||
"colcount": 8,
|
||||
"rowcount": 2,
|
||||
"children": [{"ref": "e3", "role": "button", "name": "Bad Bunny"}],
|
||||
},
|
||||
],
|
||||
},
|
||||
"captured_at": 1700000000000,
|
||||
}
|
||||
|
||||
rendered = worker.render_ui_state()
|
||||
self.assertIn('- grid "Trending artists" [cols=8] [rows=2] [ref=e2]', rendered)
|
||||
|
||||
async def test_render_preserves_offscreen_tag(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = {
|
||||
"root": {
|
||||
"ref": "e1",
|
||||
"role": "generic",
|
||||
"children": [
|
||||
{"ref": "e2", "role": "button", "name": "Visible"},
|
||||
{"ref": "e3", "role": "button", "name": "Below fold", "state": ["offscreen"]},
|
||||
],
|
||||
},
|
||||
"captured_at": 1700000000000,
|
||||
}
|
||||
|
||||
rendered = worker.render_ui_state()
|
||||
self.assertIn('- button "Visible" [ref=e2]', rendered)
|
||||
self.assertIn('- button "Below fold" [offscreen] [ref=e3]', rendered)
|
||||
|
||||
async def test_visible_nodes_empty_without_snapshot(self):
|
||||
worker = await _make_worker()
|
||||
self.assertEqual(worker.visible_nodes(), [])
|
||||
|
||||
async def test_render_emits_selection_block_when_present(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = {
|
||||
"root": {"ref": "e1", "role": "generic", "children": [{"ref": "e2", "role": "main"}]},
|
||||
"captured_at": 1700000000000,
|
||||
"selection": {"ref": "e2", "text": "the highlighted passage"},
|
||||
}
|
||||
|
||||
rendered = worker.render_ui_state()
|
||||
self.assertIn('<selection ref="e2">', rendered)
|
||||
self.assertIn("the highlighted passage", rendered)
|
||||
self.assertIn("</selection>", rendered)
|
||||
self.assertTrue(rendered.endswith("</ui_state>"))
|
||||
sel_idx = rendered.index('<selection ref="e2">')
|
||||
close_idx = rendered.index("</ui_state>")
|
||||
self.assertLess(sel_idx, close_idx)
|
||||
|
||||
async def test_render_omits_selection_when_missing(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = {
|
||||
"root": {"ref": "e1", "role": "generic"},
|
||||
"captured_at": 1700000000000,
|
||||
}
|
||||
|
||||
rendered = worker.render_ui_state()
|
||||
self.assertNotIn("<selection", rendered)
|
||||
|
||||
async def test_render_skips_selection_with_missing_ref_or_text(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = {
|
||||
"root": {"ref": "e1", "role": "generic"},
|
||||
"captured_at": 1,
|
||||
"selection": {"ref": "e2"},
|
||||
}
|
||||
self.assertNotIn("<selection", worker.render_ui_state())
|
||||
|
||||
worker._latest_snapshot = {
|
||||
"root": {"ref": "e1", "role": "generic"},
|
||||
"captured_at": 1,
|
||||
"selection": {"text": "stuff"},
|
||||
}
|
||||
self.assertNotIn("<selection", worker.render_ui_state())
|
||||
|
||||
async def test_visible_nodes_filters_offscreen_entries(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = {
|
||||
"root": {
|
||||
"ref": "e1",
|
||||
"role": "generic",
|
||||
"children": [
|
||||
{"ref": "e2", "role": "button", "name": "Visible"},
|
||||
{"ref": "e3", "role": "button", "name": "Below fold", "state": ["offscreen"]},
|
||||
{
|
||||
"ref": "e4",
|
||||
"role": "region",
|
||||
"name": "Tracks",
|
||||
"state": ["offscreen"],
|
||||
"children": [{"ref": "e5", "role": "button", "name": "Bloom"}],
|
||||
},
|
||||
],
|
||||
},
|
||||
"captured_at": 1700000000000,
|
||||
}
|
||||
|
||||
refs = [n["ref"] for n in worker.visible_nodes()]
|
||||
self.assertIn("e1", refs)
|
||||
self.assertIn("e2", refs)
|
||||
self.assertNotIn("e3", refs)
|
||||
self.assertNotIn("e4", refs)
|
||||
self.assertIn("e5", refs)
|
||||
|
||||
|
||||
class TestUIWorkerAutoInject(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_respond_auto_injects_latest_snapshot(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
|
||||
)
|
||||
|
||||
frames = _append_frames(worker)
|
||||
self.assertEqual(len(frames), 2)
|
||||
self.assertEqual(frames[0].messages[0]["role"], "developer")
|
||||
self.assertTrue(frames[0].messages[0]["content"].startswith("<ui_state>"))
|
||||
self.assertFalse(frames[0].run_llm)
|
||||
self.assertEqual(frames[1].messages[0]["content"], "hi")
|
||||
self.assertTrue(frames[1].run_llm)
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_auto_inject_ui_state_false_suppresses_injection(self):
|
||||
worker = await _make_worker(auto_inject_ui_state=False)
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
|
||||
)
|
||||
|
||||
frames = _append_frames(worker)
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertFalse(frames[0].messages[0]["content"].startswith("<ui_state>"))
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_auto_inject_no_op_without_snapshot(self):
|
||||
worker = await _make_worker()
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
|
||||
)
|
||||
|
||||
frames = _append_frames(worker)
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertFalse(frames[0].messages[0]["content"].startswith("<ui_state>"))
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
|
||||
class TestUIWorkerKeepHistory(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_default_resets_context_per_job(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
|
||||
)
|
||||
|
||||
updates = _update_frames(worker)
|
||||
self.assertEqual(len(updates), 1)
|
||||
self.assertEqual(updates[0].messages, [])
|
||||
self.assertFalse(updates[0].run_llm)
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_reset_runs_before_inject(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
|
||||
)
|
||||
|
||||
frame_types = [type(f).__name__ for f in worker._recorded]
|
||||
update_idx = frame_types.index("LLMMessagesUpdateFrame")
|
||||
append_idx = frame_types.index("LLMMessagesAppendFrame")
|
||||
self.assertLess(update_idx, append_idx)
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_keep_history_true_skips_reset(self):
|
||||
worker = await _make_worker(keep_history=True)
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"query": "hi"}),
|
||||
)
|
||||
|
||||
self.assertEqual(_update_frames(worker), [])
|
||||
self.assertEqual(len(_append_frames(worker)), 2)
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_reset_context_method_emits_update_frame(self):
|
||||
worker = await _make_worker(keep_history=True)
|
||||
|
||||
await worker.reset_context()
|
||||
|
||||
updates = _update_frames(worker)
|
||||
self.assertEqual(len(updates), 1)
|
||||
self.assertEqual(updates[0].messages, [])
|
||||
self.assertFalse(updates[0].run_llm)
|
||||
|
||||
|
||||
class TestUIWorkerRespondToJob(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_current_job_tracks_in_flight_request(self):
|
||||
worker = await _make_worker()
|
||||
self.assertIsNone(worker.current_job)
|
||||
message = BusJobRequestMessage(
|
||||
source="voice", target="ui", job_id="t1", payload={"query": "hi"}
|
||||
)
|
||||
t = await _start(worker, message)
|
||||
self.assertIs(worker.current_job, message)
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_respond_to_job_clears_current_and_sends_response(self):
|
||||
worker = await _make_worker()
|
||||
message = BusJobRequestMessage(source="voice", target="ui", job_id="t1")
|
||||
t = await _start(worker, message)
|
||||
|
||||
await worker.respond_to_job(speak="hello")
|
||||
await t
|
||||
|
||||
worker.send_job_response.assert_awaited_once()
|
||||
call = worker.send_job_response.await_args
|
||||
self.assertEqual(call.args[0], "t1")
|
||||
self.assertEqual(call.kwargs["response"], {"speak": "hello"})
|
||||
self.assertIsNone(worker.current_job)
|
||||
|
||||
async def test_respond_to_job_no_op_when_idle(self):
|
||||
worker = await _make_worker()
|
||||
await worker.respond_to_job(speak="hello")
|
||||
worker.send_job_response.assert_not_awaited()
|
||||
|
||||
async def test_respond_to_job_omits_speak_when_none(self):
|
||||
worker = await _make_worker()
|
||||
t = await _start(worker, BusJobRequestMessage(source="voice", target="ui", job_id="t1"))
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
call = worker.send_job_response.await_args
|
||||
self.assertEqual(call.kwargs["response"], {})
|
||||
|
||||
async def test_respond_to_job_merges_speak_into_response(self):
|
||||
worker = await _make_worker()
|
||||
t = await _start(worker, BusJobRequestMessage(source="voice", target="ui", job_id="t1"))
|
||||
|
||||
await worker.respond_to_job({"description": "scrolled"}, speak="ok")
|
||||
await t
|
||||
|
||||
call = worker.send_job_response.await_args
|
||||
self.assertEqual(call.kwargs["response"], {"description": "scrolled", "speak": "ok"})
|
||||
|
||||
async def test_cancellation_frees_lock_for_subsequent_jobs(self):
|
||||
worker = await _make_worker()
|
||||
msg_a = _respond_msg("a")
|
||||
msg_b = _respond_msg("b")
|
||||
|
||||
await worker._handle_job_request(msg_a)
|
||||
await _settle()
|
||||
self.assertIs(worker.current_job, msg_a)
|
||||
self.assertTrue(worker._job_locks["respond"].locked())
|
||||
|
||||
await worker._handle_job_cancel(
|
||||
BusJobCancelMessage(source="voice", target="ui", job_id="a")
|
||||
)
|
||||
await _settle()
|
||||
self.assertIsNone(worker.current_job)
|
||||
self.assertFalse(worker._job_locks["respond"].locked())
|
||||
|
||||
await worker._handle_job_request(msg_b)
|
||||
await _settle()
|
||||
self.assertIs(worker.current_job, msg_b)
|
||||
|
||||
await worker.respond_to_job(speak="B done")
|
||||
await _settle()
|
||||
|
||||
async def test_cancel_unknown_job_id_is_noop(self):
|
||||
worker = await _make_worker()
|
||||
msg_a = _respond_msg("a")
|
||||
|
||||
await worker._handle_job_request(msg_a)
|
||||
await _settle()
|
||||
|
||||
await worker._handle_job_cancel(
|
||||
BusJobCancelMessage(source="voice", target="ui", job_id="unrelated")
|
||||
)
|
||||
await _settle()
|
||||
|
||||
self.assertIs(worker.current_job, msg_a)
|
||||
self.assertTrue(worker._job_locks["respond"].locked())
|
||||
|
||||
await worker.respond_to_job(speak="A done")
|
||||
await _settle()
|
||||
|
||||
async def test_concurrent_same_name_jobs_serialize(self):
|
||||
worker = await _make_worker()
|
||||
msg_a = _respond_msg("a")
|
||||
msg_b = _respond_msg("b")
|
||||
|
||||
await worker._handle_job_request(msg_a)
|
||||
await _settle()
|
||||
self.assertIs(worker.current_job, msg_a)
|
||||
|
||||
await worker._handle_job_request(msg_b)
|
||||
await _settle()
|
||||
self.assertIs(worker.current_job, msg_a)
|
||||
|
||||
await worker.respond_to_job(speak="A done")
|
||||
await _settle()
|
||||
self.assertIs(worker.current_job, msg_b)
|
||||
|
||||
await worker.respond_to_job(speak="B done")
|
||||
await _settle()
|
||||
self.assertIsNone(worker.current_job)
|
||||
|
||||
|
||||
class TestUIWorkerRespondJob(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_respond_handler_runs_after_setup(self):
|
||||
worker = await _make_worker()
|
||||
worker._latest_snapshot = _SAMPLE_SNAPSHOT
|
||||
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(
|
||||
source="voice", target="ui", job_id="t1", payload={"query": "hello"}
|
||||
),
|
||||
)
|
||||
|
||||
appends = _append_frames(worker)
|
||||
self.assertEqual(len(appends), 2)
|
||||
self.assertTrue(appends[0].messages[0]["content"].startswith("<ui_state>"))
|
||||
self.assertFalse(appends[0].run_llm)
|
||||
self.assertEqual(appends[1].messages[0]["content"], "hello")
|
||||
self.assertTrue(appends[1].run_llm)
|
||||
self.assertEqual(worker.current_job.job_id, "t1")
|
||||
|
||||
await worker.respond_to_job(speak="done")
|
||||
await t
|
||||
self.assertIsNone(worker.current_job)
|
||||
worker.send_job_response.assert_awaited_once()
|
||||
|
||||
async def test_render_query_override(self):
|
||||
class _Custom(_StubUIWorker):
|
||||
def render_query(self, message):
|
||||
return f"Q: {message.payload['q']}"
|
||||
|
||||
worker = await _make_worker(cls=_Custom)
|
||||
t = await _start(
|
||||
worker,
|
||||
BusJobRequestMessage(source="voice", target="ui", job_id="t1", payload={"q": "hi"}),
|
||||
)
|
||||
|
||||
query_frames = [f for f in _append_frames(worker) if f.run_llm]
|
||||
self.assertEqual(query_frames[0].messages[0]["content"], "Q: hi")
|
||||
|
||||
await worker.respond_to_job()
|
||||
await t
|
||||
|
||||
async def test_handler_failure_clears_state(self):
|
||||
class _Boom(_StubUIWorker):
|
||||
def render_query(self, message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
worker = await _make_worker(cls=_Boom)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
await worker.respond_with_llm(
|
||||
BusJobRequestMessage(
|
||||
source="voice", target="ui", job_id="t1", payload={"query": "x"}
|
||||
)
|
||||
)
|
||||
|
||||
self.assertIsNone(worker.current_job)
|
||||
self.assertIsNone(worker._pending)
|
||||
|
||||
|
||||
class TestUIWorkerPromptGuide(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_default_appends_ui_state_prompt_guide(self):
|
||||
worker = UIWorker("ui", llm=MagicMock(), active=False)
|
||||
worker.llm.append_system_instruction.assert_called_once_with(UI_STATE_PROMPT_GUIDE)
|
||||
|
||||
async def test_custom_prompt_guide_overrides(self):
|
||||
worker = UIWorker("ui", llm=MagicMock(), active=False, prompt_guide="MY GUIDE")
|
||||
worker.llm.append_system_instruction.assert_called_once_with("MY GUIDE")
|
||||
|
||||
async def test_none_disables_injection(self):
|
||||
worker = UIWorker("ui", llm=MagicMock(), active=False, prompt_guide=None)
|
||||
worker.llm.append_system_instruction.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user