import gradio as gr import torch import torchvision.transforms as transforms import torchvision.models as models import numpy as np import matplotlib.pyplot as plt from PIL import Image # CIFAR-10 class names classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # Load ResNet18 model and adapt final layer for CIFAR-10 resnet18 = models.resnet18(pretrained=True) resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer resnet18.load_state_dict(torch.load("resnet18_fft_cifar10.pth", map_location=torch.device('cpu'))) resnet18.eval() # Image transform transform = transforms.Compose([ transforms.Resize((224, 224)), # ResNet18 expects 224x224 transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # FFT Visualizer def apply_fft_visualization(image: Image.Image): img_np = np.array(image.resize((32, 32))) / 255.0 fft_images = [] for i in range(3): channel = img_np[:, :, i] fft = np.fft.fft2(channel) fft_shift = np.fft.fftshift(fft) magnitude = np.log1p(np.abs(fft_shift)) fft_images.append(magnitude) fig, axs = plt.subplots(1, 3, figsize=(12, 4)) for i in range(3): axs[i].imshow(fft_images[i], cmap='inferno') axs[i].set_title(['Red', 'Green', 'Blue'][i]) axs[i].axis('off') plt.tight_layout() return fig # Prediction Function def predict(img: Image.Image, mode="Raw"): if mode == "FFT": return None, apply_fft_visualization(img) img_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = resnet18(img_tensor) _, predicted = torch.max(outputs, 1) label = classes[predicted.item()] return label, None # Gradio App gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Radio(["Raw", "FFT"], label="Mode", value="Raw") ], outputs=[ gr.Label(label="Prediction"), gr.Plot(label="FFT Visualization") ], title="CIFAR-10 Visual Analyzer (ResNet18)", description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images." ).launch()