code-assistant: work around CancelledError swallow in ClaudeSDKClient
claude_agent_sdk's _AsyncioTaskHandle.wait() uses `with suppress(asyncio.CancelledError)` to silence the inner read task's expected cancellation, but it also swallows the outer task's cancellation if it lands on the same await — causing cancel_task to time out. Bypass `async with ClaudeSDKClient` and drive connect/disconnect ourselves so disconnect() runs in a finally where the outer CancelledError has already been raised and suspended by Python's exception machinery, out of reach of the SDK's suppress.
This commit is contained in:
@@ -10,11 +10,7 @@ import asyncio
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.bus import (
|
||||
BusCancelTaskMessage,
|
||||
BusEndTaskMessage,
|
||||
BusJobRequestMessage,
|
||||
)
|
||||
from pipecat.bus import BusJobRequestMessage
|
||||
from pipecat.pipeline.base_task import BaseTask
|
||||
from pipecat.pipeline.job_context import JobStatus
|
||||
|
||||
@@ -66,7 +62,7 @@ class CodeWorker(BaseTask):
|
||||
async def start(self) -> None:
|
||||
"""Launch the Claude SDK worker loop alongside the standard task start."""
|
||||
await super().start()
|
||||
self._worker_task = self.create_task(self._worker_loop(), f"{self.name}::worker")
|
||||
self._worker_task = self.create_task(self._worker_loop(), "worker")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Cancel the worker loop before tearing down the task."""
|
||||
@@ -81,40 +77,42 @@ class CodeWorker(BaseTask):
|
||||
logger.info(f"Worker '{self.name}': queued '{message.payload['question']}'")
|
||||
self._queue.put_nowait(message)
|
||||
|
||||
async def _handle_task_end(self, message: BusEndTaskMessage) -> None:
|
||||
"""Signal the run loop to finish on a graceful end."""
|
||||
await super()._handle_task_end(message)
|
||||
self._finished_event.set()
|
||||
|
||||
async def _handle_task_cancel(self, message: BusCancelTaskMessage) -> None:
|
||||
"""Signal the run loop to finish on cancellation."""
|
||||
await super()._handle_task_cancel(message)
|
||||
self._finished_event.set()
|
||||
|
||||
async def _worker_loop(self):
|
||||
client = ClaudeSDKClient(options=self._claude_options)
|
||||
try:
|
||||
async with ClaudeSDKClient(options=self._claude_options) as client:
|
||||
while True:
|
||||
message = await self._queue.get()
|
||||
question = message.payload["question"]
|
||||
logger.info(f"Worker '{self.name}': researching '{question}'")
|
||||
|
||||
try:
|
||||
answer = ""
|
||||
await client.query(prompt=question)
|
||||
async for msg in client.receive_response():
|
||||
if type(msg).__name__ == "AssistantMessage":
|
||||
for block in msg.content:
|
||||
if type(block).__name__ == "TextBlock":
|
||||
answer += block.text
|
||||
|
||||
logger.info(f"Worker '{self.name}': completed ({len(answer)} chars)")
|
||||
await self.send_job_response(message.job_id, {"answer": answer})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker '{self.name}': error: {e}")
|
||||
await self.send_job_response(
|
||||
message.job_id, {"error": str(e)}, status=JobStatus.ERROR
|
||||
)
|
||||
await client.connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Worker '{self.name}': failed to start Claude SDK: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await self._queue.get()
|
||||
question = message.payload["question"]
|
||||
logger.info(f"Worker '{self.name}': researching '{question}'")
|
||||
|
||||
try:
|
||||
answer = ""
|
||||
await client.query(prompt=question)
|
||||
async for msg in client.receive_response():
|
||||
if type(msg).__name__ == "AssistantMessage":
|
||||
for block in msg.content:
|
||||
if type(block).__name__ == "TextBlock":
|
||||
answer += block.text
|
||||
|
||||
logger.info(f"Worker '{self.name}': completed ({len(answer)} chars)")
|
||||
await self.send_job_response(message.job_id, {"answer": answer})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker '{self.name}': error: {e}")
|
||||
await self.send_job_response(
|
||||
message.job_id, {"error": str(e)}, status=JobStatus.ERROR
|
||||
)
|
||||
finally:
|
||||
# Bypass `async with ClaudeSDKClient` and call disconnect()
|
||||
# ourselves: __aexit__ → Query.close() → _read_task.wait() uses
|
||||
# `with suppress(asyncio.CancelledError)`, which would swallow the
|
||||
# outer task's cancellation. By the time this finally runs, our
|
||||
# CancelledError has already been raised once, so _must_cancel is
|
||||
# cleared and disconnect()'s awaits proceed normally.
|
||||
await client.disconnect()
|
||||
|
||||
Reference in New Issue
Block a user