|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Type, Union
|
|
|
|
from pydantic import BaseModel
|
|
from app.utils.converter import to_snake_case
|
|
|
|
from app.schemas.schema_tools import (
|
|
convert_attribute_to_model,
|
|
validate_json_data,
|
|
validate_json_schema,
|
|
)
|
|
|
|
|
|
def cf_style_to_pydantic_percentage_shema(
|
|
cf_style_schema: dict,
|
|
) -> str:
|
|
"""
|
|
Convert CF style schema to Pydantic schema
|
|
"""
|
|
print(f'{cf_style_schema}')
|
|
attributes_line_in_product = []
|
|
values_classes = []
|
|
for attribute, attribute_info in cf_style_schema.items():
|
|
multiple = False
|
|
if "list" in attribute_info.data_type:
|
|
multiple = True
|
|
else:
|
|
multiple = False
|
|
class_name = "Class_" + attribute.capitalize()
|
|
multiple_desc = "multi-label classification" if multiple else "single-label classification"
|
|
attribute_desc = attribute_info.description
|
|
attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")'
|
|
|
|
class_code = f"""
|
|
class {class_name}(BaseModel):
|
|
|
|
"""
|
|
for value in attribute_info.allowed_values:
|
|
class_code += f" {value.lower().replace(' ', '_').replace('-', '_')}: int\n"
|
|
|
|
values_classes.append(class_code)
|
|
attributes_line_in_product.append(attribute_line)
|
|
attributes_line = "\n ".join(attributes_line_in_product)
|
|
values_classes_code = "\n".join(values_classes)
|
|
pydantic_schema = f"""
|
|
from pydantic import BaseModel, Field
|
|
{values_classes_code}
|
|
class Product(BaseModel):
|
|
{attributes_line}
|
|
"""
|
|
pydantic_code = pydantic_schema.strip()
|
|
exec(pydantic_code, globals())
|
|
return Product
|
|
|
|
def build_attributes_types_prompt(attributes):
|
|
list_of_types_prompt = "\n List of attributes types:\n"
|
|
for key, value in attributes.items():
|
|
list_of_types_prompt += f"- {key}: {value.data_type}\n"
|
|
return list_of_types_prompt
|
|
|
|
|
|
class BaseAttributionService(ABC):
|
|
@abstractmethod
|
|
async def extract_attributes(
|
|
self,
|
|
attributes_model: Type[BaseModel],
|
|
ai_model: str,
|
|
img_urls: List[str],
|
|
product_taxonomy: str,
|
|
pil_images: List[Any] = None,
|
|
appended_prompt: str = "",
|
|
) -> Dict[str, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def reevaluate_atributes(
|
|
self,
|
|
attributes_model: Type[BaseModel],
|
|
ai_model: str,
|
|
img_urls: List[str],
|
|
product_taxonomy: str,
|
|
pil_images: List[Any] = None,
|
|
appended_prompt: str = "",
|
|
) -> Dict[str, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def follow_schema(
|
|
self, schema: Dict[str, Any], data: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
pass
|
|
|
|
async def extract_attributes_with_validation(
|
|
self,
|
|
attributes: Dict[str, Any],
|
|
ai_model: str,
|
|
img_urls: List[str],
|
|
product_taxonomy: str,
|
|
product_data: Dict[str, Union[str, List[str]]],
|
|
pil_images: List[Any] = None,
|
|
img_paths: List[str] = None,
|
|
appended_prompt = str
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
forward_mapping = {}
|
|
reverse_mapping = {}
|
|
for i, key in enumerate(attributes.keys()):
|
|
forward_mapping[key] = f'{to_snake_case(key)}_{i}'
|
|
reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
|
|
|
|
transformed_attributes = {}
|
|
for key, value in attributes.items():
|
|
transformed_attributes[forward_mapping[key]] = value
|
|
|
|
attributes_types_prompt = build_attributes_types_prompt(attributes)
|
|
|
|
|
|
attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes)
|
|
schema = attributes_percentage_model.model_json_schema()
|
|
data = await self.extract_attributes(
|
|
attributes_percentage_model,
|
|
ai_model,
|
|
img_urls,
|
|
product_taxonomy if product_taxonomy != "" else "main",
|
|
product_data,
|
|
|
|
img_paths=img_paths,
|
|
appended_prompt=attributes_types_prompt
|
|
)
|
|
validate_json_data(data, schema)
|
|
|
|
str_data = str(data)
|
|
reevaluate_data = await self.reevaluate_atributes(
|
|
attributes_percentage_model,
|
|
ai_model,
|
|
img_urls,
|
|
product_taxonomy if product_taxonomy != "" else "main",
|
|
str_data,
|
|
|
|
img_paths=img_paths,
|
|
appended_prompt=attributes_types_prompt
|
|
)
|
|
|
|
init_reevaluate_data = {}
|
|
for field_name, field in attributes_percentage_model.model_fields.items():
|
|
print(f"{field_name}: {field.description}")
|
|
if "single-label" in field.description.lower():
|
|
max_percentage = 0
|
|
for k, v in reevaluate_data[field_name].items():
|
|
if v > max_percentage:
|
|
max_percentage = v
|
|
init_reevaluate_data[field_name] = k
|
|
elif "multi-label" in field.description.lower():
|
|
init_list = []
|
|
for k, v in reevaluate_data[field_name].items():
|
|
if v >= 60:
|
|
init_list.append(k)
|
|
init_reevaluate_data[field_name] = init_list
|
|
else:
|
|
assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}"
|
|
|
|
|
|
reverse_data = {}
|
|
for key, value in init_reevaluate_data.items():
|
|
reverse_data[reverse_mapping[key]] = value
|
|
return data, reverse_data
|
|
|
|
async def follow_schema_with_validation(
|
|
self, schema: Dict[str, Any], data: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
validate_json_schema(schema)
|
|
data = await self.follow_schema(schema, data)
|
|
validate_json_data(data, schema)
|
|
return data
|
|
|