Compare commits

...

1 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
7214af9a88 allow LLM services to manage watchdog timers 2025-06-24 16:43:10 -07:00
8 changed files with 54 additions and 11 deletions

View File

@@ -50,8 +50,9 @@ class FrameProcessor(BaseObject):
self,
*,
name: Optional[str] = None,
metrics: Optional[FrameProcessorMetrics] = None,
enable_process_frame_watchdog: bool = True,
enable_watchdog_logging: Optional[bool] = None,
metrics: Optional[FrameProcessorMetrics] = None,
watchdog_timeout_secs: Optional[float] = None,
**kwargs,
):
@@ -60,6 +61,9 @@ class FrameProcessor(BaseObject):
self._prev: Optional["FrameProcessor"] = None
self._next: Optional["FrameProcessor"] = None
# Enable watchdog timers for the frame processing task.
self._enable_process_frame_watchdog = enable_process_frame_watchdog
# Enable watchdog logging for all tasks created by this frame processor.
self._enable_watchdog_logging = enable_watchdog_logging
@@ -416,7 +420,8 @@ class FrameProcessor(BaseObject):
(frame, direction, callback) = await self.__input_queue.get()
try:
self.start_watchdog()
if self._enable_process_frame_watchdog:
self.start_watchdog()
# Process the frame.
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
@@ -427,7 +432,8 @@ class FrameProcessor(BaseObject):
await self.push_error(ErrorFrame(str(e)))
finally:
self.__input_queue.task_done()
self.reset_watchdog()
if self._enable_process_frame_watchdog:
self.reset_watchdog()
def __create_push_task(self):
if not self.__push_frame_task:

View File

@@ -95,7 +95,7 @@ class AnthropicLLMService(LLMService):
client=None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(enable_process_frame_watchdog=False, **kwargs)
params = params or AnthropicLLMService.InputParams()
self._client = client or AsyncAnthropic(
api_key=api_key
@@ -206,6 +206,8 @@ class AnthropicLLMService(LLMService):
async for event in response:
# Aggregate streaming content, create frames, trigger events
self.start_watchdog()
if event.type == "content_block_delta":
if hasattr(event.delta, "text"):
await self.push_frame(LLMTextFrame(event.delta.text))
@@ -279,6 +281,8 @@ class AnthropicLLMService(LLMService):
if total_input_tokens >= 1024:
context.turns_above_cache_threshold += 1
self.reset_watchdog()
await self.run_function_calls(function_calls)
except asyncio.CancelledError:
@@ -292,6 +296,7 @@ class AnthropicLLMService(LLMService):
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
self.reset_watchdog()
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
comp_tokens = (

View File

@@ -540,7 +540,7 @@ class AWSBedrockLLMService(LLMService):
client_config: Optional[Config] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(enable_process_frame_watchdog=False, **kwargs)
params = params or AWSBedrockLLMService.InputParams()
@@ -711,6 +711,8 @@ class AWSBedrockLLMService(LLMService):
function_calls = []
for event in response["stream"]:
self.start_watchdog()
# Handle text content
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
@@ -762,6 +764,9 @@ class AWSBedrockLLMService(LLMService):
completion_tokens += usage.get("outputTokens", 0)
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
self.reset_watchdog()
await self.run_function_calls(function_calls)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
@@ -774,6 +779,7 @@ class AWSBedrockLLMService(LLMService):
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
self.reset_watchdog()
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
comp_tokens = (

View File

@@ -475,7 +475,7 @@ class GoogleLLMService(LLMService):
tool_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(enable_process_frame_watchdog=False, **kwargs)
params = params or GoogleLLMService.InputParams()
@@ -558,6 +558,8 @@ class GoogleLLMService(LLMService):
function_calls = []
async for chunk in response:
self.start_watchdog()
# Stop TTFB metrics after the first chunk
await self.stop_ttfb_metrics()
if chunk.usage_metadata:
@@ -566,6 +568,7 @@ class GoogleLLMService(LLMService):
total_tokens += chunk.usage_metadata.total_token_count or 0
if not chunk.candidates:
self.reset_watchdog()
continue
for candidate in chunk.candidates:
@@ -626,12 +629,15 @@ class GoogleLLMService(LLMService):
"origins": origins,
}
self.reset_watchdog()
await self.run_function_calls(function_calls)
except DeadlineExceeded:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
self.reset_watchdog()
if grounding_metadata and isinstance(grounding_metadata, dict):
llm_search_frame = LLMSearchResponseFrame(
search_result=search_result,

View File

@@ -54,6 +54,8 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
)
async for chunk in chunk_stream:
self.start_watchdog()
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
@@ -63,11 +65,13 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
await self.start_llm_usage_metrics(tokens)
if chunk.choices is None or len(chunk.choices) == 0:
self.reset_watchdog()
continue
await self.stop_ttfb_metrics()
if not chunk.choices[0].delta:
self.reset_watchdog()
continue
if chunk.choices[0].delta.tool_calls:
@@ -100,6 +104,8 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
self.reset_watchdog()
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to
# the context, and re-prompt to get a chat answer. If we don't have a registered

View File

@@ -77,7 +77,7 @@ class BaseOpenAILLMService(LLMService):
params: Optional[InputParams] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(enable_process_frame_watchdog=False, **kwargs)
params = params or BaseOpenAILLMService.InputParams()
@@ -193,6 +193,8 @@ class BaseOpenAILLMService(LLMService):
)
async for chunk in chunk_stream:
self.start_watchdog()
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
@@ -202,11 +204,13 @@ class BaseOpenAILLMService(LLMService):
await self.start_llm_usage_metrics(tokens)
if chunk.choices is None or len(chunk.choices) == 0:
self.reset_watchdog()
continue
await self.stop_ttfb_metrics()
if not chunk.choices[0].delta:
self.reset_watchdog()
continue
if chunk.choices[0].delta.tool_calls:
@@ -246,6 +250,8 @@ class BaseOpenAILLMService(LLMService):
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
self.reset_watchdog()
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to
# the context, and re-prompt to get a chat answer. If we don't have a registered
@@ -301,3 +307,4 @@ class BaseOpenAILLMService(LLMService):
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
self.reset_watchdog()

View File

@@ -95,6 +95,8 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
)
async for chunk in chunk_stream:
self.start_watchdog()
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
@@ -104,11 +106,13 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
await self.start_llm_usage_metrics(tokens)
if chunk.choices is None or len(chunk.choices) == 0:
self.reset_watchdog()
continue
await self.stop_ttfb_metrics()
if not chunk.choices[0].delta:
self.reset_watchdog()
continue
if chunk.choices[0].delta.tool_calls:
@@ -148,6 +152,8 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
self.reset_watchdog()
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to
# the context, and re-prompt to get a chat answer. If we don't have a registered

View File

@@ -288,14 +288,15 @@ class TaskManager(BaseTaskManager):
logger.warning(f"Unable to start watchdog timer: task {name} does not exist")
def reset_watchdog(self, task: asyncio.Task):
"""Resets the given task watchdog timer. If not reset, a warning will be
logged indicating the task is stalling.
"""Resets the given task watchdog timer. If not reset on time, a warning
will be logged indicating the task is stalling.
"""
name = task.get_name()
if name in self._tasks:
self._tasks[name].watchdog_start.clear()
self._tasks[name].watchdog_timer.set()
if self._tasks[name].watchdog_start.is_set():
self._tasks[name].watchdog_start.clear()
self._tasks[name].watchdog_timer.set()
else:
logger.warning(f"Unable to reset watchdog timer: task {name} does not exist")