okay, both files now
This commit is contained in:
@@ -79,15 +79,21 @@ class AnthropicLLMService(LLMService):
|
||||
api_key: str,
|
||||
model: str = "claude-3-5-sonnet-20240620",
|
||||
max_tokens: int = 4096,
|
||||
enable_prompt_caching_beta: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncAnthropic(api_key=api_key)
|
||||
self._model = model
|
||||
self._max_tokens = max_tokens
|
||||
self._enable_prompt_caching_beta = enable_prompt_caching_beta
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def enable_prompt_caching_beta(self) -> bool:
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
@@ -98,6 +104,17 @@ class AnthropicLLMService(LLMService):
|
||||
)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
# and use that estimate if we are interrupted, because we almost certainly won't
|
||||
# get a complete usage report if the task we're running in is cancelled.
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_tokens_estimate = 0
|
||||
use_completion_tokens_estimate = False
|
||||
cache_creation_input_tokens = 0
|
||||
cache_read_input_tokens = 0
|
||||
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
@@ -106,13 +123,19 @@ class AnthropicLLMService(LLMService):
|
||||
f"Generating chat: {context.system} | {context.get_messages_for_logging()}")
|
||||
|
||||
messages = context.messages
|
||||
if self._enable_prompt_caching_beta:
|
||||
messages = context.get_messages_with_cache_control_markers()
|
||||
|
||||
api_call = self._client.messages.create
|
||||
if self._enable_prompt_caching_beta:
|
||||
api_call = self._client.beta.prompt_caching.messages.create
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response = await self._client.messages.create(
|
||||
response = await api_call(
|
||||
tools=context.tools or [],
|
||||
system=context.system or [],
|
||||
messages=messages,
|
||||
tools=context.tools or [],
|
||||
model=self._model,
|
||||
max_tokens=self._max_tokens,
|
||||
stream=True)
|
||||
@@ -123,15 +146,6 @@ class AnthropicLLMService(LLMService):
|
||||
tool_use_block = None
|
||||
json_accumulator = ''
|
||||
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
# and use that estimate if we are interrupted, because we almost certainly won't
|
||||
# get a complete usage report if the task we're running in is cancelled.
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_tokens_estimate = 0
|
||||
use_completion_tokens_estimate = False
|
||||
|
||||
async for event in response:
|
||||
# logger.debug(f"Anthropic LLM event: {event}")
|
||||
|
||||
@@ -170,6 +184,15 @@ class AnthropicLLMService(LLMService):
|
||||
event.message.usage, "input_tokens") else 0
|
||||
completion_tokens += event.message.usage.output_tokens if hasattr(
|
||||
event.message.usage, "output_tokens") else 0
|
||||
if hasattr(event.message.usage, "cache_creation_input_tokens"):
|
||||
cache_creation_input_tokens += event.message.usage.cache_creation_input_tokens
|
||||
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
|
||||
if hasattr(event.message.usage, "cache_read_input_tokens"):
|
||||
cache_read_input_tokens += event.message.usage.cache_read_input_tokens
|
||||
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
|
||||
total_input_tokens = prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
if total_input_tokens >= 1024:
|
||||
context.turns_above_cache_threshold += 1
|
||||
|
||||
except CancelledError:
|
||||
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
|
||||
@@ -241,6 +264,12 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self._user_image_request_context = {}
|
||||
|
||||
# For beta prompt caching. This is a counter that tracks the number of turns
|
||||
# we've seen above the cache threshold. We reset this when we reset the
|
||||
# messages list. We only care about this number being 0, 1, or 2. But
|
||||
# it's easiest just to treat it as a counter.
|
||||
self.turns_above_cache_threshold = 0
|
||||
|
||||
self.system = system
|
||||
|
||||
@classmethod
|
||||
@@ -270,6 +299,7 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
return context
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
self.turns_above_cache_threshold = 0
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
@@ -313,6 +343,23 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def get_messages_with_cache_control_markers(self) -> List[dict]:
|
||||
try:
|
||||
messages = copy.deepcopy(self.messages)
|
||||
if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user":
|
||||
if isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}]
|
||||
messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
if (self.turns_above_cache_threshold >= 2 and
|
||||
len(messages) > 2 and messages[-3]["role"] == "user"):
|
||||
if isinstance(messages[-3]["content"], str):
|
||||
messages[-3]["content"] = [{"type": "text", "text": messages[-3]["content"]}]
|
||||
messages[-3]["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding cache control marker: {e}")
|
||||
return self.messages
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
# See if we should pull the system message out of our context.messages list. (For
|
||||
# compatibility with Open AI messages format.)
|
||||
|
||||
Reference in New Issue
Block a user