File size: 6,477 Bytes
8ba64a4
9bb2fc2
8ba64a4
 
9645c29
8ba64a4
 
 
 
 
 
 
 
4e55bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98ef152
4e55bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b652f9c
 
 
 
 
 
a9d8d74
8ba64a4
 
 
 
 
 
 
 
 
b652f9c
8ba64a4
 
 
a9d8d74
 
 
 
 
 
 
 
b652f9c
a9d8d74
 
 
8ba64a4
 
 
 
 
 
 
 
4e55bb0
8ba64a4
 
 
638f225
8ba64a4
 
b652f9c
8ba64a4
 
9645c29
 
4e55bb0
 
 
 
 
 
 
 
 
 
b652f9c
 
4e55bb0
 
 
8ba64a4
4e55bb0
8ba64a4
 
e85027d
8ba64a4
e85027d
8ba64a4
b652f9c
8ba64a4
 
9645c29
a9d8d74
 
4e55bb0
a9d8d74
 
 
 
 
 
b652f9c
a9d8d74
 
 
4e55bb0
a9d8d74
98ef152
a9d8d74
 
 
 
 
98ef152
a9d8d74
 
 
 
 
98ef152
 
a9d8d74
4e55bb0
 
 
 
 
8ba64a4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type, Union

from pydantic import BaseModel
from app.utils.converter import to_snake_case

from app.schemas.schema_tools import (
    convert_attribute_to_model,
    validate_json_data,
    validate_json_schema,
)


def cf_style_to_pydantic_percentage_shema(

    cf_style_schema: dict,

) -> str:
    """

    Convert CF style schema to Pydantic schema

    """
    print(f'{cf_style_schema}')
    attributes_line_in_product = []
    values_classes = []
    for attribute, attribute_info in cf_style_schema.items():
        multiple = False
        if "list" in attribute_info.data_type:
            multiple = True
        else:
            multiple = False
        class_name = "Class_" + attribute.capitalize()
        multiple_desc = "multi-label classification" if multiple else "single-label classification"
        attribute_desc = attribute_info.description
        attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")'

        class_code = f"""

class {class_name}(BaseModel):



"""
        for value in attribute_info.allowed_values:
            class_code += f"    {value.lower().replace(' ', '_').replace('-', '_')}: int\n"
        
        values_classes.append(class_code)
        attributes_line_in_product.append(attribute_line)
    attributes_line = "\n    ".join(attributes_line_in_product)
    values_classes_code = "\n".join(values_classes)
    pydantic_schema = f"""

from pydantic import BaseModel, Field

{values_classes_code}

class Product(BaseModel):

    {attributes_line}

"""
    pydantic_code = pydantic_schema.strip()
    exec(pydantic_code, globals())
    return Product

def build_attributes_types_prompt(attributes):
    list_of_types_prompt = "\n List of attributes types:\n"
    for key, value in attributes.items():
        list_of_types_prompt += f"- {key}: {value.data_type}\n"
    return list_of_types_prompt


class BaseAttributionService(ABC):
    @abstractmethod
    async def extract_attributes(

        self,

        attributes_model: Type[BaseModel],

        ai_model: str,

        img_urls: List[str],

        product_taxonomy: str,

        pil_images: List[Any] = None,

        appended_prompt: str = "",

    ) -> Dict[str, Any]:
        pass

    @abstractmethod
    async def reevaluate_atributes(

        self,

        attributes_model: Type[BaseModel],

        ai_model: str,

        img_urls: List[str],

        product_taxonomy: str,

        pil_images: List[Any] = None,

        appended_prompt: str = "",

    ) -> Dict[str, Any]:
        pass

    @abstractmethod
    async def follow_schema(

        self, schema: Dict[str, Any], data: Dict[str, Any]

    ) -> Dict[str, Any]:
        pass

    async def extract_attributes_with_validation(

        self,

        attributes: Dict[str, Any],

        ai_model: str,

        img_urls: List[str],

        product_taxonomy: str,

        product_data: Dict[str, Union[str, List[str]]],

        pil_images: List[Any] = None,

        img_paths: List[str] = None,

        appended_prompt = str

    ) -> Dict[str, Any]:
        # validate_json_schema(schema)

        # create mappings for keys of attributes, to make the key following naming convention of python variables
        forward_mapping = {}
        reverse_mapping = {}
        for i, key in enumerate(attributes.keys()):
            forward_mapping[key] = f'{to_snake_case(key)}_{i}'
            reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
        
        transformed_attributes = {}
        for key, value in attributes.items():
            transformed_attributes[forward_mapping[key]] = value

        attributes_types_prompt = build_attributes_types_prompt(attributes)

        # attributes_model = convert_attribute_to_model(transformed_attributes)
        attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes)
        schema = attributes_percentage_model.model_json_schema()
        data = await self.extract_attributes(
            attributes_percentage_model,
            ai_model,
            img_urls,
            product_taxonomy if product_taxonomy != "" else "main",
            product_data,
            # pil_images=pil_images, # temporarily removed to save cost
            img_paths=img_paths,
            appended_prompt=attributes_types_prompt
        )
        validate_json_data(data, schema)

        str_data = str(data)
        reevaluate_data = await self.reevaluate_atributes(
            attributes_percentage_model,
            ai_model,
            img_urls,
            product_taxonomy if product_taxonomy != "" else "main",
            str_data,
            # pil_images=pil_images, # temporarily removed to save cost
            img_paths=img_paths,
            appended_prompt=attributes_types_prompt
        )

        init_reevaluate_data = {}
        for field_name, field in attributes_percentage_model.model_fields.items(): # type: ignore
            print(f"{field_name}: {field.description}")
            if "single-label" in field.description.lower():
                max_percentage = 0
                for k, v in reevaluate_data[field_name].items():
                    if v > max_percentage:
                        max_percentage = v
                        init_reevaluate_data[field_name] = k
            elif "multi-label" in field.description.lower():
                init_list = []
                for k, v in reevaluate_data[field_name].items():
                    if v >= 60:
                        init_list.append(k)
                init_reevaluate_data[field_name] = init_list
            else:
                assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}"

        # reverse the key mapping to the original keys
        reverse_data = {}
        for key, value in init_reevaluate_data.items():
            reverse_data[reverse_mapping[key]] = value
        return data, reverse_data

    async def follow_schema_with_validation(

        self, schema: Dict[str, Any], data: Dict[str, Any]

    ) -> Dict[str, Any]:
        validate_json_schema(schema)
        data = await self.follow_schema(schema, data)
        validate_json_data(data, schema)
        return data