from abc import ABC, abstractmethod from typing import Any, Dict, List, Type, Union from pydantic import BaseModel from app.utils.converter import to_snake_case from app.schemas.schema_tools import ( convert_attribute_to_model, validate_json_data, validate_json_schema, ) def cf_style_to_pydantic_percentage_shema( cf_style_schema: dict, ) -> str: """ Convert CF style schema to Pydantic schema """ print(f'{cf_style_schema}') attributes_line_in_product = [] values_classes = [] for attribute, attribute_info in cf_style_schema.items(): multiple = False if "list" in attribute_info.data_type: multiple = True else: multiple = False class_name = "Class_" + attribute.capitalize() multiple_desc = "multi-label classification" if multiple else "single-label classification" attribute_desc = attribute_info.description attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")' class_code = f""" class {class_name}(BaseModel): """ for value in attribute_info.allowed_values: class_code += f" {value.lower().replace(' ', '_').replace('-', '_')}: int\n" values_classes.append(class_code) attributes_line_in_product.append(attribute_line) attributes_line = "\n ".join(attributes_line_in_product) values_classes_code = "\n".join(values_classes) pydantic_schema = f""" from pydantic import BaseModel, Field {values_classes_code} class Product(BaseModel): {attributes_line} """ pydantic_code = pydantic_schema.strip() exec(pydantic_code, globals()) return Product def build_attributes_types_prompt(attributes): list_of_types_prompt = "\n List of attributes types:\n" for key, value in attributes.items(): list_of_types_prompt += f"- {key}: {value.data_type}\n" return list_of_types_prompt class BaseAttributionService(ABC): @abstractmethod async def extract_attributes( self, attributes_model: Type[BaseModel], ai_model: str, img_urls: List[str], product_taxonomy: str, pil_images: List[Any] = None, appended_prompt: str = "", ) -> Dict[str, Any]: pass @abstractmethod async def reevaluate_atributes( self, attributes_model: Type[BaseModel], ai_model: str, img_urls: List[str], product_taxonomy: str, pil_images: List[Any] = None, appended_prompt: str = "", ) -> Dict[str, Any]: pass @abstractmethod async def follow_schema( self, schema: Dict[str, Any], data: Dict[str, Any] ) -> Dict[str, Any]: pass async def extract_attributes_with_validation( self, attributes: Dict[str, Any], ai_model: str, img_urls: List[str], product_taxonomy: str, product_data: Dict[str, Union[str, List[str]]], pil_images: List[Any] = None, img_paths: List[str] = None, appended_prompt = str ) -> Dict[str, Any]: # validate_json_schema(schema) # create mappings for keys of attributes, to make the key following naming convention of python variables forward_mapping = {} reverse_mapping = {} for i, key in enumerate(attributes.keys()): forward_mapping[key] = f'{to_snake_case(key)}_{i}' reverse_mapping[f'{to_snake_case(key)}_{i}'] = key transformed_attributes = {} for key, value in attributes.items(): transformed_attributes[forward_mapping[key]] = value attributes_types_prompt = build_attributes_types_prompt(attributes) # attributes_model = convert_attribute_to_model(transformed_attributes) attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes) schema = attributes_percentage_model.model_json_schema() data = await self.extract_attributes( attributes_percentage_model, ai_model, img_urls, product_taxonomy if product_taxonomy != "" else "main", product_data, # pil_images=pil_images, # temporarily removed to save cost img_paths=img_paths, appended_prompt=attributes_types_prompt ) validate_json_data(data, schema) str_data = str(data) reevaluate_data = await self.reevaluate_atributes( attributes_percentage_model, ai_model, img_urls, product_taxonomy if product_taxonomy != "" else "main", str_data, # pil_images=pil_images, # temporarily removed to save cost img_paths=img_paths, appended_prompt=attributes_types_prompt ) init_reevaluate_data = {} for field_name, field in attributes_percentage_model.model_fields.items(): # type: ignore print(f"{field_name}: {field.description}") if "single-label" in field.description.lower(): max_percentage = 0 for k, v in reevaluate_data[field_name].items(): if v > max_percentage: max_percentage = v init_reevaluate_data[field_name] = k elif "multi-label" in field.description.lower(): init_list = [] for k, v in reevaluate_data[field_name].items(): if v >= 60: init_list.append(k) init_reevaluate_data[field_name] = init_list else: assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}" # reverse the key mapping to the original keys reverse_data = {} for key, value in init_reevaluate_data.items(): reverse_data[reverse_mapping[key]] = value return data, reverse_data async def follow_schema_with_validation( self, schema: Dict[str, Any], data: Dict[str, Any] ) -> Dict[str, Any]: validate_json_schema(schema) data = await self.follow_schema(schema, data) validate_json_data(data, schema) return data