linting fixes to anthropic.py

This commit is contained in:
Kwindla Hultman Kramer
2024-08-14 17:27:00 -07:00
parent 48f68ba6dc
commit a6d90b0a00
2 changed files with 43 additions and 44 deletions

View File

@@ -1,6 +1,5 @@
WARNING: --strip-extras is becoming the default in version 8.0.0. To silence this warning, either use --strip-extras to opt into the new default or use --no-strip-extras to retain the existing behavior.
#
# This file is autogenerated by pip-compile with Python 3.11
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile --all-extras pyproject.toml
@@ -13,7 +12,6 @@ aiohttp==3.9.5
# langchain
# langchain-community
# pipecat-ai (pyproject.toml)
# together
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
@@ -29,6 +27,10 @@ anyio==4.4.0
# openai
# starlette
# watchfiles
async-timeout==4.0.3
# via
# aiohttp
# langchain
attrs==23.2.0
# via
# aiohttp
@@ -51,7 +53,6 @@ charset-normalizer==3.3.2
click==8.1.7
# via
# flask
# together
# typer
# uvicorn
coloredlogs==15.0.1
@@ -76,8 +77,8 @@ einops==0.8.0
# via pipecat-ai (pyproject.toml)
email-validator==2.2.0
# via fastapi
eval-type-backport==0.2.0
# via together
exceptiongroup==1.2.2
# via anyio
fal-client==0.4.1
# via pipecat-ai (pyproject.toml)
fastapi==0.111.1
@@ -90,7 +91,6 @@ filelock==3.15.4
# via
# huggingface-hub
# pyht
# together
# torch
# transformers
flask==3.0.3
@@ -192,13 +192,13 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.13
langchain==0.2.12
# via
# langchain-community
# pipecat-ai (pyproject.toml)
langchain-community==0.2.12
langchain-community==0.2.11
# via pipecat-ai (pyproject.toml)
langchain-core==0.2.30
langchain-core==0.2.29
# via
# langchain
# langchain-community
@@ -208,7 +208,7 @@ langchain-openai==0.1.20
# via pipecat-ai (pyproject.toml)
langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.99
langsmith==0.1.98
# via
# langchain
# langchain-community
@@ -247,11 +247,9 @@ numpy==1.26.4
# numba
# onnxruntime
# pipecat-ai (pyproject.toml)
# pyarrow
# pyloudnorm
# resampy
# scipy
# together
# torchvision
# transformers
onnxruntime==1.18.1
@@ -277,7 +275,6 @@ packaging==24.1
pillow==10.3.0
# via
# pipecat-ai (pyproject.toml)
# together
# torchvision
proto-plus==1.24.0
# via
@@ -294,8 +291,6 @@ protobuf==4.25.4
# pipecat-ai (pyproject.toml)
# proto-plus
# pyht
pyarrow==17.0.0
# via together
pyasn1==0.6.0
# via
# pyasn1-modules
@@ -313,7 +308,6 @@ pydantic==2.8.2
# langchain-core
# langsmith
# openai
# together
pydantic-core==2.20.1
# via pydantic
pygments==2.18.0
@@ -355,7 +349,6 @@ requests==2.32.3
# langsmith
# pyht
# tiktoken
# together
# transformers
resampy==0.4.3
# via pipecat-ai (pyproject.toml)
@@ -387,12 +380,10 @@ sqlalchemy==2.0.32
# langchain-community
starlette==0.37.2
# via fastapi
sympy==1.13.2
sympy==1.13.1
# via
# onnxruntime
# torch
tabulate==0.9.0
# via together
tenacity==8.5.0
# via
# langchain
@@ -402,8 +393,6 @@ tiktoken==0.7.0
# via langchain-openai
timm==0.9.16
# via pipecat-ai (pyproject.toml)
together==1.2.7
# via pipecat-ai (pyproject.toml)
tokenizers==0.19.1
# via
# anthropic
@@ -424,17 +413,15 @@ tqdm==4.66.5
# google-generativeai
# huggingface-hub
# openai
# together
# transformers
transformers==4.40.2
# via pipecat-ai (pyproject.toml)
typer==0.12.3
# via
# fastapi-cli
# together
# via fastapi-cli
typing-extensions==4.12.2
# via
# anthropic
# anyio
# deepgram-sdk
# fastapi
# google-generativeai
@@ -448,13 +435,14 @@ typing-extensions==4.12.2
# torch
# typer
# typing-inspect
# uvicorn
typing-inspect==0.9.0
# via dataclasses-json
uritemplate==4.1.1
# via google-api-python-client
urllib3==2.2.2
# via requests
uvicorn[standard]==0.30.6
uvicorn[standard]==0.30.5
# via
# fastapi
# fastapi-cli

