|
import json
|
|
import os
|
|
from typing import Any, Dict, List, Type, Union
|
|
|
|
import anthropic
|
|
import weave
|
|
from anthropic import APIStatusError, AsyncAnthropic
|
|
from pydantic import BaseModel
|
|
|
|
from app.config import get_settings
|
|
from app.core import errors
|
|
from app.core.errors import BadRequestError, VendorError
|
|
from app.core.prompts import get_prompts
|
|
from app.services.base import BaseAttributionService
|
|
from app.utils.converter import product_data_to_str
|
|
from app.utils.image_processing import get_data_format, get_image_data
|
|
from app.utils.logger import exception_to_str, setup_logger
|
|
|
|
ENV = os.getenv("ENV", "LOCAL")
|
|
if ENV == "LOCAL":
|
|
weave_project_name = "cfai/attribution-exp"
|
|
elif ENV == "DEV":
|
|
weave_project_name = "cfai/attribution-dev"
|
|
elif ENV == "UAT":
|
|
weave_project_name = "cfai/attribution-uat"
|
|
elif ENV == "PROD":
|
|
pass
|
|
|
|
|
|
|
|
settings = get_settings()
|
|
prompts = get_prompts()
|
|
logger = setup_logger(__name__)
|
|
|
|
|
|
class AnthropicService(BaseAttributionService):
|
|
def __init__(self):
|
|
self.client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
|
|
@weave.op
|
|
async def extract_attributes(
|
|
self,
|
|
attributes_model: Type[BaseModel],
|
|
ai_model: str,
|
|
img_urls: List[str],
|
|
product_taxonomy: str,
|
|
product_data: Dict[str, Union[str, List[str]]],
|
|
pil_images: List[Any] = None,
|
|
img_paths: List[str] = None,
|
|
) -> Dict[str, Any]:
|
|
logger.info("Extracting info via Anthropic...")
|
|
tools = [
|
|
{
|
|
"name": "extract_garment_info",
|
|
"description": "Extracts key information from the image.",
|
|
"input_schema": attributes_model.model_json_schema(),
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
]
|
|
|
|
if img_urls is not None:
|
|
image_messages = [
|
|
{
|
|
"type": "image",
|
|
"source": {"type": "url", "url": img_url},
|
|
}
|
|
for img_url in img_urls
|
|
]
|
|
elif img_paths is not None:
|
|
image_messages = [
|
|
{
|
|
"type": "image",
|
|
"source": {
|
|
"type": "base64",
|
|
"media_type": f"image/{get_data_format(img_path)}",
|
|
"data": get_image_data(img_path),
|
|
},
|
|
}
|
|
for img_path in img_paths
|
|
]
|
|
else:
|
|
|
|
pass
|
|
|
|
system_message = [{"type": "text", "text": prompts.GET_PERCENTAGE_SYSTEM_MESSAGE}]
|
|
|
|
text_messages = [
|
|
{
|
|
"type": "text",
|
|
"text": prompts.GET_PERCENTAGE_HUMAN_MESSAGE.format(
|
|
product_taxonomy=product_taxonomy,
|
|
product_data=product_data_to_str(product_data),
|
|
),
|
|
}
|
|
]
|
|
|
|
messages = [{"role": "user", "content": text_messages + image_messages}]
|
|
|
|
|
|
try:
|
|
response = await self.client.messages.create(
|
|
model=ai_model,
|
|
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
|
|
max_tokens=2048,
|
|
system=system_message,
|
|
tools=tools,
|
|
messages=messages,
|
|
|
|
|
|
top_k=1,
|
|
)
|
|
except anthropic.BadRequestError as e:
|
|
raise BadRequestError(e.message)
|
|
except Exception as e:
|
|
raise VendorError(
|
|
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
|
)
|
|
|
|
for content in response.content:
|
|
if content.type == "tool_use":
|
|
if content.input is None or not content.input:
|
|
raise VendorError(
|
|
errors.VENDOR_THROW_ERROR.format(
|
|
error_message="content.input is None or content.input is empty"
|
|
)
|
|
)
|
|
|
|
return content.input
|
|
|
|
raise VendorError(
|
|
errors.VENDOR_THROW_ERROR.format(error_message="No tool_use found")
|
|
)
|
|
|
|
@weave.op
|
|
async def follow_schema(self, schema, data):
|
|
logger.info("Following structure via Anthropic...")
|
|
tools = [
|
|
{
|
|
"name": "extract_garment_info",
|
|
"description": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE,
|
|
"input_schema": schema,
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
]
|
|
|
|
text_messages = [
|
|
{
|
|
"type": "text",
|
|
"text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
|
|
}
|
|
]
|
|
|
|
system_message = [
|
|
{"type": "text", "text": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE}
|
|
]
|
|
|
|
messages = [{"role": "user", "content": text_messages}]
|
|
try:
|
|
response = await self.client.messages.create(
|
|
model=settings.ANTHROPIC_DEFAULT_MODEL,
|
|
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
|
|
max_tokens=2048,
|
|
system=system_message,
|
|
tools=tools,
|
|
messages=messages,
|
|
)
|
|
except Exception as e:
|
|
raise VendorError(
|
|
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
|
)
|
|
|
|
for content in response.content:
|
|
if content.type == "tool_use":
|
|
return content.input["json_info"]
|
|
|
|
return {"status": "ERROR: no tool_use found"}
|
|
|