Fix pyright errors in new bus/task/proxy code

- `TaskBus._router_task`: cast the narrowed `SystemFrame` back
  to `BusMessage` for the subscriber callback.
- `bus.network.__init__`: expose `PgmqBus` / `RedisBus` to
  the type-checker via a TYPE_CHECKING block so `__all__` is
  satisfied; runtime path still goes through `__getattr__`.
- `RedisBus`: subscribe through a local before assigning
  `self._pubsub`, and `assert self._pubsub is not None` in
  the reader loop.
- `BaseTask.on_job_error` accepts
  `BusJobResponseMessage | BusJobResponseUrgentMessage` to match
  what is dispatched.
- `JobGroupContext.__aexit__` / `JobContext.__aexit__`: assert
  `self._group is not None` before `wait()`.
- `@task_ready` collector: type handlers dict as `dict[str, Callable]`
  so the `.__name__` read on a duplicate handler typechecks.
- WebSocket proxy client/server: assert the socket is set in
  `_receive_loop`, and decode `str` payloads to bytes before
  handing them to the serializer.
This commit is contained in:
Aleix Conchillo Flaqué
2026-05-13 22:46:47 -07:00
parent 5f86e39038
commit 42204c4d0f
8 changed files with 24 additions and 6 deletions

View File

@@ -13,6 +13,7 @@ Provides the abstract `TaskBus` base class. Concrete implementations
import asyncio
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import cast
from pipecat.bus.messages import BusLocalMessage, BusMessage
from pipecat.bus.queue import BusMessageQueue
@@ -167,7 +168,7 @@ class TaskBus(BaseObject):
while True:
message = await sub.queue.get()
if isinstance(message, SystemFrame):
await sub.subscriber.on_bus_message(message)
await sub.subscriber.on_bus_message(cast(BusMessage, message))
else:
sub.data_queue.put_nowait(message)
except asyncio.CancelledError:

View File

@@ -11,6 +11,12 @@ package can be loaded with only the extras you need; importing a specific
bus without its extra raises a clear error from that submodule.
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pipecat.bus.network.pgmq import PgmqBus
from pipecat.bus.network.redis import RedisBus
__all__ = ["PgmqBus", "RedisBus"]

View File

@@ -68,8 +68,9 @@ class RedisBus(TaskBus):
async def start(self):
"""Subscribe to Redis channel and start the reader task."""
await super().start()
self._pubsub = self._redis.pubsub()
await self._pubsub.subscribe(self._channel)
pubsub = self._redis.pubsub()
await pubsub.subscribe(self._channel)
self._pubsub = pubsub
self._reader_task = self.create_task(self._reader_loop(), f"{self}::redis_reader")
await asyncio.sleep(0)
@@ -96,6 +97,7 @@ class RedisBus(TaskBus):
async def _reader_loop(self) -> None:
"""Read messages from Redis pub/sub and deliver to subscribers."""
assert self._pubsub is not None, "start() must be called before _reader_loop"
async for raw_message in self._pubsub.listen():
if raw_message["type"] != "message":
continue

View File

@@ -441,7 +441,9 @@ class BaseTask(BaseObject, BusSubscriber):
"""Called when all workers in a job group have responded."""
pass
async def on_job_error(self, message: BusJobResponseMessage) -> None:
async def on_job_error(
self, message: BusJobResponseMessage | BusJobResponseUrgentMessage
) -> None:
"""Called when a job group is cancelled due to a worker error.
Fires when a worker responds with ``ERROR`` or ``FAILED`` status

View File

@@ -256,6 +256,7 @@ class JobGroupContext:
)
return False
assert self._group is not None
await self._group.wait()
return False
@@ -355,6 +356,7 @@ class JobContext:
)
return False
assert self._group is not None
try:
await self._group.wait()
except JobGroupError as e:

View File

@@ -6,6 +6,8 @@
"""Decorator for marking methods as task-ready handlers."""
from collections.abc import Callable
def task_ready(*, name: str):
"""Mark a method as a handler for a specific task becoming ready.
@@ -41,7 +43,7 @@ def _collect_task_ready_handlers(obj) -> dict:
ValueError: If two handlers watch the same task name.
"""
seen: set[str] = set()
handlers: dict[str, object] = {}
handlers: dict[str, Callable] = {}
for cls in type(obj).__mro__:
for attr_name, val in cls.__dict__.items():
if attr_name in seen:

View File

@@ -183,10 +183,12 @@ class WebSocketProxyClientTask(BaseTask):
async def _receive_loop(self) -> None:
"""Read messages from the WebSocket and put them on the local bus."""
assert self._ws is not None, "on_activated() must run before _receive_loop"
try:
async for data in self._ws:
try:
message = self._serializer.deserialize(data)
payload = data if isinstance(data, bytes) else data.encode()
message = self._serializer.deserialize(payload)
if not message:
continue

View File

@@ -200,6 +200,7 @@ class WebSocketProxyServerTask(BaseTask):
async def _receive_loop(self) -> None:
"""Read messages from the WebSocket and put them on the local bus."""
assert self._ws is not None, "start() must run before _receive_loop"
try:
while True:
data = await self._ws.receive_bytes()