okay, both files now

This commit is contained in:
Kwindla Hultman Kramer
2024-08-15 00:57:10 -07:00
parent 6e0dd4a779
commit 94deec01c9

View File

@@ -79,15 +79,21 @@ class AnthropicLLMService(LLMService):
api_key: str,
model: str = "claude-3-5-sonnet-20240620",
max_tokens: int = 4096,
enable_prompt_caching_beta: bool = False,
**kwargs):
super().__init__(**kwargs)
self._client = AsyncAnthropic(api_key=api_key)
self._model = model
self._max_tokens = max_tokens
self._enable_prompt_caching_beta = enable_prompt_caching_beta
def can_generate_metrics(self) -> bool:
return True
@property
def enable_prompt_caching_beta(self) -> bool:
return self._enable_prompt_caching_beta
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
user = AnthropicUserContextAggregator(context)
@@ -98,6 +104,17 @@ class AnthropicLLMService(LLMService):
)
async def _process_context(self, context: OpenAILLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
# and use that estimate if we are interrupted, because we almost certainly won't
# get a complete usage report if the task we're running in is cancelled.
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
use_completion_tokens_estimate = False
cache_creation_input_tokens = 0
cache_read_input_tokens = 0
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
@@ -106,13 +123,19 @@ class AnthropicLLMService(LLMService):
f"Generating chat: {context.system} | {context.get_messages_for_logging()}")
messages = context.messages
if self._enable_prompt_caching_beta:
messages = context.get_messages_with_cache_control_markers()
api_call = self._client.messages.create
if self._enable_prompt_caching_beta:
api_call = self._client.beta.prompt_caching.messages.create
await self.start_ttfb_metrics()
response = await self._client.messages.create(
response = await api_call(
tools=context.tools or [],
system=context.system or [],
messages=messages,
tools=context.tools or [],
model=self._model,
max_tokens=self._max_tokens,
stream=True)
@@ -123,15 +146,6 @@ class AnthropicLLMService(LLMService):
tool_use_block = None
json_accumulator = ''
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
# and use that estimate if we are interrupted, because we almost certainly won't
# get a complete usage report if the task we're running in is cancelled.
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
use_completion_tokens_estimate = False
async for event in response:
# logger.debug(f"Anthropic LLM event: {event}")
@@ -170,6 +184,15 @@ class AnthropicLLMService(LLMService):
event.message.usage, "input_tokens") else 0
completion_tokens += event.message.usage.output_tokens if hasattr(
event.message.usage, "output_tokens") else 0
if hasattr(event.message.usage, "cache_creation_input_tokens"):
cache_creation_input_tokens += event.message.usage.cache_creation_input_tokens
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
if hasattr(event.message.usage, "cache_read_input_tokens"):
cache_read_input_tokens += event.message.usage.cache_read_input_tokens
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
total_input_tokens = prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
if total_input_tokens >= 1024:
context.turns_above_cache_threshold += 1
except CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
@@ -241,6 +264,12 @@ class AnthropicLLMContext(OpenAILLMContext):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self._user_image_request_context = {}
# For beta prompt caching. This is a counter that tracks the number of turns
# we've seen above the cache threshold. We reset this when we reset the
# messages list. We only care about this number being 0, 1, or 2. But
# it's easiest just to treat it as a counter.
self.turns_above_cache_threshold = 0
self.system = system
@classmethod
@@ -270,6 +299,7 @@ class AnthropicLLMContext(OpenAILLMContext):
return context
def set_messages(self, messages: List):
self.turns_above_cache_threshold = 0
self._messages[:] = messages
self._restructure_from_openai_messages()
@@ -313,6 +343,23 @@ class AnthropicLLMContext(OpenAILLMContext):
except Exception as e:
logger.error(f"Error adding message: {e}")
def get_messages_with_cache_control_markers(self) -> List[dict]:
try:
messages = copy.deepcopy(self.messages)
if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user":
if isinstance(messages[-1]["content"], str):
messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}]
messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
if (self.turns_above_cache_threshold >= 2 and
len(messages) > 2 and messages[-3]["role"] == "user"):
if isinstance(messages[-3]["content"], str):
messages[-3]["content"] = [{"type": "text", "text": messages[-3]["content"]}]
messages[-3]["content"][-1]["cache_control"] = {"type": "ephemeral"}
return messages
except Exception as e:
logger.error(f"Error adding cache control marker: {e}")
return self.messages
def _restructure_from_openai_messages(self):
# See if we should pull the system message out of our context.messages list. (For
# compatibility with Open AI messages format.)