View File

@@ -30,8 +30,14 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator, LLMAssistantContextAggregator
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame
)
from pipecat.processors.aggregators.llm_response import (
LLMUserContextAggregator,
LLMAssistantContextAggregator
)
from loguru import logger
@@ -40,7 +46,8 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. Also, set `ANTHROPIC_API_KEY` environment variable.")
"In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. " +
"Also, set `ANTHROPIC_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")
@@ -81,7 +88,7 @@ class AnthropicLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
@ staticmethod
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
user = AnthropicUserContextAggregator(context)
assistant = AnthropicAssistantContextAggregator(user)
@@ -140,16 +147,17 @@ class AnthropicLLMService(LLMService):
if event.content_block.type == "tool_use":
tool_use_block = event.content_block
json_accumulator = ''
elif (event.type == "message_delta" and
hasattr(event.delta, 'stop_reason') and event.delta.stop_reason == 'tool_use'):
elif ((event.type == "message_delta" and
hasattr(event.delta, 'stop_reason')
and event.delta.stop_reason == 'tool_use')):
if tool_use_block:
await self.call_function(context=context,
tool_call_id=tool_use_block.id,
function_name=tool_use_block.name,
arguments=json.loads(json_accumulator))
# Calculate usage. Do this here in its own if statement, because there may be usage data
# embedded in messages that we do other processing for, above.
# Calculate usage. Do this here in its own if statement, because there may be usage
# data embedded in messages that we do other processing for, above.
if hasattr(event, "usage"):
prompt_tokens += event.usage.input_tokens if hasattr(
event.usage, "input_tokens") else 0
@@ -161,7 +169,7 @@ class AnthropicLLMService(LLMService):
completion_tokens += event.message.usage.output_tokens if hasattr(
event.message.usage, "output_tokens") else 0
except CancelledError as e:
except CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
# token estimate. The reraise the exception so all the processors running in this task
# also get cancelled.
@@ -174,7 +182,8 @@ class AnthropicLLMService(LLMService):
await self.push_frame(LLMFullResponseEndFrame())
await self._report_usage_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens if not use_completion_tokens_estimate else completion_tokens_estimate)
completion_tokens=(completion_tokens if not use_completion_tokens_estimate
else completion_tokens_estimate))
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -200,7 +209,8 @@ class AnthropicLLMService(LLMService):
await self._process_context(context)
async def request_image_frame(self, user_id: str, *, text_content: str = None):
await self.push_frame(UserImageRequestFrame(user_id=user_id, context=text_content), FrameDirection.UPSTREAM)
await self.push_frame(UserImageRequestFrame(user_id=user_id, context=text_content),
FrameDirection.UPSTREAM)
def _estimate_tokens(self, text: str) -> int:
return int(len(re.split(r'[^\w]+', text)) * 1.3)
@@ -231,7 +241,7 @@ class AnthropicLLMContext(OpenAILLMContext):
self.system_message = system
@ classmethod
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
self = cls(
messages=openai_context.messages,
@@ -252,11 +262,11 @@ class AnthropicLLMContext(OpenAILLMContext):
self.messages.pop(0)
return self
@ classmethod
@classmethod
def from_messages(cls, messages: List[dict]) -> "AnthropicLLMContext":
return cls(messages=messages)
@ classmethod
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AnthropicLLMContext":
context = cls()
context.add_image_frame_message(
@@ -389,12 +399,13 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
elif isinstance(frame, FunctionCallInProgressFrame):
self._function_call_in_progress = frame
elif isinstance(frame, FunctionCallResultFrame):
if self._function_call_in_progress and self._function_call_in_progress.tool_call_id == frame.tool_call_id:
if (self._function_call_in_progress and self._function_call_in_progress.tool_call_id ==
frame.tool_call_id):
self._function_call_in_progress = None
self._function_call_result = frame
else:
logger.warning(
f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id")
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id")
self._function_call_in_progress = None
self._function_call_result = None
elif isinstance(frame, AnthropicImageMessageFrame):