from abc import ABC, abstractmethod from typing import Any, Dict, List, Type, Union from pydantic import BaseModel from app.utils.converter import to_snake_case from app.schemas.schema_tools import ( convert_attribute_to_model, validate_json_data, validate_json_schema, ) class BaseAttributionService(ABC): @abstractmethod async def extract_attributes( self, attributes_model: Type[BaseModel], ai_model: str, img_urls: List[str], product_taxonomy: str, pil_images: List[Any] = None, ) -> Dict[str, Any]: pass @abstractmethod async def follow_schema( self, schema: Dict[str, Any], data: Dict[str, Any] ) -> Dict[str, Any]: pass async def extract_attributes_with_validation( self, attributes: Dict[str, Any], ai_model: str, img_urls: List[str], product_taxonomy: str, product_data: Dict[str, Union[str, List[str]]], pil_images: List[Any] = None, img_paths: List[str] = None, schema: Dict[str, Any] = None, ) -> Dict[str, Any]: # validate_json_schema(schema) # create mappings for keys of attributes, to make the key following naming convention of python variables forward_mapping = {} reverse_mapping = {} for i, key in enumerate(attributes.keys()): forward_mapping[key] = f'{to_snake_case(key)}_{i}' reverse_mapping[f'{to_snake_case(key)}_{i}'] = key transformed_attributes = {} for key, value in attributes.items(): transformed_attributes[forward_mapping[key]] = value attributes_model = convert_attribute_to_model(transformed_attributes) schema = attributes_model.model_json_schema() data = await self.extract_attributes( attributes_model, ai_model, img_urls, product_taxonomy if product_taxonomy != "" else "main", product_data, # pil_images=pil_images, # temporarily removed to save cost img_paths=img_paths, ) validate_json_data(data, schema) # reverse the key mapping to the original keys reverse_data = {} for key, value in data.items(): reverse_data[reverse_mapping[key]] = value return reverse_data async def follow_schema_with_validation( self, schema: Dict[str, Any], data: Dict[str, Any] ) -> Dict[str, Any]: validate_json_schema(schema) data = await self.follow_schema(schema, data) validate_json_data(data, schema) return data