Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- app.py +69 -0
- cityscapes_dataUNet.pth +3 -0
- cityscapes_dataUnet.pth +3 -0
- unet.py +57 -0
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from unet import UNet
|
7 |
+
from torchvision import transforms
|
8 |
+
from torchvision.transforms.functional import to_tensor, to_pil_image
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
12 |
+
device = torch.device(device)
|
13 |
+
# Load the trained model
|
14 |
+
model_path = 'cityscapes_dataUNet.pth'
|
15 |
+
num_classes = 10
|
16 |
+
model = UNet(num_classes=num_classes)
|
17 |
+
model.load_state_dict(torch.load(model_path))
|
18 |
+
model.to(device)
|
19 |
+
model.eval()
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# Define the prediction function that takes an input image and returns the segmented image
|
25 |
+
def predict_segmentation(image):
|
26 |
+
print(device)
|
27 |
+
# Convert the input image to a PyTorch tensor and normalize it
|
28 |
+
image = Image.fromarray(image, 'RGB')
|
29 |
+
# image = transforms.functional.resize(image, (256, 256))
|
30 |
+
image = to_tensor(image).unsqueeze(0)
|
31 |
+
image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image)
|
32 |
+
image=image.to(device)
|
33 |
+
|
34 |
+
print("input shape",image.shape) # input shape torch.Size([1, 3, 256, 256])
|
35 |
+
print("input dtype",image.dtype) # input dtype torch.float32
|
36 |
+
|
37 |
+
# Make a prediction using the model
|
38 |
+
with torch.no_grad():
|
39 |
+
|
40 |
+
print(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32
|
41 |
+
|
42 |
+
output= model(image)
|
43 |
+
# print(output.shape,output.dtype) # torch.Size([1, 10, 256, 256]) torch.float32
|
44 |
+
|
45 |
+
predicted_class = torch.argmax(output, dim=1).squeeze(0)
|
46 |
+
predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8)
|
47 |
+
print(predicted_class.dtype , predicted_class.shape) # int64 (256, 256)
|
48 |
+
|
49 |
+
|
50 |
+
# Visualize the predicted segmentation mask
|
51 |
+
plt.imshow(predicted_class)
|
52 |
+
plt.show()
|
53 |
+
# Apply the inverse transform to convert the normalized image back to RGB
|
54 |
+
# predicted_class = inverse_transform(torch.from_numpy(predicted_class))
|
55 |
+
|
56 |
+
print("predicted class ",predicted_class)
|
57 |
+
|
58 |
+
predicted_class = to_pil_image(predicted_class)
|
59 |
+
|
60 |
+
# Return the predicted segmentation
|
61 |
+
return predicted_class
|
62 |
+
|
63 |
+
# Define the Gradio interface
|
64 |
+
input_image = gr.inputs.Image()
|
65 |
+
output_image = gr.outputs.Image(type='numpy')
|
66 |
+
|
67 |
+
gr.Interface(fn=predict_segmentation, inputs=input_image, outputs=output_image,
|
68 |
+
title='UNet Image Segmentation',
|
69 |
+
description='Segment an image using a UNet model').launch()
|
cityscapes_dataUNet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f41259bc794efd66defb8c0029ea054dcf4bb0b98dcc3229e36267782132315
|
3 |
+
size 138216745
|
cityscapes_dataUnet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a9d4dec513aa5ef4a2873cfc1eba34fb5e7a15b0dad50041043ff80d23b2b30
|
3 |
+
size 138216745
|
unet.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
class UNet(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, num_classes):
|
8 |
+
super(UNet, self).__init__()
|
9 |
+
self.num_classes = num_classes
|
10 |
+
self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
|
11 |
+
self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
|
12 |
+
self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
|
13 |
+
self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
|
14 |
+
self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
|
15 |
+
self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
|
16 |
+
self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
|
17 |
+
self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
|
18 |
+
self.middle = self.conv_block(in_channels=512, out_channels=1024)
|
19 |
+
self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
|
20 |
+
self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
|
21 |
+
self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
|
22 |
+
self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
|
23 |
+
self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
|
24 |
+
self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
|
25 |
+
self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
|
26 |
+
self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
|
27 |
+
self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)
|
28 |
+
|
29 |
+
def conv_block(self, in_channels, out_channels):
|
30 |
+
block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
|
31 |
+
nn.ReLU(),
|
32 |
+
nn.BatchNorm2d(num_features=out_channels),
|
33 |
+
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.BatchNorm2d(num_features=out_channels))
|
36 |
+
return block
|
37 |
+
|
38 |
+
def forward(self, X):
|
39 |
+
contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
|
40 |
+
contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
|
41 |
+
contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
|
42 |
+
contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
|
43 |
+
contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
|
44 |
+
contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
|
45 |
+
contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
|
46 |
+
contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
|
47 |
+
middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
|
48 |
+
expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32]
|
49 |
+
expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
|
50 |
+
expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
|
51 |
+
expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
|
52 |
+
expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
|
53 |
+
expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
|
54 |
+
expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
|
55 |
+
expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
|
56 |
+
output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
|
57 |
+
return output_out
|