Compare commits
1 Commits
mb/cli
...
aleix/llm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7214af9a88 |
@@ -50,8 +50,9 @@ class FrameProcessor(BaseObject):
|
||||
self,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
metrics: Optional[FrameProcessorMetrics] = None,
|
||||
enable_process_frame_watchdog: bool = True,
|
||||
enable_watchdog_logging: Optional[bool] = None,
|
||||
metrics: Optional[FrameProcessorMetrics] = None,
|
||||
watchdog_timeout_secs: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -60,6 +61,9 @@ class FrameProcessor(BaseObject):
|
||||
self._prev: Optional["FrameProcessor"] = None
|
||||
self._next: Optional["FrameProcessor"] = None
|
||||
|
||||
# Enable watchdog timers for the frame processing task.
|
||||
self._enable_process_frame_watchdog = enable_process_frame_watchdog
|
||||
|
||||
# Enable watchdog logging for all tasks created by this frame processor.
|
||||
self._enable_watchdog_logging = enable_watchdog_logging
|
||||
|
||||
@@ -416,7 +420,8 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
(frame, direction, callback) = await self.__input_queue.get()
|
||||
try:
|
||||
self.start_watchdog()
|
||||
if self._enable_process_frame_watchdog:
|
||||
self.start_watchdog()
|
||||
# Process the frame.
|
||||
await self.process_frame(frame, direction)
|
||||
# If this frame has an associated callback, call it now.
|
||||
@@ -427,7 +432,8 @@ class FrameProcessor(BaseObject):
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
finally:
|
||||
self.__input_queue.task_done()
|
||||
self.reset_watchdog()
|
||||
if self._enable_process_frame_watchdog:
|
||||
self.reset_watchdog()
|
||||
|
||||
def __create_push_task(self):
|
||||
if not self.__push_frame_task:
|
||||
|
||||
@@ -95,7 +95,7 @@ class AnthropicLLMService(LLMService):
|
||||
client=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(enable_process_frame_watchdog=False, **kwargs)
|
||||
params = params or AnthropicLLMService.InputParams()
|
||||
self._client = client or AsyncAnthropic(
|
||||
api_key=api_key
|
||||
@@ -206,6 +206,8 @@ class AnthropicLLMService(LLMService):
|
||||
async for event in response:
|
||||
# Aggregate streaming content, create frames, trigger events
|
||||
|
||||
self.start_watchdog()
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
if hasattr(event.delta, "text"):
|
||||
await self.push_frame(LLMTextFrame(event.delta.text))
|
||||
@@ -279,6 +281,8 @@ class AnthropicLLMService(LLMService):
|
||||
if total_input_tokens >= 1024:
|
||||
context.turns_above_cache_threshold += 1
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -292,6 +296,7 @@ class AnthropicLLMService(LLMService):
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
comp_tokens = (
|
||||
|
||||
@@ -540,7 +540,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
client_config: Optional[Config] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(enable_process_frame_watchdog=False, **kwargs)
|
||||
|
||||
params = params or AWSBedrockLLMService.InputParams()
|
||||
|
||||
@@ -711,6 +711,8 @@ class AWSBedrockLLMService(LLMService):
|
||||
|
||||
function_calls = []
|
||||
for event in response["stream"]:
|
||||
self.start_watchdog()
|
||||
|
||||
# Handle text content
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
@@ -762,6 +764,9 @@ class AWSBedrockLLMService(LLMService):
|
||||
completion_tokens += usage.get("outputTokens", 0)
|
||||
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
|
||||
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
except asyncio.CancelledError:
|
||||
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
|
||||
@@ -774,6 +779,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
comp_tokens = (
|
||||
|
||||
@@ -475,7 +475,7 @@ class GoogleLLMService(LLMService):
|
||||
tool_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(enable_process_frame_watchdog=False, **kwargs)
|
||||
|
||||
params = params or GoogleLLMService.InputParams()
|
||||
|
||||
@@ -558,6 +558,8 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
function_calls = []
|
||||
async for chunk in response:
|
||||
self.start_watchdog()
|
||||
|
||||
# Stop TTFB metrics after the first chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
if chunk.usage_metadata:
|
||||
@@ -566,6 +568,7 @@ class GoogleLLMService(LLMService):
|
||||
total_tokens += chunk.usage_metadata.total_token_count or 0
|
||||
|
||||
if not chunk.candidates:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
for candidate in chunk.candidates:
|
||||
@@ -626,12 +629,15 @@ class GoogleLLMService(LLMService):
|
||||
"origins": origins,
|
||||
}
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
if grounding_metadata and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
search_result=search_result,
|
||||
|
||||
@@ -54,6 +54,8 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
self.start_watchdog()
|
||||
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
@@ -63,11 +65,13 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
@@ -100,6 +104,8 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
|
||||
@@ -77,7 +77,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(enable_process_frame_watchdog=False, **kwargs)
|
||||
|
||||
params = params or BaseOpenAILLMService.InputParams()
|
||||
|
||||
@@ -193,6 +193,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
self.start_watchdog()
|
||||
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
@@ -202,11 +204,13 @@ class BaseOpenAILLMService(LLMService):
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
@@ -246,6 +250,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
):
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
@@ -301,3 +307,4 @@ class BaseOpenAILLMService(LLMService):
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self.reset_watchdog()
|
||||
|
||||
@@ -95,6 +95,8 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
self.start_watchdog()
|
||||
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
@@ -104,11 +106,13 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
@@ -148,6 +152,8 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
):
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
|
||||
@@ -288,14 +288,15 @@ class TaskManager(BaseTaskManager):
|
||||
logger.warning(f"Unable to start watchdog timer: task {name} does not exist")
|
||||
|
||||
def reset_watchdog(self, task: asyncio.Task):
|
||||
"""Resets the given task watchdog timer. If not reset, a warning will be
|
||||
logged indicating the task is stalling.
|
||||
"""Resets the given task watchdog timer. If not reset on time, a warning
|
||||
will be logged indicating the task is stalling.
|
||||
|
||||
"""
|
||||
name = task.get_name()
|
||||
if name in self._tasks:
|
||||
self._tasks[name].watchdog_start.clear()
|
||||
self._tasks[name].watchdog_timer.set()
|
||||
if self._tasks[name].watchdog_start.is_set():
|
||||
self._tasks[name].watchdog_start.clear()
|
||||
self._tasks[name].watchdog_timer.set()
|
||||
else:
|
||||
logger.warning(f"Unable to reset watchdog timer: task {name} does not exist")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user