lopesdri commited on
Commit
acc2d02
·
1 Parent(s): e2cbcc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -1,29 +1,30 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../drive/MyDrive/Colab Notebooks/cats_vs_dogs.ipynb.
 
2
 
3
- # %% auto 0
4
- __all__ = ['learn', 'categories', 'image', 'label', 'examples', 'intf', 'classify_image']
 
 
 
5
 
6
- # %% ../drive/MyDrive/Colab Notebooks/cats_vs_dogs.ipynb 10
7
- from fastai.vision.all import *
8
  import gradio as gr
9
-
10
- def is_cat(x): return x[0].isupper()
11
-
12
-
13
- # %% ../drive/MyDrive/Colab Notebooks/cats_vs_dogs.ipynb 12
14
- learn = load_learner('pets.pkl')
15
-
16
- # %% ../drive/MyDrive/Colab Notebooks/cats_vs_dogs.ipynb 14
17
- categories = ('Dog', 'Cat')
18
-
19
- def classify_image(img):
20
- pred,idx,probs = learn.predict(img)
21
- return dict(zip(categories, map(float,probs)))
22
-
23
- # %% ../drive/MyDrive/Colab Notebooks/cats_vs_dogs.ipynb 15
24
- image = gr.inputs.Image(shape=(192, 192))
25
- label = gr.outputs.Label()
26
- examples = ['dog.jpg']
27
-
28
- intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples)
29
- intf.launch(inline=False)
 
1
+ import torch
2
+ import torchvision
3
 
4
+ model = torchvision.models.resnet50(pretrained=False)
5
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
6
+ model.load_state_dict(torch.load("model.pth"))
7
+ model.to(device)
8
+ model.eval()
9
 
 
 
10
  import gradio as gr
11
+ from PIL import Image
12
+
13
+ # Define the function to make predictions
14
+ def predict(image):
15
+ image = transform(image).unsqueeze(0).to(device)
16
+ model.eval()
17
+ with torch.no_grad():
18
+ output = model(image)
19
+ _, predicted = torch.max(output.data, 1)
20
+ return dataset.classes[predicted.item()]
21
+
22
+ # Define the input and output components
23
+ image_input = gr.inputs.Image(type="pil", label="Upload Image")
24
+ label_output = gr.outputs.Label()
25
+
26
+ # Create the interface
27
+ interface = gr.Interface(fn=predict, inputs=image_input, outputs=label_output)
28
+
29
+ # Launch the interface
30
+ interface.launch()