File size: 2,567 Bytes
0129c5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
"""Cifar10-FourierVision-GradioApp.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1uw8cWaCxnSHf2CYhgeF_HYYdMeGP3odV
"""

import gradio as gr
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# CIFAR-10 class names
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

# Load ResNet18 model and adapt final layer for CIFAR-10
resnet18 = models.resnet18(pretrained=False)
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10)  # Replace final layer
resnet18.load_state_dict(torch.load("/content/sample_data/resnet18_fft_cifar10.pth", map_location=torch.device('cpu')))
resnet18.eval()

# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet18 expects 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# FFT Visualizer
def apply_fft_visualization(image: Image.Image):
    img_np = np.array(image.resize((32, 32))) / 255.0
    fft_images = []
    for i in range(3):
        channel = img_np[:, :, i]
        fft = np.fft.fft2(channel)
        fft_shift = np.fft.fftshift(fft)
        magnitude = np.log1p(np.abs(fft_shift))
        fft_images.append(magnitude)

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    for i in range(3):
        axs[i].imshow(fft_images[i], cmap='inferno')
        axs[i].set_title(['Red', 'Green', 'Blue'][i])
        axs[i].axis('off')
    plt.tight_layout()
    return fig

# Prediction Function
def predict(img: Image.Image, mode="Raw"):
    if mode == "FFT":
        return None, apply_fft_visualization(img)

    img_tensor = transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = resnet18(img_tensor)
        _, predicted = torch.max(outputs, 1)
    label = classes[predicted.item()]
    return label, None

# Gradio App
gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Radio(["Raw", "FFT"], label="Mode", value="Raw")
    ],
    outputs=[
        gr.Label(label="Prediction"),
        gr.Plot(label="FFT Visualization")
    ],
    title="CIFAR-10 Visual Analyzer (ResNet18)",
    description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images."
).launch()