attribution-2steps-method / app /services /service_anthropic.py
thanhnt-cf's picture
step 1
0dd08cb
import json
import os
from typing import Any, Dict, List, Type, Union
import anthropic
import weave
from anthropic import APIStatusError, AsyncAnthropic
from pydantic import BaseModel
from app.config import get_settings
from app.core import errors
from app.core.errors import BadRequestError, VendorError
from app.core.prompts import get_prompts
from app.services.base import BaseAttributionService
from app.utils.converter import product_data_to_str
from app.utils.image_processing import get_data_format, get_image_data
from app.utils.logger import exception_to_str, setup_logger
ENV = os.getenv("ENV", "LOCAL")
if ENV == "LOCAL": # local or demo
weave_project_name = "cfai/attribution-exp"
elif ENV == "DEV":
weave_project_name = "cfai/attribution-dev"
elif ENV == "UAT":
weave_project_name = "cfai/attribution-uat"
elif ENV == "PROD":
pass
# if ENV != "PROD":
# weave.init(project_name=weave_project_name)
settings = get_settings()
prompts = get_prompts()
logger = setup_logger(__name__)
class AnthropicService(BaseAttributionService):
def __init__(self):
self.client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
@weave.op
async def extract_attributes(
self,
attributes_model: Type[BaseModel],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
product_data: Dict[str, Union[str, List[str]]],
pil_images: List[Any] = None, # do not remove, this is for weave
img_paths: List[str] = None,
) -> Dict[str, Any]:
logger.info("Extracting info via Anthropic...")
tools = [
{
"name": "extract_garment_info",
"description": "Extracts key information from the image.",
"input_schema": attributes_model.model_json_schema(),
"cache_control": {"type": "ephemeral"},
}
]
if img_urls is not None:
image_messages = [
{
"type": "image",
"source": {"type": "url", "url": img_url},
}
for img_url in img_urls
]
elif img_paths is not None:
image_messages = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{get_data_format(img_path)}",
"data": get_image_data(img_path),
},
}
for img_path in img_paths
]
else:
# this is not expected, raise some errors here later.
pass
system_message = [{"type": "text", "text": prompts.GET_PERCENTAGE_SYSTEM_MESSAGE}]
text_messages = [
{
"type": "text",
"text": prompts.GET_PERCENTAGE_HUMAN_MESSAGE.format(
product_taxonomy=product_taxonomy,
product_data=product_data_to_str(product_data),
),
}
]
messages = [{"role": "user", "content": text_messages + image_messages}]
# try:
try:
response = await self.client.messages.create(
model=ai_model,
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
max_tokens=2048,
system=system_message,
tools=tools,
messages=messages,
# temperature=0.0,
# top_p=1e-45,
top_k=1,
)
except anthropic.BadRequestError as e:
raise BadRequestError(e.message)
except Exception as e:
raise VendorError(
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
)
for content in response.content:
if content.type == "tool_use":
if content.input is None or not content.input:
raise VendorError(
errors.VENDOR_THROW_ERROR.format(
error_message="content.input is None or content.input is empty"
)
)
return content.input
raise VendorError(
errors.VENDOR_THROW_ERROR.format(error_message="No tool_use found")
)
@weave.op
async def follow_schema(self, schema, data):
logger.info("Following structure via Anthropic...")
tools = [
{
"name": "extract_garment_info",
"description": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE,
"input_schema": schema,
"cache_control": {"type": "ephemeral"},
}
]
text_messages = [
{
"type": "text",
"text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
}
]
system_message = [
{"type": "text", "text": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE}
]
messages = [{"role": "user", "content": text_messages}]
try:
response = await self.client.messages.create(
model=settings.ANTHROPIC_DEFAULT_MODEL,
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
max_tokens=2048,
system=system_message,
tools=tools,
messages=messages,
)
except Exception as e:
raise VendorError(
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
)
for content in response.content:
if content.type == "tool_use":
return content.input["json_info"]
return {"status": "ERROR: no tool_use found"}