Replaces every "task" identifier that referred to the BaseTask abstraction with "worker". Asyncio task plumbing (asyncio.Task, BaseTaskManager, TaskManager, create_task, cancel_task, etc.) stays untouched. Highlights: - Classes: BaseTask → BaseWorker, PipelineTask → PipelineWorker, LLMTask → LLMWorker, LLMContextTask → LLMContextWorker, TaskBus → WorkerBus, TaskRegistry → WorkerRegistry, TaskActivationArgs → WorkerActivationArgs, TaskReadyData → WorkerReadyData, TaskRegistryEntry → WorkerRegistryEntry, TaskObserver → WorkerObserver, all Bus*TaskMessage → Bus*WorkerMessage, BusAddTaskMessage.task field → worker, BusWorkerRegistryMessage.tasks field → workers. - Methods/decorators: activate_task → activate_worker, deactivate_task → deactivate_worker, add_task → add_worker, watch_task → watch_worker, @task_ready → @worker_ready, setup_pipeline_task hook → setup_pipeline_worker. - Params/fields: FrameProcessorSetup.pipeline_task and FunctionCallParams.pipeline_task → pipeline_worker. Parameter names like task_name → worker_name; spawn/run accept worker:. - Files: pipeline/base_task.py → base_worker.py, pipeline/task.py → worker.py (plus a re-export shim at pipeline/task.py), task_observer.py → worker_observer.py, task_ready_decorator.py → worker_ready_decorator.py, pipecat.tasks → pipecat.workers, llm_task.py → llm_worker.py, llm_context_task.py → llm_context_worker.py, examples/multi-task → examples/multi-worker. Back-compat: - PipelineTask kept as a deprecated subclass of PipelineWorker that warns on construction. - pipecat.pipeline.task re-exports PipelineWorker/PipelineTask/etc. so existing user imports keep working. - FrameProcessor.pipeline_task kept as a deprecated property that forwards to pipeline_worker. Local variables in examples that hold a worker (task = PipelineTask(...)) are renamed to worker = PipelineWorker(...). Asyncio-task locals (runner_task, etc.) are preserved.
319 lines
10 KiB
Python
319 lines
10 KiB
Python
#
|
|
# Copyright (c) 2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import asyncio
|
|
import unittest
|
|
from unittest.mock import MagicMock
|
|
|
|
from pipecat.bus import (
|
|
AsyncQueueBus,
|
|
BusAddWorkerMessage,
|
|
BusDataMessage,
|
|
)
|
|
from pipecat.bus.serializers import JSONMessageSerializer
|
|
from pipecat.pipeline.base_worker import BaseWorker
|
|
from pipecat.registry import WorkerRegistry
|
|
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
|
|
|
|
|
async def create_test_bus():
|
|
"""Create an AsyncQueueBus with a TaskManager for testing."""
|
|
bus = AsyncQueueBus()
|
|
tm = TaskManager()
|
|
tm.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
|
await bus.setup(tm)
|
|
return bus, tm
|
|
|
|
|
|
class FakeWebSocket:
|
|
"""Fake WebSocket for testing the client proxy."""
|
|
|
|
def __init__(self):
|
|
self._sent: list[bytes] = []
|
|
self._receive_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
|
self.closed = False
|
|
|
|
async def send(self, data: bytes):
|
|
self._sent.append(data)
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
try:
|
|
return await asyncio.wait_for(self._receive_queue.get(), timeout=0.5)
|
|
except TimeoutError:
|
|
raise StopAsyncIteration
|
|
|
|
def inject(self, data: bytes):
|
|
"""Inject data as if received from the remote side."""
|
|
self._receive_queue.put_nowait(data)
|
|
|
|
|
|
class FakeStarletteWebSocket:
|
|
"""Fake Starlette WebSocket for testing the server proxy."""
|
|
|
|
def __init__(self):
|
|
self._sent: list[bytes] = []
|
|
self._receive_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
|
self.closed = False
|
|
|
|
async def send_bytes(self, data: bytes):
|
|
self._sent.append(data)
|
|
|
|
async def receive_bytes(self) -> bytes:
|
|
return await self._receive_queue.get()
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
|
|
@property
|
|
def client_state(self):
|
|
return MagicMock(value=1) if not self.closed else MagicMock(value=3)
|
|
|
|
def inject(self, data: bytes):
|
|
"""Inject data as if received from the remote client."""
|
|
self._receive_queue.put_nowait(data)
|
|
|
|
|
|
class TestWebSocketProxyClientTask(unittest.IsolatedAsyncioTestCase):
|
|
async def asyncSetUp(self):
|
|
self.bus, self.tm = await create_test_bus()
|
|
self.registry = WorkerRegistry(runner_name="test-runner")
|
|
self.serializer = JSONMessageSerializer()
|
|
|
|
async def _create_client(self, fake_ws):
|
|
from pipecat.workers.proxy.websocket.client import WebSocketProxyClientTask
|
|
|
|
task = WebSocketProxyClientTask(
|
|
"proxy",
|
|
url="ws://fake",
|
|
remote_worker_name="worker",
|
|
local_worker_name="voice",
|
|
serializer=self.serializer,
|
|
)
|
|
task.attach(registry=self.registry, bus=self.bus)
|
|
await task.setup(self.tm)
|
|
task._ws = fake_ws
|
|
return task
|
|
|
|
async def test_forwards_targeted_messages(self):
|
|
"""Messages targeted at the remote task are forwarded."""
|
|
fake_ws = FakeWebSocket()
|
|
task = await self._create_client(fake_ws)
|
|
|
|
msg = BusDataMessage(source="voice", target="worker")
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 1)
|
|
restored = self.serializer.deserialize(fake_ws._sent[0])
|
|
self.assertEqual(restored.source, "voice")
|
|
self.assertEqual(restored.target, "worker")
|
|
|
|
async def test_skips_messages_for_other_tasks(self):
|
|
"""Messages targeted at other tasks are not forwarded."""
|
|
fake_ws = FakeWebSocket()
|
|
task = await self._create_client(fake_ws)
|
|
|
|
msg = BusDataMessage(source="voice", target="other_task")
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 0)
|
|
|
|
async def test_skips_broadcast_messages(self):
|
|
"""Broadcast messages (target=None) are not forwarded."""
|
|
fake_ws = FakeWebSocket()
|
|
task = await self._create_client(fake_ws)
|
|
|
|
msg = BusDataMessage(source="voice")
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 0)
|
|
|
|
async def test_skips_local_messages(self):
|
|
"""BusLocalMessage messages are not forwarded."""
|
|
fake_ws = FakeWebSocket()
|
|
task = await self._create_client(fake_ws)
|
|
|
|
stub = MagicMock(spec=BaseWorker)
|
|
stub.name = "child"
|
|
msg = BusAddWorkerMessage(source="parent", target="worker", worker=stub)
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 0)
|
|
|
|
async def test_accepts_inbound_for_local_task(self):
|
|
"""Inbound messages targeted at the local task are accepted."""
|
|
fake_ws = FakeWebSocket()
|
|
task = await self._create_client(fake_ws)
|
|
|
|
sent_to_bus = []
|
|
original_send = task.send_bus_message
|
|
|
|
async def capture_send(message):
|
|
sent_to_bus.append(message)
|
|
await original_send(message)
|
|
|
|
task.send_bus_message = capture_send
|
|
|
|
msg = BusDataMessage(source="worker", target="voice")
|
|
fake_ws.inject(self.serializer.serialize(msg))
|
|
|
|
receive_task = asyncio.create_task(task._receive_loop())
|
|
await asyncio.sleep(0.1)
|
|
receive_task.cancel()
|
|
try:
|
|
await receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self.assertEqual(len(sent_to_bus), 1)
|
|
self.assertEqual(sent_to_bus[0].source, "worker")
|
|
self.assertEqual(sent_to_bus[0].target, "voice")
|
|
|
|
async def test_drops_inbound_for_other_tasks(self):
|
|
"""Inbound messages targeted at other tasks are dropped."""
|
|
fake_ws = FakeWebSocket()
|
|
task = await self._create_client(fake_ws)
|
|
|
|
sent_to_bus = []
|
|
original_send = task.send_bus_message
|
|
|
|
async def capture_send(message):
|
|
sent_to_bus.append(message)
|
|
await original_send(message)
|
|
|
|
task.send_bus_message = capture_send
|
|
|
|
msg = BusDataMessage(source="worker", target="other_task")
|
|
fake_ws.inject(self.serializer.serialize(msg))
|
|
|
|
receive_task = asyncio.create_task(task._receive_loop())
|
|
await asyncio.sleep(0.1)
|
|
receive_task.cancel()
|
|
try:
|
|
await receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self.assertEqual(len(sent_to_bus), 0)
|
|
|
|
|
|
class TestWebSocketProxyServerTask(unittest.IsolatedAsyncioTestCase):
|
|
async def asyncSetUp(self):
|
|
self.bus, self.tm = await create_test_bus()
|
|
self.registry = WorkerRegistry(runner_name="test-runner")
|
|
self.serializer = JSONMessageSerializer()
|
|
|
|
async def _create_server(self, fake_ws):
|
|
from pipecat.workers.proxy.websocket.server import WebSocketProxyServerTask
|
|
|
|
task = WebSocketProxyServerTask(
|
|
"gateway",
|
|
websocket=fake_ws,
|
|
worker_name="worker",
|
|
remote_worker_name="voice",
|
|
serializer=self.serializer,
|
|
)
|
|
task.attach(registry=self.registry, bus=self.bus)
|
|
await task.setup(self.tm)
|
|
return task
|
|
|
|
async def test_forwards_messages_from_local_task(self):
|
|
"""Messages from the local task targeted at the remote task are forwarded."""
|
|
fake_ws = FakeStarletteWebSocket()
|
|
task = await self._create_server(fake_ws)
|
|
|
|
msg = BusDataMessage(source="worker", target="voice")
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 1)
|
|
restored = self.serializer.deserialize(fake_ws._sent[0])
|
|
self.assertEqual(restored.source, "worker")
|
|
self.assertEqual(restored.target, "voice")
|
|
|
|
async def test_skips_messages_from_other_tasks(self):
|
|
"""Messages from other tasks are not forwarded."""
|
|
fake_ws = FakeStarletteWebSocket()
|
|
task = await self._create_server(fake_ws)
|
|
|
|
msg = BusDataMessage(source="other_task", target="voice")
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 0)
|
|
|
|
async def test_skips_messages_to_other_targets(self):
|
|
"""Messages from the local task to other targets are not forwarded."""
|
|
fake_ws = FakeStarletteWebSocket()
|
|
task = await self._create_server(fake_ws)
|
|
|
|
msg = BusDataMessage(source="worker", target="other_task")
|
|
await task.on_bus_message(msg)
|
|
|
|
self.assertEqual(len(fake_ws._sent), 0)
|
|
|
|
async def test_accepts_inbound_for_local_task(self):
|
|
"""Inbound messages targeted at the local task are accepted."""
|
|
fake_ws = FakeStarletteWebSocket()
|
|
task = await self._create_server(fake_ws)
|
|
|
|
sent_to_bus = []
|
|
original_send = task.send_bus_message
|
|
|
|
async def capture_send(message):
|
|
sent_to_bus.append(message)
|
|
await original_send(message)
|
|
|
|
task.send_bus_message = capture_send
|
|
|
|
msg = BusDataMessage(source="voice", target="worker")
|
|
fake_ws.inject(self.serializer.serialize(msg))
|
|
|
|
receive_task = asyncio.create_task(task._receive_loop())
|
|
await asyncio.sleep(0.1)
|
|
receive_task.cancel()
|
|
try:
|
|
await receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self.assertEqual(len(sent_to_bus), 1)
|
|
self.assertEqual(sent_to_bus[0].target, "worker")
|
|
|
|
async def test_drops_inbound_for_other_tasks(self):
|
|
"""Inbound messages targeted at other tasks are dropped."""
|
|
fake_ws = FakeStarletteWebSocket()
|
|
task = await self._create_server(fake_ws)
|
|
|
|
sent_to_bus = []
|
|
original_send = task.send_bus_message
|
|
|
|
async def capture_send(message):
|
|
sent_to_bus.append(message)
|
|
await original_send(message)
|
|
|
|
task.send_bus_message = capture_send
|
|
|
|
msg = BusDataMessage(source="voice", target="other_task")
|
|
fake_ws.inject(self.serializer.serialize(msg))
|
|
|
|
receive_task = asyncio.create_task(task._receive_loop())
|
|
await asyncio.sleep(0.1)
|
|
receive_task.cancel()
|
|
try:
|
|
await receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self.assertEqual(len(sent_to_bus), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|