import gradio as gr import torch import torch.nn.functional as F from torch import nn from torchvision.models import resnext101_64x4d from torchvision import transforms MODEL_NAME = 'ResNeXt-101-64x4d' DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] model = resnext101_64x4d() model.fc = nn.Linear(model.fc.in_features, 88) if torch.cuda.is_available(): model.load_state_dict(torch.load(MODEL_NAME+'-model-1.pt')) else: model.load_state_dict(torch.load(MODEL_NAME+'-model-1.pt', map_location=torch.device('cpu'))) model = model.to(DEVICE) labels = [...] predictTransform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) def predict(img): img = predictTransform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): model.eval() prediction = F.softmax(model(img)[0], dim=0) confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} return confidences title = "Plant Disease Classifier" description = "Please upload a photo containing a plant leaf." # iface = gr.Interface(predict, # inputs=gr.Image(), # outputs=gr.Label(num_top_classes=7), # live=True, # title=title, # description=description).launch() iface = gr.Interface( fn=predict, inputs="image", outputs=["text"], examples=[ ['examples/PotatoEarlyBlight4.JPG'], ['examples/TomatoYellowCurlVirus4.JPG'], ]) iface.launch()