import torch, torchvision from torchvision import transforms import numpy as np import gradio as gr from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image import gradio as gr import os from helper import CifarAlbumentations, get_train_transforms, get_test_transforms from resnet import CustomResNet config = { 'batch_size': 512, 'data_dir': './data', 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 'num_classes': 10, 'lr': 0.01, 'max_lr': 0.1, 'max_lr_epoch': 5, 'dropout' : 0.01, 'LEARNING_RATE' : 1e-5, 'WEIGHT_DECAY' : 1e-4, 'NUM_EPOCHS' : 100 } train_transforms = get_train_transforms() test_transforms = get_test_transforms() model = CustomResNet(config, config['dropout'], train_transforms, test_transforms) model.load_state_dict(torch.load("resnet_model_v7.pth", map_location=torch.device('cpu')), strict=False) model.setup(stage="test") inv_normalize = transforms.Normalize( mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], std=[1/0.23, 1/0.23, 1/0.23] ) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') classes_for_categorize = {0: 'plane', 1: 'car', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'} def inference(input_img, transparency=0.5, target_layer_number=-1, top_classes=10): transform = transforms.ToTensor() org_img = input_img input_img = transform(input_img) input_img = input_img input_img = input_img.unsqueeze(0) outputs = model(input_img) softmax = torch.nn.Softmax(dim=0) o = softmax(outputs.flatten()) confidences = {classes[i]: float(o[i]) for i in range(10)} sorted_classes = sorted(confidences.items(), key=lambda x: x[1], reverse=True) top_classes = sorted_classes[:top_classes] top_classes_dict = {cls: conf for cls, conf in top_classes} _, prediction = torch.max(outputs, 1) target_layers = [model.conv2] cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) grayscale_cam = cam(input_tensor=input_img, targets=None) grayscale_cam = grayscale_cam[0, :] img = input_img.squeeze(0) img = inv_normalize(img) rgb_img = np.transpose(img, (1, 2, 0)) rgb_img = rgb_img.numpy() visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) return top_classes_dict, visualization def show_misclassified_images_wrap(num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5): transparency = float(transparency) num_images = int(num_images) if use_gradcam == "Yes": use_gradcam = True else: use_gradcam = False return model.show_misclassified_images(num_images, use_gradcam, gradcam_layer, transparency) title = "CIFAR10 Image Classification" description = "Upload an Image or Choose from Examples Below" images_folder = "examples" # Define the examples list with full paths examples = [[os.path.join(images_folder, "plane.jpg"), 0.5, -1,10], [os.path.join(images_folder, "car.jpg"), 0.5, -1,5], [os.path.join(images_folder, "bird.jpg"), 0.5, -1,3], [os.path.join(images_folder, "cat.jpg"), 0.5, -1, 5], [os.path.join(images_folder, "deer.jpg"), 0.5, -1,7], [os.path.join(images_folder, "dog.jpg"), 0.5, -1,6], [os.path.join(images_folder, "frog.jpg"), 0.5, -1,2], [os.path.join(images_folder, "horse.jpg"), 0.5, -1,10], [os.path.join(images_folder, "ship.jpg"), 0.5, -1,10], [os.path.join(images_folder, "truck.jpeg"), 0.5, -1,10]] # Create the input interface with the modified template input_interface = gr.Interface( inference, inputs=[ gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value=0.5, label="Transparency", info="Set the Opacity of CAM"), gr.Slider(-2, -1, value=-2, step=1, label="Network Layer", info="GradCAM Network Layer"), gr.Slider(1, 10, step=1, value=10, label="Top Classes", info="How many top classes do you want to view") ], outputs=[ gr.Label(num_top_classes=10), gr.Image(shape=(32, 32), label="Model Prediction").style(width=300, height=300) ], description=description, examples=[[f'examples/{k}.jpg'] for k in classes_for_categorize.values()],) mislclassified_description = "Misclassified Image for Custom Resnet" icon_html = '' title_with_icon = f"""