diff --git a/src/pipecat/bus/bus.py b/src/pipecat/bus/bus.py index f8fec267b..14bb2a084 100644 --- a/src/pipecat/bus/bus.py +++ b/src/pipecat/bus/bus.py @@ -13,6 +13,7 @@ Provides the abstract `TaskBus` base class. Concrete implementations import asyncio from abc import abstractmethod from dataclasses import dataclass, field +from typing import cast from pipecat.bus.messages import BusLocalMessage, BusMessage from pipecat.bus.queue import BusMessageQueue @@ -167,7 +168,7 @@ class TaskBus(BaseObject): while True: message = await sub.queue.get() if isinstance(message, SystemFrame): - await sub.subscriber.on_bus_message(message) + await sub.subscriber.on_bus_message(cast(BusMessage, message)) else: sub.data_queue.put_nowait(message) except asyncio.CancelledError: diff --git a/src/pipecat/bus/network/__init__.py b/src/pipecat/bus/network/__init__.py index 2146207e3..4ff4848a2 100644 --- a/src/pipecat/bus/network/__init__.py +++ b/src/pipecat/bus/network/__init__.py @@ -11,6 +11,12 @@ package can be loaded with only the extras you need; importing a specific bus without its extra raises a clear error from that submodule. """ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pipecat.bus.network.pgmq import PgmqBus + from pipecat.bus.network.redis import RedisBus + __all__ = ["PgmqBus", "RedisBus"] diff --git a/src/pipecat/bus/network/redis.py b/src/pipecat/bus/network/redis.py index 784b2b5e0..55e2d55ea 100644 --- a/src/pipecat/bus/network/redis.py +++ b/src/pipecat/bus/network/redis.py @@ -68,8 +68,9 @@ class RedisBus(TaskBus): async def start(self): """Subscribe to Redis channel and start the reader task.""" await super().start() - self._pubsub = self._redis.pubsub() - await self._pubsub.subscribe(self._channel) + pubsub = self._redis.pubsub() + await pubsub.subscribe(self._channel) + self._pubsub = pubsub self._reader_task = self.create_task(self._reader_loop(), f"{self}::redis_reader") await asyncio.sleep(0) @@ -96,6 +97,7 @@ class RedisBus(TaskBus): async def _reader_loop(self) -> None: """Read messages from Redis pub/sub and deliver to subscribers.""" + assert self._pubsub is not None, "start() must be called before _reader_loop" async for raw_message in self._pubsub.listen(): if raw_message["type"] != "message": continue diff --git a/src/pipecat/pipeline/base_task.py b/src/pipecat/pipeline/base_task.py index 08a052024..ef7daeee3 100644 --- a/src/pipecat/pipeline/base_task.py +++ b/src/pipecat/pipeline/base_task.py @@ -441,7 +441,9 @@ class BaseTask(BaseObject, BusSubscriber): """Called when all workers in a job group have responded.""" pass - async def on_job_error(self, message: BusJobResponseMessage) -> None: + async def on_job_error( + self, message: BusJobResponseMessage | BusJobResponseUrgentMessage + ) -> None: """Called when a job group is cancelled due to a worker error. Fires when a worker responds with ``ERROR`` or ``FAILED`` status diff --git a/src/pipecat/pipeline/job_context.py b/src/pipecat/pipeline/job_context.py index 4cc42b1e3..222896c66 100644 --- a/src/pipecat/pipeline/job_context.py +++ b/src/pipecat/pipeline/job_context.py @@ -256,6 +256,7 @@ class JobGroupContext: ) return False + assert self._group is not None await self._group.wait() return False @@ -355,6 +356,7 @@ class JobContext: ) return False + assert self._group is not None try: await self._group.wait() except JobGroupError as e: diff --git a/src/pipecat/pipeline/task_ready_decorator.py b/src/pipecat/pipeline/task_ready_decorator.py index 982652a61..25933e32b 100644 --- a/src/pipecat/pipeline/task_ready_decorator.py +++ b/src/pipecat/pipeline/task_ready_decorator.py @@ -6,6 +6,8 @@ """Decorator for marking methods as task-ready handlers.""" +from collections.abc import Callable + def task_ready(*, name: str): """Mark a method as a handler for a specific task becoming ready. @@ -41,7 +43,7 @@ def _collect_task_ready_handlers(obj) -> dict: ValueError: If two handlers watch the same task name. """ seen: set[str] = set() - handlers: dict[str, object] = {} + handlers: dict[str, Callable] = {} for cls in type(obj).__mro__: for attr_name, val in cls.__dict__.items(): if attr_name in seen: diff --git a/src/pipecat/tasks/proxy/websocket/client.py b/src/pipecat/tasks/proxy/websocket/client.py index 56107b26e..ab7770588 100644 --- a/src/pipecat/tasks/proxy/websocket/client.py +++ b/src/pipecat/tasks/proxy/websocket/client.py @@ -183,10 +183,12 @@ class WebSocketProxyClientTask(BaseTask): async def _receive_loop(self) -> None: """Read messages from the WebSocket and put them on the local bus.""" + assert self._ws is not None, "on_activated() must run before _receive_loop" try: async for data in self._ws: try: - message = self._serializer.deserialize(data) + payload = data if isinstance(data, bytes) else data.encode() + message = self._serializer.deserialize(payload) if not message: continue diff --git a/src/pipecat/tasks/proxy/websocket/server.py b/src/pipecat/tasks/proxy/websocket/server.py index 8fa141a1d..e67ddfd8d 100644 --- a/src/pipecat/tasks/proxy/websocket/server.py +++ b/src/pipecat/tasks/proxy/websocket/server.py @@ -200,6 +200,7 @@ class WebSocketProxyServerTask(BaseTask): async def _receive_loop(self) -> None: """Read messages from the WebSocket and put them on the local bus.""" + assert self._ws is not None, "start() must run before _receive_loop" try: while True: data = await self._ws.receive_bytes()