Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""Cifar10-FourierVision-GradioApp.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1uw8cWaCxnSHf2CYhgeF_HYYdMeGP3odV | |
""" | |
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
# CIFAR-10 class names | |
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck'] | |
# Load ResNet18 model and adapt final layer for CIFAR-10 | |
resnet18 = models.resnet18(pretrained=False) | |
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer | |
resnet18.load_state_dict(torch.load("/content/sample_data/resnet18_fft_cifar10.pth", map_location=torch.device('cpu'))) | |
resnet18.eval() | |
# Image transform | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # ResNet18 expects 224x224 | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# FFT Visualizer | |
def apply_fft_visualization(image: Image.Image): | |
img_np = np.array(image.resize((32, 32))) / 255.0 | |
fft_images = [] | |
for i in range(3): | |
channel = img_np[:, :, i] | |
fft = np.fft.fft2(channel) | |
fft_shift = np.fft.fftshift(fft) | |
magnitude = np.log1p(np.abs(fft_shift)) | |
fft_images.append(magnitude) | |
fig, axs = plt.subplots(1, 3, figsize=(12, 4)) | |
for i in range(3): | |
axs[i].imshow(fft_images[i], cmap='inferno') | |
axs[i].set_title(['Red', 'Green', 'Blue'][i]) | |
axs[i].axis('off') | |
plt.tight_layout() | |
return fig | |
# Prediction Function | |
def predict(img: Image.Image, mode="Raw"): | |
if mode == "FFT": | |
return None, apply_fft_visualization(img) | |
img_tensor = transform(img).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = resnet18(img_tensor) | |
_, predicted = torch.max(outputs, 1) | |
label = classes[predicted.item()] | |
return label, None | |
# Gradio App | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Radio(["Raw", "FFT"], label="Mode", value="Raw") | |
], | |
outputs=[ | |
gr.Label(label="Prediction"), | |
gr.Plot(label="FFT Visualization") | |
], | |
title="CIFAR-10 Visual Analyzer (ResNet18)", | |
description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images." | |
).launch() | |