Manthanx commited on
Commit
ba20db4
·
1 Parent(s): aa19069

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +69 -0
  2. cityscapes_dataUNet.pth +3 -0
  3. cityscapes_dataUnet.pth +3 -0
  4. unet.py +57 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from unet import UNet
7
+ from torchvision import transforms
8
+ from torchvision.transforms.functional import to_tensor, to_pil_image
9
+ import matplotlib.pyplot as plt
10
+
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+ device = torch.device(device)
13
+ # Load the trained model
14
+ model_path = 'cityscapes_dataUNet.pth'
15
+ num_classes = 10
16
+ model = UNet(num_classes=num_classes)
17
+ model.load_state_dict(torch.load(model_path))
18
+ model.to(device)
19
+ model.eval()
20
+
21
+
22
+
23
+
24
+ # Define the prediction function that takes an input image and returns the segmented image
25
+ def predict_segmentation(image):
26
+ print(device)
27
+ # Convert the input image to a PyTorch tensor and normalize it
28
+ image = Image.fromarray(image, 'RGB')
29
+ # image = transforms.functional.resize(image, (256, 256))
30
+ image = to_tensor(image).unsqueeze(0)
31
+ image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image)
32
+ image=image.to(device)
33
+
34
+ print("input shape",image.shape) # input shape torch.Size([1, 3, 256, 256])
35
+ print("input dtype",image.dtype) # input dtype torch.float32
36
+
37
+ # Make a prediction using the model
38
+ with torch.no_grad():
39
+
40
+ print(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32
41
+
42
+ output= model(image)
43
+ # print(output.shape,output.dtype) # torch.Size([1, 10, 256, 256]) torch.float32
44
+
45
+ predicted_class = torch.argmax(output, dim=1).squeeze(0)
46
+ predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8)
47
+ print(predicted_class.dtype , predicted_class.shape) # int64 (256, 256)
48
+
49
+
50
+ # Visualize the predicted segmentation mask
51
+ plt.imshow(predicted_class)
52
+ plt.show()
53
+ # Apply the inverse transform to convert the normalized image back to RGB
54
+ # predicted_class = inverse_transform(torch.from_numpy(predicted_class))
55
+
56
+ print("predicted class ",predicted_class)
57
+
58
+ predicted_class = to_pil_image(predicted_class)
59
+
60
+ # Return the predicted segmentation
61
+ return predicted_class
62
+
63
+ # Define the Gradio interface
64
+ input_image = gr.inputs.Image()
65
+ output_image = gr.outputs.Image(type='numpy')
66
+
67
+ gr.Interface(fn=predict_segmentation, inputs=input_image, outputs=output_image,
68
+ title='UNet Image Segmentation',
69
+ description='Segment an image using a UNet model').launch()
cityscapes_dataUNet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f41259bc794efd66defb8c0029ea054dcf4bb0b98dcc3229e36267782132315
3
+ size 138216745
cityscapes_dataUnet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a9d4dec513aa5ef4a2873cfc1eba34fb5e7a15b0dad50041043ff80d23b2b30
3
+ size 138216745
unet.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ class UNet(nn.Module):
6
+
7
+ def __init__(self, num_classes):
8
+ super(UNet, self).__init__()
9
+ self.num_classes = num_classes
10
+ self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
11
+ self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
12
+ self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
13
+ self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
14
+ self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
15
+ self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
16
+ self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
17
+ self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
18
+ self.middle = self.conv_block(in_channels=512, out_channels=1024)
19
+ self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
20
+ self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
21
+ self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
22
+ self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
23
+ self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
24
+ self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
25
+ self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
26
+ self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
27
+ self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)
28
+
29
+ def conv_block(self, in_channels, out_channels):
30
+ block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
31
+ nn.ReLU(),
32
+ nn.BatchNorm2d(num_features=out_channels),
33
+ nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
34
+ nn.ReLU(),
35
+ nn.BatchNorm2d(num_features=out_channels))
36
+ return block
37
+
38
+ def forward(self, X):
39
+ contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
40
+ contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
41
+ contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
42
+ contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
43
+ contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
44
+ contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
45
+ contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
46
+ contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
47
+ middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
48
+ expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32]
49
+ expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
50
+ expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
51
+ expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
52
+ expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
53
+ expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
54
+ expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
55
+ expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
56
+ output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
57
+ return output_out