Fix pyright errors in new bus/task/proxy code
- `TaskBus._router_task`: cast the narrowed `SystemFrame` back to `BusMessage` for the subscriber callback. - `bus.network.__init__`: expose `PgmqBus` / `RedisBus` to the type-checker via a TYPE_CHECKING block so `__all__` is satisfied; runtime path still goes through `__getattr__`. - `RedisBus`: subscribe through a local before assigning `self._pubsub`, and `assert self._pubsub is not None` in the reader loop. - `BaseTask.on_job_error` accepts `BusJobResponseMessage | BusJobResponseUrgentMessage` to match what is dispatched. - `JobGroupContext.__aexit__` / `JobContext.__aexit__`: assert `self._group is not None` before `wait()`. - `@task_ready` collector: type handlers dict as `dict[str, Callable]` so the `.__name__` read on a duplicate handler typechecks. - WebSocket proxy client/server: assert the socket is set in `_receive_loop`, and decode `str` payloads to bytes before handing them to the serializer.
This commit is contained in:
@@ -13,6 +13,7 @@ Provides the abstract `TaskBus` base class. Concrete implementations
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast
|
||||
|
||||
from pipecat.bus.messages import BusLocalMessage, BusMessage
|
||||
from pipecat.bus.queue import BusMessageQueue
|
||||
@@ -167,7 +168,7 @@ class TaskBus(BaseObject):
|
||||
while True:
|
||||
message = await sub.queue.get()
|
||||
if isinstance(message, SystemFrame):
|
||||
await sub.subscriber.on_bus_message(message)
|
||||
await sub.subscriber.on_bus_message(cast(BusMessage, message))
|
||||
else:
|
||||
sub.data_queue.put_nowait(message)
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -11,6 +11,12 @@ package can be loaded with only the extras you need; importing a specific
|
||||
bus without its extra raises a clear error from that submodule.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.bus.network.pgmq import PgmqBus
|
||||
from pipecat.bus.network.redis import RedisBus
|
||||
|
||||
__all__ = ["PgmqBus", "RedisBus"]
|
||||
|
||||
|
||||
|
||||
@@ -68,8 +68,9 @@ class RedisBus(TaskBus):
|
||||
async def start(self):
|
||||
"""Subscribe to Redis channel and start the reader task."""
|
||||
await super().start()
|
||||
self._pubsub = self._redis.pubsub()
|
||||
await self._pubsub.subscribe(self._channel)
|
||||
pubsub = self._redis.pubsub()
|
||||
await pubsub.subscribe(self._channel)
|
||||
self._pubsub = pubsub
|
||||
self._reader_task = self.create_task(self._reader_loop(), f"{self}::redis_reader")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@@ -96,6 +97,7 @@ class RedisBus(TaskBus):
|
||||
|
||||
async def _reader_loop(self) -> None:
|
||||
"""Read messages from Redis pub/sub and deliver to subscribers."""
|
||||
assert self._pubsub is not None, "start() must be called before _reader_loop"
|
||||
async for raw_message in self._pubsub.listen():
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
|
||||
@@ -441,7 +441,9 @@ class BaseTask(BaseObject, BusSubscriber):
|
||||
"""Called when all workers in a job group have responded."""
|
||||
pass
|
||||
|
||||
async def on_job_error(self, message: BusJobResponseMessage) -> None:
|
||||
async def on_job_error(
|
||||
self, message: BusJobResponseMessage | BusJobResponseUrgentMessage
|
||||
) -> None:
|
||||
"""Called when a job group is cancelled due to a worker error.
|
||||
|
||||
Fires when a worker responds with ``ERROR`` or ``FAILED`` status
|
||||
|
||||
@@ -256,6 +256,7 @@ class JobGroupContext:
|
||||
)
|
||||
return False
|
||||
|
||||
assert self._group is not None
|
||||
await self._group.wait()
|
||||
return False
|
||||
|
||||
@@ -355,6 +356,7 @@ class JobContext:
|
||||
)
|
||||
return False
|
||||
|
||||
assert self._group is not None
|
||||
try:
|
||||
await self._group.wait()
|
||||
except JobGroupError as e:
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
"""Decorator for marking methods as task-ready handlers."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def task_ready(*, name: str):
|
||||
"""Mark a method as a handler for a specific task becoming ready.
|
||||
@@ -41,7 +43,7 @@ def _collect_task_ready_handlers(obj) -> dict:
|
||||
ValueError: If two handlers watch the same task name.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
handlers: dict[str, object] = {}
|
||||
handlers: dict[str, Callable] = {}
|
||||
for cls in type(obj).__mro__:
|
||||
for attr_name, val in cls.__dict__.items():
|
||||
if attr_name in seen:
|
||||
|
||||
@@ -183,10 +183,12 @@ class WebSocketProxyClientTask(BaseTask):
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Read messages from the WebSocket and put them on the local bus."""
|
||||
assert self._ws is not None, "on_activated() must run before _receive_loop"
|
||||
try:
|
||||
async for data in self._ws:
|
||||
try:
|
||||
message = self._serializer.deserialize(data)
|
||||
payload = data if isinstance(data, bytes) else data.encode()
|
||||
message = self._serializer.deserialize(payload)
|
||||
if not message:
|
||||
continue
|
||||
|
||||
|
||||
@@ -200,6 +200,7 @@ class WebSocketProxyServerTask(BaseTask):
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Read messages from the WebSocket and put them on the local bus."""
|
||||
assert self._ws is not None, "start() must run before _receive_loop"
|
||||
try:
|
||||
while True:
|
||||
data = await self._ws.receive_bytes()
|
||||
|
||||
Reference in New Issue
Block a user