Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torchvision.models import resnext101_64x4d | |
from torchvision import transforms | |
MODEL_NAME = 'ResNeXt-101-64x4d' | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
model = resnext101_64x4d() | |
model.fc = nn.Linear(model.fc.in_features, 88) | |
if torch.cuda.is_available(): | |
model.load_state_dict(torch.load(MODEL_NAME+'-model-1.pt')) | |
else: | |
model.load_state_dict(torch.load(MODEL_NAME+'-model-1.pt', map_location=torch.device('cpu'))) | |
model = model.to(DEVICE) | |
labels = [...] | |
predictTransform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=MEAN, std=STD) | |
]) | |
def predict(img): | |
img = predictTransform(img).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
model.eval() | |
prediction = F.softmax(model(img)[0], dim=0) | |
confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} | |
return confidences | |
title = "Plant Disease Classifier" | |
description = "Please upload a photo containing a plant leaf." | |
# iface = gr.Interface(predict, | |
# inputs=gr.Image(), | |
# outputs=gr.Label(num_top_classes=7), | |
# live=True, | |
# title=title, | |
# description=description).launch() | |
iface = gr.Interface( | |
fn=predict, | |
inputs="image", | |
outputs=["text"], | |
examples=[ | |
['examples/PotatoEarlyBlight4.JPG'], | |
['examples/TomatoYellowCurlVirus4.JPG'], | |
]) | |
iface.launch() |