lopesdri commited on
Commit
ad03cdd
·
1 Parent(s): 3906d80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -16
app.py CHANGED
@@ -1,30 +1,22 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision.transforms import ToTensor
4
- from PIL import Image
5
-
6
  # Load your PyTorch model
7
- model = torch.load("model.pth",map_location=torch.device('cpu'))
 
 
8
 
9
  classes = ['bom', 'ruim']
10
 
11
- def preprocess(image):
12
- image = image.resize((224, 224))
13
- image_tensor = ToTensor()(image)
14
- image_tensor = image_tensor.unsqueeze(0)
15
- return image_tensor
16
-
17
  # Define the function for image classification
18
  def classify_image(image):
19
- image_tensor = preprocess(image)
20
 
21
  # Perform inference using your PyTorch model
22
  with torch.no_grad():
23
  model.eval()
24
  outputs = model(image_tensor)
25
 
26
- predicted_labels = outputs.argmax(dim=1).tolist()
27
- return predicted_labels
 
28
 
29
  # Define the Gradio interface
30
  inputs = gr.Image()
@@ -32,5 +24,4 @@ outputs = gr.Label(num_top_classes=1)
32
 
33
  interface = gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs)
34
 
35
- # Launch the interface
36
  interface.launch(debug=True)
 
 
 
 
 
 
1
  # Load your PyTorch model
2
+ model = resnet50(pretrained=False)
3
+ model.fc = nn.Linear(model.fc.in_features, 2)
4
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')))
5
 
6
  classes = ['bom', 'ruim']
7
 
 
 
 
 
 
 
8
  # Define the function for image classification
9
  def classify_image(image):
10
+ image_tensor = ToTensor()(image).unsqueeze(0)
11
 
12
  # Perform inference using your PyTorch model
13
  with torch.no_grad():
14
  model.eval()
15
  outputs = model(image_tensor)
16
 
17
+ _, predicted = torch.max(outputs.data, 1)
18
+ return classes[predicted.item()]
19
+
20
 
21
  # Define the Gradio interface
22
  inputs = gr.Image()
 
24
 
25
  interface = gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs)
26
 
 
27
  interface.launch(debug=True)