diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 506d71243..088db7861 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -79,15 +79,21 @@ class AnthropicLLMService(LLMService): api_key: str, model: str = "claude-3-5-sonnet-20240620", max_tokens: int = 4096, + enable_prompt_caching_beta: bool = False, **kwargs): super().__init__(**kwargs) self._client = AsyncAnthropic(api_key=api_key) self._model = model self._max_tokens = max_tokens + self._enable_prompt_caching_beta = enable_prompt_caching_beta def can_generate_metrics(self) -> bool: return True + @property + def enable_prompt_caching_beta(self) -> bool: + return self._enable_prompt_caching_beta + @staticmethod def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair: user = AnthropicUserContextAggregator(context) @@ -98,6 +104,17 @@ class AnthropicLLMService(LLMService): ) async def _process_context(self, context: OpenAILLMContext): + # Usage tracking. We track the usage reported by Anthropic in prompt_tokens and + # completion_tokens. We also estimate the completion tokens from output text + # and use that estimate if we are interrupted, because we almost certainly won't + # get a complete usage report if the task we're running in is cancelled. + prompt_tokens = 0 + completion_tokens = 0 + completion_tokens_estimate = 0 + use_completion_tokens_estimate = False + cache_creation_input_tokens = 0 + cache_read_input_tokens = 0 + try: await self.push_frame(LLMFullResponseStartFrame()) await self.start_processing_metrics() @@ -106,13 +123,19 @@ class AnthropicLLMService(LLMService): f"Generating chat: {context.system} | {context.get_messages_for_logging()}") messages = context.messages + if self._enable_prompt_caching_beta: + messages = context.get_messages_with_cache_control_markers() + + api_call = self._client.messages.create + if self._enable_prompt_caching_beta: + api_call = self._client.beta.prompt_caching.messages.create await self.start_ttfb_metrics() - response = await self._client.messages.create( + response = await api_call( + tools=context.tools or [], system=context.system or [], messages=messages, - tools=context.tools or [], model=self._model, max_tokens=self._max_tokens, stream=True) @@ -123,15 +146,6 @@ class AnthropicLLMService(LLMService): tool_use_block = None json_accumulator = '' - # Usage tracking. We track the usage reported by Anthropic in prompt_tokens and - # completion_tokens. We also estimate the completion tokens from output text - # and use that estimate if we are interrupted, because we almost certainly won't - # get a complete usage report if the task we're running in is cancelled. - prompt_tokens = 0 - completion_tokens = 0 - completion_tokens_estimate = 0 - use_completion_tokens_estimate = False - async for event in response: # logger.debug(f"Anthropic LLM event: {event}") @@ -170,6 +184,15 @@ class AnthropicLLMService(LLMService): event.message.usage, "input_tokens") else 0 completion_tokens += event.message.usage.output_tokens if hasattr( event.message.usage, "output_tokens") else 0 + if hasattr(event.message.usage, "cache_creation_input_tokens"): + cache_creation_input_tokens += event.message.usage.cache_creation_input_tokens + logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}") + if hasattr(event.message.usage, "cache_read_input_tokens"): + cache_read_input_tokens += event.message.usage.cache_read_input_tokens + logger.debug(f"Cache read input tokens: {cache_read_input_tokens}") + total_input_tokens = prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens + if total_input_tokens >= 1024: + context.turns_above_cache_threshold += 1 except CancelledError: # If we're interrupted, we won't get a complete usage report. So set our flag to use the @@ -241,6 +264,12 @@ class AnthropicLLMContext(OpenAILLMContext): super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) self._user_image_request_context = {} + # For beta prompt caching. This is a counter that tracks the number of turns + # we've seen above the cache threshold. We reset this when we reset the + # messages list. We only care about this number being 0, 1, or 2. But + # it's easiest just to treat it as a counter. + self.turns_above_cache_threshold = 0 + self.system = system @classmethod @@ -270,6 +299,7 @@ class AnthropicLLMContext(OpenAILLMContext): return context def set_messages(self, messages: List): + self.turns_above_cache_threshold = 0 self._messages[:] = messages self._restructure_from_openai_messages() @@ -313,6 +343,23 @@ class AnthropicLLMContext(OpenAILLMContext): except Exception as e: logger.error(f"Error adding message: {e}") + def get_messages_with_cache_control_markers(self) -> List[dict]: + try: + messages = copy.deepcopy(self.messages) + if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user": + if isinstance(messages[-1]["content"], str): + messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}] + messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} + if (self.turns_above_cache_threshold >= 2 and + len(messages) > 2 and messages[-3]["role"] == "user"): + if isinstance(messages[-3]["content"], str): + messages[-3]["content"] = [{"type": "text", "text": messages[-3]["content"]}] + messages[-3]["content"][-1]["cache_control"] = {"type": "ephemeral"} + return messages + except Exception as e: + logger.error(f"Error adding cache control marker: {e}") + return self.messages + def _restructure_from_openai_messages(self): # See if we should pull the system message out of our context.messages list. (For # compatibility with Open AI messages format.)