Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from unet import UNet | |
from torchvision import transforms | |
from torchvision.transforms.functional import to_tensor, to_pil_image | |
import matplotlib.pyplot as plt | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
device = torch.device(device) | |
# Load the trained model | |
model_path = 'cityscapes_dataUNet.pth' | |
num_classes = 10 | |
model = UNet(num_classes=num_classes) | |
model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) | |
model.to(device) | |
model.eval() | |
# Define the prediction function that takes an input image and returns the segmented image | |
def predict_segmentation(image): | |
print(device) | |
# Convert the input image to a PyTorch tensor and normalize it | |
image = Image.fromarray(image, 'RGB') | |
# image = transforms.functional.resize(image, (256, 256)) | |
image = to_tensor(image).unsqueeze(0) | |
image = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image) | |
image=image.to(device) | |
print("input shape",image.shape) # input shape torch.Size([1, 3, 256, 256]) | |
print("input dtype",image.dtype) # input dtype torch.float32 | |
# Make a prediction using the model | |
with torch.no_grad(): | |
print(image.shape, image.dtype) # torch.Size([1, 3, 256, 256]) torch.float32 | |
output= model(image) | |
# print(output.shape,output.dtype) # torch.Size([1, 10, 256, 256]) torch.float32 | |
predicted_class = torch.argmax(output, dim=1).squeeze(0) | |
predicted_class = predicted_class.cpu().detach().numpy().astype(np.uint8) | |
print(predicted_class.dtype , predicted_class.shape) # int64 (256, 256) | |
# Visualize the predicted segmentation mask | |
plt.imshow(predicted_class) | |
plt.show() | |
# Apply the inverse transform to convert the normalized image back to RGB | |
# predicted_class = inverse_transform(torch.from_numpy(predicted_class)) | |
print("predicted class ",predicted_class) | |
predicted_class = to_pil_image(predicted_class) | |
# Return the predicted segmentation | |
return predicted_class | |
# Define the Gradio interface | |
input_image = gr.inputs.Image() | |
output_image = gr.outputs.Image(type='numpy') | |
gr.Interface(fn=predict_segmentation, inputs=input_image, outputs=output_image, | |
title='UNet Image Segmentation', | |
description='Segment an image using a UNet model').launch() |