import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import json import requests from PIL import Image from torchvision import transforms import urllib.request from torchvision import models import torch.nn as nn # --- Define the Model --- class FineGrainedClassifier(nn.Module): def __init__(self, num_classes=434): # Updated to 434 classes super(FineGrainedClassifier, self).__init__() self.image_encoder = models.resnet50(pretrained=True) self.image_encoder.fc = nn.Identity() self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en') self.classifier = nn.Sequential( nn.Linear(2048 + 768, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) # Updated to 434 classes ) def forward(self, image, input_ids, attention_mask): image_features = self.image_encoder(image) text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) text_features = text_output.last_hidden_state[:, 0, :] combined_features = torch.cat((image_features, text_features), dim=1) output = self.classifier(combined_features) return output # --- Data Augmentation Setup --- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load the label-to-class mapping from your Hugging Face repository label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" label_to_class = requests.get(label_map_url).json() # Load your custom model from Hugging Face model = FineGrainedClassifier(num_classes=len(label_to_class)) checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth" checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu')) # Strip the "module." prefix from the keys in the state_dict if they exist new_state_dict = {} for k, v in checkpoint.items(): if k.startswith("module."): new_state_dict[k[7:]] = v # Remove "module." prefix else: new_state_dict[k] = v model.load_state_dict(new_state_dict) # Load the tokenizer from Jina tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") def load_image(image): """ Preprocess the uploaded image. """ image = transform(image) image = image.unsqueeze(0) # Add batch dimension return image def predict(image, title, threshold=0.7): """ Predict the top 3 categories for the given image and title. Includes "Others" if the confidence of the top prediction is below the threshold. """ # Preprocess the image image = load_image(image) # Tokenize the title title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt') input_ids = title_encoding['input_ids'] attention_mask = title_encoding['attention_mask'] # Predict model.eval() with torch.no_grad(): output = model(image, input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.nn.functional.softmax(output, dim=1) top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) # Map the top 3 indices to class names top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] # Check if the highest probability is below the threshold if top3_probabilities[0][0].item() < threshold: top3_classes.insert(0, "Others") top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) # Prepare the output as a dictionary results = {} for i in range(len(top3_classes)): results[top3_classes[i]] = top3_probabilities[0][i].item() return results # Define the Gradio interface title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...") image_input = gr.inputs.Image(type="pil", label="Upload Image") output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities") gr.Interface( fn=predict, inputs=[image_input, title_input], outputs=output, title="Ecommerce Classifier", description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.", ).launch()