import gradio as gr from torch.nn.functional import softmax import torch from transformers import ViTFeatureExtractor from transformers import MobileViTFeatureExtractor from transformers import MobileViTForImageClassification from transformers import ViTForImageClassification def predict(model_type, inp): if model_type == "ViT": model_name_or_path = './models/vit-base-garbage/' feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path) model = ViTForImageClassification.from_pretrained(model_name_or_path) elif model_type == "MobileViT": model_name_or_path = './models/apple/mobilevit-small-garbage/' feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_name_or_path) model = MobileViTForImageClassification.from_pretrained(model_name_or_path) inputs = feature_extractor(inp, return_tensors="pt") LABELS = list(model.config.label2id.keys()) with torch.no_grad(): logits = model(**inputs) print(logits[0]) probability = torch.nn.functional.softmax(logits[0], dim=-1) confidences = {LABELS[i]:(float(probability[0][i])) for i in range(6)} # print(confidences) return confidences demo = gr.Interface(fn=predict, inputs=[gr.Dropdown(["ViT", "MobileViT"], label="Model Name", value='ViT'),gr.inputs.Image(type="pil")], outputs=gr.outputs.Label(num_top_classes=3), examples=[["ViT","paper567.jpg"],["ViT","trash105.jpg"],["ViT","plastic202.jpg"],["MobileViT","metal382.jpg"]], ) demo.launch()