File size: 6,477 Bytes
8ba64a4 9bb2fc2 8ba64a4 9645c29 8ba64a4 4e55bb0 98ef152 4e55bb0 b652f9c a9d8d74 8ba64a4 b652f9c 8ba64a4 a9d8d74 b652f9c a9d8d74 8ba64a4 4e55bb0 8ba64a4 638f225 8ba64a4 b652f9c 8ba64a4 9645c29 4e55bb0 b652f9c 4e55bb0 8ba64a4 4e55bb0 8ba64a4 e85027d 8ba64a4 e85027d 8ba64a4 b652f9c 8ba64a4 9645c29 a9d8d74 4e55bb0 a9d8d74 b652f9c a9d8d74 4e55bb0 a9d8d74 98ef152 a9d8d74 98ef152 a9d8d74 98ef152 a9d8d74 4e55bb0 8ba64a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type, Union
from pydantic import BaseModel
from app.utils.converter import to_snake_case
from app.schemas.schema_tools import (
convert_attribute_to_model,
validate_json_data,
validate_json_schema,
)
def cf_style_to_pydantic_percentage_shema(
cf_style_schema: dict,
) -> str:
"""
Convert CF style schema to Pydantic schema
"""
print(f'{cf_style_schema}')
attributes_line_in_product = []
values_classes = []
for attribute, attribute_info in cf_style_schema.items():
multiple = False
if "list" in attribute_info.data_type:
multiple = True
else:
multiple = False
class_name = "Class_" + attribute.capitalize()
multiple_desc = "multi-label classification" if multiple else "single-label classification"
attribute_desc = attribute_info.description
attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")'
class_code = f"""
class {class_name}(BaseModel):
"""
for value in attribute_info.allowed_values:
class_code += f" {value.lower().replace(' ', '_').replace('-', '_')}: int\n"
values_classes.append(class_code)
attributes_line_in_product.append(attribute_line)
attributes_line = "\n ".join(attributes_line_in_product)
values_classes_code = "\n".join(values_classes)
pydantic_schema = f"""
from pydantic import BaseModel, Field
{values_classes_code}
class Product(BaseModel):
{attributes_line}
"""
pydantic_code = pydantic_schema.strip()
exec(pydantic_code, globals())
return Product
def build_attributes_types_prompt(attributes):
list_of_types_prompt = "\n List of attributes types:\n"
for key, value in attributes.items():
list_of_types_prompt += f"- {key}: {value.data_type}\n"
return list_of_types_prompt
class BaseAttributionService(ABC):
@abstractmethod
async def extract_attributes(
self,
attributes_model: Type[BaseModel],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
pil_images: List[Any] = None,
appended_prompt: str = "",
) -> Dict[str, Any]:
pass
@abstractmethod
async def reevaluate_atributes(
self,
attributes_model: Type[BaseModel],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
pil_images: List[Any] = None,
appended_prompt: str = "",
) -> Dict[str, Any]:
pass
@abstractmethod
async def follow_schema(
self, schema: Dict[str, Any], data: Dict[str, Any]
) -> Dict[str, Any]:
pass
async def extract_attributes_with_validation(
self,
attributes: Dict[str, Any],
ai_model: str,
img_urls: List[str],
product_taxonomy: str,
product_data: Dict[str, Union[str, List[str]]],
pil_images: List[Any] = None,
img_paths: List[str] = None,
appended_prompt = str
) -> Dict[str, Any]:
# validate_json_schema(schema)
# create mappings for keys of attributes, to make the key following naming convention of python variables
forward_mapping = {}
reverse_mapping = {}
for i, key in enumerate(attributes.keys()):
forward_mapping[key] = f'{to_snake_case(key)}_{i}'
reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
transformed_attributes = {}
for key, value in attributes.items():
transformed_attributes[forward_mapping[key]] = value
attributes_types_prompt = build_attributes_types_prompt(attributes)
# attributes_model = convert_attribute_to_model(transformed_attributes)
attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes)
schema = attributes_percentage_model.model_json_schema()
data = await self.extract_attributes(
attributes_percentage_model,
ai_model,
img_urls,
product_taxonomy if product_taxonomy != "" else "main",
product_data,
# pil_images=pil_images, # temporarily removed to save cost
img_paths=img_paths,
appended_prompt=attributes_types_prompt
)
validate_json_data(data, schema)
str_data = str(data)
reevaluate_data = await self.reevaluate_atributes(
attributes_percentage_model,
ai_model,
img_urls,
product_taxonomy if product_taxonomy != "" else "main",
str_data,
# pil_images=pil_images, # temporarily removed to save cost
img_paths=img_paths,
appended_prompt=attributes_types_prompt
)
init_reevaluate_data = {}
for field_name, field in attributes_percentage_model.model_fields.items(): # type: ignore
print(f"{field_name}: {field.description}")
if "single-label" in field.description.lower():
max_percentage = 0
for k, v in reevaluate_data[field_name].items():
if v > max_percentage:
max_percentage = v
init_reevaluate_data[field_name] = k
elif "multi-label" in field.description.lower():
init_list = []
for k, v in reevaluate_data[field_name].items():
if v >= 60:
init_list.append(k)
init_reevaluate_data[field_name] = init_list
else:
assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}"
# reverse the key mapping to the original keys
reverse_data = {}
for key, value in init_reevaluate_data.items():
reverse_data[reverse_mapping[key]] = value
return data, reverse_data
async def follow_schema_with_validation(
self, schema: Dict[str, Any], data: Dict[str, Any]
) -> Dict[str, Any]:
validate_json_schema(schema)
data = await self.follow_schema(schema, data)
validate_json_data(data, schema)
return data
|