From 42204c4d0f4ae5acf6d25de6dd40a56a04437b48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 13 May 2026 22:46:47 -0700 Subject: [PATCH] Fix pyright errors in new bus/task/proxy code - `TaskBus._router_task`: cast the narrowed `SystemFrame` back to `BusMessage` for the subscriber callback. - `bus.network.__init__`: expose `PgmqBus` / `RedisBus` to the type-checker via a TYPE_CHECKING block so `__all__` is satisfied; runtime path still goes through `__getattr__`. - `RedisBus`: subscribe through a local before assigning `self._pubsub`, and `assert self._pubsub is not None` in the reader loop. - `BaseTask.on_job_error` accepts `BusJobResponseMessage | BusJobResponseUrgentMessage` to match what is dispatched. - `JobGroupContext.__aexit__` / `JobContext.__aexit__`: assert `self._group is not None` before `wait()`. - `@task_ready` collector: type handlers dict as `dict[str, Callable]` so the `.__name__` read on a duplicate handler typechecks. - WebSocket proxy client/server: assert the socket is set in `_receive_loop`, and decode `str` payloads to bytes before handing them to the serializer. --- src/pipecat/bus/bus.py | 3 ++- src/pipecat/bus/network/__init__.py | 6 ++++++ src/pipecat/bus/network/redis.py | 6 ++++-- src/pipecat/pipeline/base_task.py | 4 +++- src/pipecat/pipeline/job_context.py | 2 ++ src/pipecat/pipeline/task_ready_decorator.py | 4 +++- src/pipecat/tasks/proxy/websocket/client.py | 4 +++- src/pipecat/tasks/proxy/websocket/server.py | 1 + 8 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/pipecat/bus/bus.py b/src/pipecat/bus/bus.py index f8fec267b..14bb2a084 100644 --- a/src/pipecat/bus/bus.py +++ b/src/pipecat/bus/bus.py @@ -13,6 +13,7 @@ Provides the abstract `TaskBus` base class. Concrete implementations import asyncio from abc import abstractmethod from dataclasses import dataclass, field +from typing import cast from pipecat.bus.messages import BusLocalMessage, BusMessage from pipecat.bus.queue import BusMessageQueue @@ -167,7 +168,7 @@ class TaskBus(BaseObject): while True: message = await sub.queue.get() if isinstance(message, SystemFrame): - await sub.subscriber.on_bus_message(message) + await sub.subscriber.on_bus_message(cast(BusMessage, message)) else: sub.data_queue.put_nowait(message) except asyncio.CancelledError: diff --git a/src/pipecat/bus/network/__init__.py b/src/pipecat/bus/network/__init__.py index 2146207e3..4ff4848a2 100644 --- a/src/pipecat/bus/network/__init__.py +++ b/src/pipecat/bus/network/__init__.py @@ -11,6 +11,12 @@ package can be loaded with only the extras you need; importing a specific bus without its extra raises a clear error from that submodule. """ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pipecat.bus.network.pgmq import PgmqBus + from pipecat.bus.network.redis import RedisBus + __all__ = ["PgmqBus", "RedisBus"] diff --git a/src/pipecat/bus/network/redis.py b/src/pipecat/bus/network/redis.py index 784b2b5e0..55e2d55ea 100644 --- a/src/pipecat/bus/network/redis.py +++ b/src/pipecat/bus/network/redis.py @@ -68,8 +68,9 @@ class RedisBus(TaskBus): async def start(self): """Subscribe to Redis channel and start the reader task.""" await super().start() - self._pubsub = self._redis.pubsub() - await self._pubsub.subscribe(self._channel) + pubsub = self._redis.pubsub() + await pubsub.subscribe(self._channel) + self._pubsub = pubsub self._reader_task = self.create_task(self._reader_loop(), f"{self}::redis_reader") await asyncio.sleep(0) @@ -96,6 +97,7 @@ class RedisBus(TaskBus): async def _reader_loop(self) -> None: """Read messages from Redis pub/sub and deliver to subscribers.""" + assert self._pubsub is not None, "start() must be called before _reader_loop" async for raw_message in self._pubsub.listen(): if raw_message["type"] != "message": continue diff --git a/src/pipecat/pipeline/base_task.py b/src/pipecat/pipeline/base_task.py index 08a052024..ef7daeee3 100644 --- a/src/pipecat/pipeline/base_task.py +++ b/src/pipecat/pipeline/base_task.py @@ -441,7 +441,9 @@ class BaseTask(BaseObject, BusSubscriber): """Called when all workers in a job group have responded.""" pass - async def on_job_error(self, message: BusJobResponseMessage) -> None: + async def on_job_error( + self, message: BusJobResponseMessage | BusJobResponseUrgentMessage + ) -> None: """Called when a job group is cancelled due to a worker error. Fires when a worker responds with ``ERROR`` or ``FAILED`` status diff --git a/src/pipecat/pipeline/job_context.py b/src/pipecat/pipeline/job_context.py index 4cc42b1e3..222896c66 100644 --- a/src/pipecat/pipeline/job_context.py +++ b/src/pipecat/pipeline/job_context.py @@ -256,6 +256,7 @@ class JobGroupContext: ) return False + assert self._group is not None await self._group.wait() return False @@ -355,6 +356,7 @@ class JobContext: ) return False + assert self._group is not None try: await self._group.wait() except JobGroupError as e: diff --git a/src/pipecat/pipeline/task_ready_decorator.py b/src/pipecat/pipeline/task_ready_decorator.py index 982652a61..25933e32b 100644 --- a/src/pipecat/pipeline/task_ready_decorator.py +++ b/src/pipecat/pipeline/task_ready_decorator.py @@ -6,6 +6,8 @@ """Decorator for marking methods as task-ready handlers.""" +from collections.abc import Callable + def task_ready(*, name: str): """Mark a method as a handler for a specific task becoming ready. @@ -41,7 +43,7 @@ def _collect_task_ready_handlers(obj) -> dict: ValueError: If two handlers watch the same task name. """ seen: set[str] = set() - handlers: dict[str, object] = {} + handlers: dict[str, Callable] = {} for cls in type(obj).__mro__: for attr_name, val in cls.__dict__.items(): if attr_name in seen: diff --git a/src/pipecat/tasks/proxy/websocket/client.py b/src/pipecat/tasks/proxy/websocket/client.py index 56107b26e..ab7770588 100644 --- a/src/pipecat/tasks/proxy/websocket/client.py +++ b/src/pipecat/tasks/proxy/websocket/client.py @@ -183,10 +183,12 @@ class WebSocketProxyClientTask(BaseTask): async def _receive_loop(self) -> None: """Read messages from the WebSocket and put them on the local bus.""" + assert self._ws is not None, "on_activated() must run before _receive_loop" try: async for data in self._ws: try: - message = self._serializer.deserialize(data) + payload = data if isinstance(data, bytes) else data.encode() + message = self._serializer.deserialize(payload) if not message: continue diff --git a/src/pipecat/tasks/proxy/websocket/server.py b/src/pipecat/tasks/proxy/websocket/server.py index 8fa141a1d..e67ddfd8d 100644 --- a/src/pipecat/tasks/proxy/websocket/server.py +++ b/src/pipecat/tasks/proxy/websocket/server.py @@ -200,6 +200,7 @@ class WebSocketProxyServerTask(BaseTask): async def _receive_loop(self) -> None: """Read messages from the WebSocket and put them on the local bus.""" + assert self._ws is not None, "start() must run before _receive_loop" try: while True: data = await self._ws.receive_bytes()