kunwarsaaim's picture
intial commit
05d6a45
import gradio as gr
from torch.nn.functional import softmax
import torch
from transformers import ViTFeatureExtractor
from transformers import MobileViTFeatureExtractor
from transformers import MobileViTForImageClassification
from transformers import ViTForImageClassification
def predict(model_type, inp):
if model_type == "ViT":
model_name_or_path = './models/vit-base-garbage/'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(model_name_or_path)
elif model_type == "MobileViT":
model_name_or_path = './models/apple/mobilevit-small-garbage/'
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_name_or_path)
model = MobileViTForImageClassification.from_pretrained(model_name_or_path)
inputs = feature_extractor(inp, return_tensors="pt")
LABELS = list(model.config.label2id.keys())
with torch.no_grad():
logits = model(**inputs)
print(logits[0])
probability = torch.nn.functional.softmax(logits[0], dim=-1)
confidences = {LABELS[i]:(float(probability[0][i])) for i in range(6)}
# print(confidences)
return confidences
demo = gr.Interface(fn=predict,
inputs=[gr.Dropdown(["ViT", "MobileViT"], label="Model Name", value='ViT'),gr.inputs.Image(type="pil")],
outputs=gr.outputs.Label(num_top_classes=3),
examples=[["ViT","paper567.jpg"],["ViT","trash105.jpg"],["ViT","plastic202.jpg"],["MobileViT","metal382.jpg"]],
)
demo.launch()