File size: 1,398 Bytes
2514593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aecbf3a
2514593
 
 
3a51bc4
 
 
 
 
 
 
 
aecbf3a
2514593
 
aecbf3a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torch import nn
import torchvision
import gradio as gr

# Define and load my resnet50 model
model = torchvision.models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    # Add dropout layer with 50% probability
    nn.Dropout(0.5),
    # Add a linear layer in order to deal with 5 classes
    nn.Linear(num_ftrs, 5),
)

model.load_state_dict(
    torch.load("model/final_model_state_dict.pth", map_location=torch.device("cpu"))
)
model.eval()

# Define the labels
labels = ["bird", "cat", "dog", "horse", "sheep"]


# Define the predict function
def predict(inp):
    inp = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
    with torch.no_grad():
        prediction = model(inp)
        # Map prediction to label
        prediction = labels[prediction.argmax()]
    return prediction


# Define the gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    examples=[
        ["input_imgs/bird.jpeg"],
        ["input_imgs/cat.jpeg"],
        ["input_imgs/dog.jpeg"],
        ["input_imgs/horse.jpeg"],
        ["input_imgs/sheep.jpeg"],
    ],
    title="Image Object Classifier",
    description="This is a demo of a resnet50 model trained on COCO dataset, which can classify 5 classes: bird, cat, dog, horse, sheep.",
)

if __name__ == "__main__":
    interface.launch()