|
import json |
|
import os |
|
from typing import Any, Dict, List, Type, Union |
|
|
|
import openai |
|
import weave |
|
from openai import AsyncOpenAI |
|
from pydantic import BaseModel |
|
|
|
from app.utils.converter import product_data_to_str |
|
from app.utils.image_processing import ( |
|
get_data_format, |
|
get_image_base64_and_type, |
|
get_image_data, |
|
) |
|
from app.utils.logger import exception_to_str, setup_logger |
|
|
|
from ..config import get_settings |
|
from ..core import errors |
|
from ..core.errors import BadRequestError, VendorError |
|
from ..core.prompts import get_prompts |
|
from .base import BaseAttributionService |
|
|
|
ENV = os.getenv("ENV", "LOCAL") |
|
if ENV == "LOCAL": |
|
weave_project_name = "cfai/attribution-exp" |
|
elif ENV == "DEV": |
|
weave_project_name = "cfai/attribution-dev" |
|
elif ENV == "UAT": |
|
weave_project_name = "cfai/attribution-uat" |
|
elif ENV == "PROD": |
|
pass |
|
|
|
|
|
|
|
settings = get_settings() |
|
prompts = get_prompts() |
|
logger = setup_logger(__name__) |
|
|
|
|
|
def get_response_format(json_schema: dict[str, any]) -> dict[str, any]: |
|
|
|
json_schema["additionalProperties"] = False |
|
|
|
|
|
if "$defs" in json_schema: |
|
for keys in json_schema["$defs"].keys(): |
|
json_schema["$defs"][keys]["additionalProperties"] = False |
|
response_format = { |
|
"type": "json_schema", |
|
"json_schema": {"strict": True, "name": "GarmentSchema", "schema": json_schema}, |
|
} |
|
|
|
return response_format |
|
|
|
|
|
class OpenAIService(BaseAttributionService): |
|
def __init__(self): |
|
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) |
|
|
|
@weave.op |
|
async def extract_attributes( |
|
self, |
|
attributes_model: Type[BaseModel], |
|
ai_model: str, |
|
img_urls: List[str], |
|
product_taxonomy: str, |
|
product_data: Dict[str, Union[str, List[str]]], |
|
pil_images: List[Any] = None, |
|
img_paths: List[str] = None, |
|
appended_prompt: str = "", |
|
) -> Dict[str, Any]: |
|
|
|
print("Prompt: ") |
|
print(prompts.GET_PERCENTAGE_HUMAN_MESSAGE.format(product_taxonomy=product_taxonomy, product_data=product_data_to_str(product_data)) + appended_prompt) |
|
|
|
text_content = [ |
|
{ |
|
"type": "text", |
|
"text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format( |
|
product_taxonomy=product_taxonomy, |
|
product_data=product_data_to_str(product_data), |
|
) + appended_prompt, |
|
}, |
|
] |
|
if img_urls is not None: |
|
base64_data_list = [] |
|
data_format_list = [] |
|
|
|
for img_url in img_urls: |
|
base64_data, data_format = get_image_base64_and_type(img_url) |
|
base64_data_list.append(base64_data) |
|
data_format_list.append(data_format) |
|
|
|
image_content = [ |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/{data_format};base64,{base64_data}", |
|
}, |
|
} |
|
for base64_data, data_format in zip(base64_data_list, data_format_list) |
|
] |
|
elif img_paths is not None: |
|
image_content = [ |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}", |
|
}, |
|
} |
|
for img_path in img_paths |
|
] |
|
|
|
try: |
|
logger.info("Extracting info via OpenAI...") |
|
response = await self.client.beta.chat.completions.parse( |
|
model=ai_model, |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": prompts.GET_PERCENTAGE_SYSTEM_MESSAGE, |
|
}, |
|
{ |
|
"role": "user", |
|
"content": text_content + image_content, |
|
}, |
|
], |
|
max_tokens=1000, |
|
response_format=attributes_model, |
|
logprobs=False, |
|
|
|
|
|
top_p=1e-45, |
|
) |
|
except openai.BadRequestError as e: |
|
error_message = exception_to_str(e) |
|
raise BadRequestError(error_message) |
|
except Exception as e: |
|
raise VendorError( |
|
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) |
|
) |
|
|
|
try: |
|
content = response.choices[0].message.content |
|
parsed_data = json.loads(content) |
|
except: |
|
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON) |
|
|
|
return parsed_data |
|
|
|
async def reevaluate_atributes( |
|
self, |
|
attributes_model: Type[BaseModel], |
|
ai_model: str, |
|
img_urls: List[str], |
|
product_taxonomy: str, |
|
product_data: str, |
|
pil_images: List[Any] = None, |
|
img_paths: List[str] = None, |
|
appended_prompt: str = "", |
|
) -> Dict[str, Any]: |
|
|
|
print("Prompt: ") |
|
print(prompts.REEVALUATE_HUMAN_MESSAGE.format(product_taxonomy=product_taxonomy, product_data=product_data) + appended_prompt) |
|
|
|
text_content = [ |
|
{ |
|
"type": "text", |
|
"text": prompts.REEVALUATE_HUMAN_MESSAGE.format( |
|
product_taxonomy=product_taxonomy, |
|
product_data=product_data, |
|
) + appended_prompt, |
|
}, |
|
] |
|
if img_urls is not None: |
|
base64_data_list = [] |
|
data_format_list = [] |
|
|
|
for img_url in img_urls: |
|
base64_data, data_format = get_image_base64_and_type(img_url) |
|
base64_data_list.append(base64_data) |
|
data_format_list.append(data_format) |
|
|
|
image_content = [ |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/{data_format};base64,{base64_data}", |
|
}, |
|
} |
|
for base64_data, data_format in zip(base64_data_list, data_format_list) |
|
] |
|
elif img_paths is not None: |
|
image_content = [ |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}", |
|
}, |
|
} |
|
for img_path in img_paths |
|
] |
|
|
|
try: |
|
logger.info("Extracting info via OpenAI...") |
|
response = await self.client.beta.chat.completions.parse( |
|
model=ai_model, |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": prompts.REEVALUATE_SYSTEM_MESSAGE, |
|
}, |
|
{ |
|
"role": "user", |
|
"content": text_content + image_content, |
|
}, |
|
], |
|
max_tokens=1000, |
|
response_format=attributes_model, |
|
logprobs=False, |
|
|
|
|
|
top_p=1e-45, |
|
) |
|
except openai.BadRequestError as e: |
|
error_message = exception_to_str(e) |
|
raise BadRequestError(error_message) |
|
except Exception as e: |
|
raise VendorError( |
|
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) |
|
) |
|
|
|
try: |
|
content = response.choices[0].message.content |
|
parsed_data = json.loads(content) |
|
except: |
|
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON) |
|
|
|
return parsed_data |
|
|
|
@weave.op |
|
async def follow_schema( |
|
self, schema: Dict[str, Any], data: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
logger.info("Following structure via OpenAI...") |
|
text_content = [ |
|
{ |
|
"type": "text", |
|
"text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data), |
|
}, |
|
] |
|
|
|
try: |
|
response = await self.client.beta.chat.completions.parse( |
|
model="gpt-4o-2024-11-20", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE, |
|
}, |
|
{ |
|
"role": "user", |
|
"content": text_content, |
|
}, |
|
], |
|
max_tokens=1000, |
|
response_format=get_response_format(schema), |
|
logprobs=False, |
|
|
|
temperature=0.0, |
|
) |
|
except Exception as e: |
|
raise VendorError( |
|
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) |
|
) |
|
|
|
if response.choices[0].message.refusal: |
|
logger.info("OpenAI refused to respond to the request") |
|
return {"status": "refused"} |
|
|
|
try: |
|
content = response.choices[0].message.content |
|
parsed_data = json.loads(content) |
|
except: |
|
raise ValueError(errors.VENDOR_ERROR_INVALID_JSON) |
|
|
|
return parsed_data |
|
|