Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from typing import Any, Dict, List, Type, Union | |
from pydantic import BaseModel | |
from app.utils.converter import to_snake_case | |
from app.schemas.schema_tools import ( | |
convert_attribute_to_model, | |
validate_json_data, | |
validate_json_schema, | |
) | |
class BaseAttributionService(ABC): | |
async def extract_attributes( | |
self, | |
attributes_model: Type[BaseModel], | |
ai_model: str, | |
img_urls: List[str], | |
product_taxonomy: str, | |
pil_images: List[Any] = None, | |
) -> Dict[str, Any]: | |
pass | |
async def follow_schema( | |
self, schema: Dict[str, Any], data: Dict[str, Any] | |
) -> Dict[str, Any]: | |
pass | |
async def extract_attributes_with_validation( | |
self, | |
attributes: Dict[str, Any], | |
ai_model: str, | |
img_urls: List[str], | |
product_taxonomy: str, | |
product_data: Dict[str, Union[str, List[str]]], | |
pil_images: List[Any] = None, | |
img_paths: List[str] = None, | |
schema: Dict[str, Any] = None, | |
) -> Dict[str, Any]: | |
# validate_json_schema(schema) | |
# create mappings for keys of attributes, to make the key following naming convention of python variables | |
forward_mapping = {} | |
reverse_mapping = {} | |
for i, key in enumerate(attributes.keys()): | |
forward_mapping[key] = f'{to_snake_case(key)}_{i}' | |
reverse_mapping[f'{to_snake_case(key)}_{i}'] = key | |
transformed_attributes = {} | |
for key, value in attributes.items(): | |
transformed_attributes[forward_mapping[key]] = value | |
attributes_model = convert_attribute_to_model(transformed_attributes) | |
schema = attributes_model.model_json_schema() | |
data = await self.extract_attributes( | |
attributes_model, | |
ai_model, | |
img_urls, | |
product_taxonomy if product_taxonomy != "" else "main", | |
product_data, | |
# pil_images=pil_images, # temporarily removed to save cost | |
img_paths=img_paths, | |
) | |
validate_json_data(data, schema) | |
# reverse the key mapping to the original keys | |
reverse_data = {} | |
for key, value in data.items(): | |
reverse_data[reverse_mapping[key]] = value | |
return reverse_data | |
async def follow_schema_with_validation( | |
self, schema: Dict[str, Any], data: Dict[str, Any] | |
) -> Dict[str, Any]: | |
validate_json_schema(schema) | |
data = await self.follow_schema(schema, data) | |
validate_json_data(data, schema) | |
return data | |