FourierVision / cifar10_fouriervision_gradioapp.py
GenAIDevTOProd's picture
Upload 2 files
0129c5b verified
# -*- coding: utf-8 -*-
"""Cifar10-FourierVision-GradioApp.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1uw8cWaCxnSHf2CYhgeF_HYYdMeGP3odV
"""
import gradio as gr
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# CIFAR-10 class names
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Load ResNet18 model and adapt final layer for CIFAR-10
resnet18 = models.resnet18(pretrained=False)
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer
resnet18.load_state_dict(torch.load("/content/sample_data/resnet18_fft_cifar10.pth", map_location=torch.device('cpu')))
resnet18.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet18 expects 224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# FFT Visualizer
def apply_fft_visualization(image: Image.Image):
img_np = np.array(image.resize((32, 32))) / 255.0
fft_images = []
for i in range(3):
channel = img_np[:, :, i]
fft = np.fft.fft2(channel)
fft_shift = np.fft.fftshift(fft)
magnitude = np.log1p(np.abs(fft_shift))
fft_images.append(magnitude)
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
for i in range(3):
axs[i].imshow(fft_images[i], cmap='inferno')
axs[i].set_title(['Red', 'Green', 'Blue'][i])
axs[i].axis('off')
plt.tight_layout()
return fig
# Prediction Function
def predict(img: Image.Image, mode="Raw"):
if mode == "FFT":
return None, apply_fft_visualization(img)
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = resnet18(img_tensor)
_, predicted = torch.max(outputs, 1)
label = classes[predicted.item()]
return label, None
# Gradio App
gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(["Raw", "FFT"], label="Mode", value="Raw")
],
outputs=[
gr.Label(label="Prediction"),
gr.Plot(label="FFT Visualization")
],
title="CIFAR-10 Visual Analyzer (ResNet18)",
description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images."
).launch()