thanhnt-cf's picture
fix issues in description
98ef152
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type, Union
from pydantic import BaseModel
from app.utils.converter import to_snake_case
from app.schemas.schema_tools import (
convert_attribute_to_model,
validate_json_data,
validate_json_schema,
)
def cf_style_to_pydantic_percentage_shema(
cf_style_schema: dict,
) -> str:
"""
Convert CF style schema to Pydantic schema
"""
print(f'{cf_style_schema}')
attributes_line_in_product = []
values_classes = []
for attribute, attribute_info in cf_style_schema.items():
multiple = False
if "list" in attribute_info.data_type:
multiple = True
else:
multiple = False
class_name = "Class_" + attribute.capitalize()
multiple_desc = "multi-label classification" if multiple else "single-label classification"
attribute_desc = attribute_info.description
attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")'
class_code = f"""
class {class_name}(BaseModel):
"""
for value in attribute_info.allowed_values:
class_code += f" {value.lower().replace(' ', '_').replace('-', '_')}: int\n"
values_classes.append(class_code)
attributes_line_in_product.append(attribute_line)
attributes_line = "\n ".join(attributes_line_in_product)
values_classes_code = "\n".join(values_classes)
pydantic_schema = f"""
from pydantic import BaseModel, Field
{values_classes_code}
class Product(BaseModel):
{attributes_line}
"""
pydantic_code = pydantic_schema.strip()
exec(pydantic_code, globals())
return Product
def build_attributes_types_prompt(attributes):
list_of_types_prompt = "\n List of attributes types:\n"
for key, value in attributes.items():
list_of_types_prompt += f"- {key}: {value.data_type}\n"
return list_of_types_prompt
class BaseAttributionService(ABC):
@abstractmethod
async def extract_attributes(
self,
attributes_model: Type[BaseModel],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
pil_images: List[Any] = None,
appended_prompt: str = "",
) -> Dict[str, Any]:
pass
@abstractmethod
async def reevaluate_atributes(
self,
attributes_model: Type[BaseModel],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
pil_images: List[Any] = None,
appended_prompt: str = "",
) -> Dict[str, Any]:
pass
@abstractmethod
async def follow_schema(
self, schema: Dict[str, Any], data: Dict[str, Any]
) -> Dict[str, Any]:
pass
async def extract_attributes_with_validation(
self,
attributes: Dict[str, Any],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
product_data: Dict[str, Union[str, List[str]]],
pil_images: List[Any] = None,
img_paths: List[str] = None,
appended_prompt = str
) -> Dict[str, Any]:
# validate_json_schema(schema)
# create mappings for keys of attributes, to make the key following naming convention of python variables
forward_mapping = {}
reverse_mapping = {}
for i, key in enumerate(attributes.keys()):
forward_mapping[key] = f'{to_snake_case(key)}_{i}'
reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
transformed_attributes = {}
for key, value in attributes.items():
transformed_attributes[forward_mapping[key]] = value
attributes_types_prompt = build_attributes_types_prompt(attributes)
# attributes_model = convert_attribute_to_model(transformed_attributes)
attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes)
schema = attributes_percentage_model.model_json_schema()
data = await self.extract_attributes(
attributes_percentage_model,
ai_model,
img_urls,
product_taxonomy if product_taxonomy != "" else "main",
product_data,
# pil_images=pil_images, # temporarily removed to save cost
img_paths=img_paths,
appended_prompt=attributes_types_prompt
)
validate_json_data(data, schema)
str_data = str(data)
reevaluate_data = await self.reevaluate_atributes(
attributes_percentage_model,
ai_model,
img_urls,
product_taxonomy if product_taxonomy != "" else "main",
str_data,
# pil_images=pil_images, # temporarily removed to save cost
img_paths=img_paths,
appended_prompt=attributes_types_prompt
)
init_reevaluate_data = {}
for field_name, field in attributes_percentage_model.model_fields.items(): # type: ignore
print(f"{field_name}: {field.description}")
if "single-label" in field.description.lower():
max_percentage = 0
for k, v in reevaluate_data[field_name].items():
if v > max_percentage:
max_percentage = v
init_reevaluate_data[field_name] = k
elif "multi-label" in field.description.lower():
init_list = []
for k, v in reevaluate_data[field_name].items():
if v >= 60:
init_list.append(k)
init_reevaluate_data[field_name] = init_list
else:
assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}"
# reverse the key mapping to the original keys
reverse_data = {}
for key, value in init_reevaluate_data.items():
reverse_data[reverse_mapping[key]] = value
return data, reverse_data
async def follow_schema_with_validation(
self, schema: Dict[str, Any], data: Dict[str, Any]
) -> Dict[str, Any]:
validate_json_schema(schema)
data = await self.follow_schema(schema, data)
validate_json_data(data, schema)
return data