import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import json import requests from PIL import Image from torchvision import transforms import urllib.request # Load the label-to-class mapping from your Hugging Face repository label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" label_to_class = requests.get(label_map_url).json() # Load the model and tokenizer from your Hugging Face repository model = AutoModel.from_pretrained("Maverick98/EcommerceClassifier") tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") # Define image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_image(image_path_or_url): """ Load an image from a URL or local path and preprocess it. """ if image_path_or_url.startswith("http"): with urllib.request.urlopen(image_path_or_url) as url: image = Image.open(url).convert('RGB') else: image = Image.open(image_path_or_url).convert('RGB') image = transform(image) image = image.unsqueeze(0) # Add batch dimension return image def predict(image_path_or_url, title, threshold=0.7): """ Predict the top 3 categories for the given image and title. Includes "Others" if the confidence of the top prediction is below the threshold. """ # Preprocess the image image = load_image(image_path_or_url) # Tokenize the title title_encoding = tokenizer(title, padding='max_length', max_length=32, truncation=True, return_tensors='pt') input_ids = title_encoding['input_ids'] attention_mask = title_encoding['attention_mask'] # Predict model.eval() with torch.no_grad(): output = model(image, input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.nn.functional.softmax(output, dim=1) top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) # Map the top 3 indices to class names top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] # Check if the highest probability is below the threshold if top3_probabilities[0][0].item() < threshold: top3_classes.insert(0, "Others") top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) # Prepare the output as a dictionary results = {} for i in range(len(top3_classes)): results[top3_classes[i]] = top3_probabilities[0][i].item() return results # Define the Gradio interface title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...") image_input = gr.inputs.Textbox(label="Image URL or Path", placeholder="Enter image URL or local path here...") output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities") gr.Interface( fn=predict, inputs=[image_input, title_input], outputs=output, title="Ecommerce Classifier", description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.", ).launch()