Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from torchvision import models, transforms | |
| import time | |
| import os | |
| import copy | |
| import pickle | |
| from PIL import Image | |
| import datetime | |
| import gdown | |
| import urllib.request | |
| import gradio as gr | |
| import markdown | |
| # load model state and class names from gdrive | |
| # issue accessing the link. updated permissions | |
| #https://drive.google.com/file/d/1m9C-WMfKRDCmScxTh8JmcoFtymxAqjS3/view?usp=sharing | |
| # url = 'https://drive.google.com/uc?id=1m9C-WMfKRDCmScxTh8JmcoFtymxAqjS3' | |
| path_class_names = "./class_names_restnet_leeds_butterfly.pkl" | |
| # MAX_TRIES = 10 | |
| # i = 0 | |
| # while(i<MAX_TRIES): | |
| # try: | |
| # gdown.download(url, path_class_names, quiet=False) | |
| # break | |
| # except: | |
| # pass | |
| # i += 1 | |
| # gdown.download(url, path_class_names, quiet=False) | |
| #https://drive.google.com/file/d/1qxaWnYwLIwWGrGg9uehG7h2W227SXGKq/view?usp=sharing | |
| # url = 'https://drive.google.com/uc?id=1qxaWnYwLIwWGrGg9uehG7h2W227SXGKq' | |
| path_model = "./model_state_restnet_leeds_butterfly.pth" | |
| # gdown.download(url, path_model, quiet=False) | |
| # i = 0 | |
| # while(i<MAX_TRIES): | |
| # try: | |
| # gdown.download(url, path_model, quiet=False) | |
| # break | |
| # except: | |
| # pass | |
| # i += 1 | |
| # example images | |
| url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Red_postman_butterfly_%28Heliconius_erato%29.jpg/1599px-Red_postman_butterfly_%28Heliconius_erato%29.jpg" | |
| path_input = "./h_erato.jpg" | |
| urllib.request.urlretrieve(url, filename=path_input) | |
| url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/63/Monarch_In_May.jpg/1024px-Monarch_In_May.jpg" | |
| path_input = "./d_plexippus.jpg" | |
| urllib.request.urlretrieve(url, filename=path_input) | |
| url = "https://drive.google.com/uc?id=1A7WgDrQ_RLO6JOQiYhkH_hj_EKcbpmOl" | |
| path_input = "./v_cardui.jpg" | |
| urllib.request.urlretrieve(url, filename=path_input) | |
| url = "https://drive.google.com/uc?id=1CiWShQYIm2N0fkVaWJpftlXZFqwjsXhA" | |
| path_input = "./p_cresphontes.jpg" | |
| urllib.request.urlretrieve(url, filename=path_input) | |
| url = "https://drive.google.com/uc?id=1r8rbkUwTSIZL0MQVgU-WjDGwvLXuwYPG" | |
| path_input = "./p_rapae.jpg" | |
| urllib.request.urlretrieve(url, filename=path_input) | |
| # normalisation | |
| data_transforms_test = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| class_names = pickle.load(open(path_class_names, "rb")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model_ft = models.resnet18(pretrained=True) | |
| num_ftrs = model_ft.fc.in_features | |
| model_ft.fc = nn.Linear(num_ftrs, len(class_names)) | |
| model_ft = model_ft.to(device) | |
| model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device))) | |
| # Proper labeling | |
| id_to_name = { | |
| '001_Danaus Plexippus': 'Danaus plexippus - Monarch', | |
| '002_Heliconius Charitonius': 'Heliconius charitonius - Zebra Longwing', | |
| '003_Heliconius Erato': 'Heliconius erato - Red Postman', | |
| '004_Junonia Coenia': 'Junonia coenia - Common Buckeye', | |
| '005_Lycaena Phlaeas': 'Lycaena phlaeas - Small Copper', | |
| '006_Nymphalis Antiopa': 'Nymphalis antiopa - Mourning Cloak', | |
| '007_Papilio Cresphontes': 'Papilio cresphontes - Giant Swallowtail', | |
| '008_Pieris Rapae': 'Pieris rapae - Cabbage White', | |
| '009_Vanessa Atalanta': 'Vanessa atalanta - Red Admiral', | |
| '010_Vanessa Cardui': 'Vanessa cardui - Painted Lady', | |
| } | |
| def do_inference(img): | |
| img_t = data_transforms_test(img) | |
| batch_t = torch.unsqueeze(img_t, 0) | |
| model_ft.eval() | |
| # We don't need gradients for test, so wrap in | |
| # no_grad to save memory | |
| with torch.no_grad(): | |
| batch_t = batch_t.to(device) | |
| # forward propagation | |
| output = model_ft( batch_t) | |
| # get prediction | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) | |
| probs = probs.cpu().numpy()[0] | |
| probs = probs[output] | |
| labels = np.array(class_names)[output] | |
| if(probs[0] < 0.5): | |
| return {"No butterfly":1.0} | |
| else: | |
| return {id_to_name[labels[i]]: round(float(probs[i]),2) for i in range(len(labels))} | |
| im = gr.inputs.Image(shape=(512, 512), image_mode='RGB', | |
| invert_colors=False, source="upload", | |
| type="pil") | |
| title = "Butterfly Classification Demo" | |
| description = "A pretrained ResNet18 CNN trained on the Leeds Butterfly Dataset. Libraries: PyTorch, Gradio." | |
| examples = [['./h_erato.jpg'],['d_plexippus.jpg'],['v_cardui.jpg'],['p_cresphontes.jpg'],['p_rapae.jpg']] | |
| article_text = markdown.markdown(''' | |
| <h1 style="color:white">PyTorch image classification - A pretrained ResNet18 CNN trained on the <a href="http://www.josiahwang.com/dataset/leedsbutterfly/" target="_blank">Leeds Butterfly Dataset</a></h1> | |
| <br> | |
| <p>The Leeds Butterfly Dataset consists of 832 images in 10 classes:</p> | |
| <ul> | |
| <li>Danaus plexippus - Monarch</li> | |
| <li>Heliconius charitonius - Zebra Longwing</li> | |
| <li>Heliconius erato - Red Postman</li> | |
| <li>Lycaena phlaeas - Small Copper</li> | |
| <li>Junonia coenia - Common Buckeye</li> | |
| <li>Nymphalis antiopa - Mourning Cloak</li> | |
| <li>Papilio cresphontes - Giant Swallowtail</li> | |
| <li>Pieris rapae - Cabbage White</li> | |
| <li>Vanessa atalanta - Red Admiral</li> | |
| <li>Vanessa cardui - Painted Lady</li> | |
| </ul> | |
| <br> | |
| <p>Part of a dissertation project. Author: <a href="https://github.com/ttheland" target="_blank">ttheland</a></p> | |
| ''') | |
| iface = gr.Interface( | |
| do_inference, | |
| im, | |
| gr.outputs.Label(num_top_classes=3), | |
| live=False, | |
| interpretation=None, | |
| title=title, | |
| description=description, | |
| article= article_text, | |
| examples=examples, | |
| theme="dark-peach" | |
| ) | |
| iface.test_launch() | |
| iface.launch(share=True, enable_queue=True) |