File size: 7,424 Bytes
77a70ee
2898702
ae68879
2898702
ae68879
2898702
ae68879
77a70ee
5e81b01
2898702
 
21a55e5
 
5e81b01
 
2898702
 
 
5013d6a
ae68879
2898702
 
801bd57
2898702
 
 
 
21a55e5
ae68879
2898702
38e5a7e
ae68879
 
da6bbb2
77a70ee
ae68879
2898702
ae68879
2898702
 
ae68879
 
 
 
 
 
 
 
2898702
 
 
ae68879
2898702
 
ae68879
2898702
 
 
 
 
ae68879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e81b01
2898702
 
 
5e81b01
2898702
 
 
5e81b01
2898702
ae68879
77a70ee
2898702
38e5a7e
21a55e5
 
77a70ee
2898702
21a55e5
5e81b01
21a55e5
2898702
21a55e5
77a70ee
21a55e5
ae68879
 
21a55e5
2898702
da6bbb2
21a55e5
2898702
da6bbb2
 
5e81b01
21a55e5
ae68879
21a55e5
ae68879
 
2898702
 
 
 
ae68879
 
 
 
 
 
 
2898702
 
ae68879
3d83a20
ae68879
5e81b01
ae68879
 
5e81b01
 
ae68879
 
 
 
 
 
2898702
 
5e81b01
2898702
 
ae68879
77a70ee
 
 
5e81b01
2898702
 
 
 
5e81b01
2898702
21a55e5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Two‑stage AI‑image detector with visual explainability
──────────────────────────────────────────────────────
Stage‑1 : haywoodsloan/ai-image-detector-deploy (Swin‑V2)  β†’ RealΒ vsΒ AI
          ⟳ Grad‑CAM overlay
Stage‑2 : SuSy.pt (torchscript ResNet)                      β†’ Generator
          ⟳ Saliency‑grad overlay (Captum)
"""

# ───────────────────── Imports ────────────────────────────────────────
import torch, numpy as np, pandas as pd, matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchcam.methods import GradCAM
from captum.attr import Saliency
from skimage.feature import graycomatrix, graycoprops
import gradio as gr

# ─────────────────── Runtime / models ────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.set_loglevel("ERROR")

BIN_ID   = "haywoodsloan/ai-image-detector-deploy"
bin_proc = AutoImageProcessor.from_pretrained(BIN_ID)
bin_mod  = AutoModelForImageClassification.from_pretrained(BIN_ID).to(device).eval()
CAM_LAYER_BIN = "encoder.layers.3.blocks.1.layernorm_after"

susy_mod = torch.jit.load("SuSy.pt").to(device).eval()
CAM_LAYER_SUSY = "feature_extractor.resnet_model.layer4.1.relu"

GEN_CLASSES = ["Stable Diffusion 1.x", "DALLΒ·EΒ 3",
               "MJΒ V5/V6", "Stable DiffusionΒ XL", "MJΒ V1/V2"]
PATCH, TOP = 224, 5

# ─────────────── Universal overlay helper ────────────────────────────
def overlay_explanation(model, model_inputs, target_layer, class_idx, base_img):
    """Return heat‑map PIL.Image blended on top of base_img."""
    is_script = isinstance(model, torch.jit.ScriptModule)

    # clone & ensure gradients
    if torch.is_tensor(model_inputs):
        forward_inputs = model_inputs.clone().detach().requires_grad_(True)
    else:
        forward_inputs = {
            k: v.clone().detach().requires_grad_(True)
            for k, v in model_inputs.items()
        }

    if is_script:
        model.zero_grad(set_to_none=True)
        sal   = Saliency(model)
        grads = sal.attribute(forward_inputs, target=class_idx).abs().mean(1, keepdim=True)
        mask  = grads.squeeze().detach().cpu().numpy()

    else:
        mods = dict(model.named_modules())
        tgt  = mods.get(target_layer) or next(m for n, m in mods.items() if n.endswith(target_layer))
        cam  = GradCAM(model, target_layer=tgt)

        with torch.enable_grad():
            outputs = (model(forward_inputs) if torch.is_tensor(forward_inputs)
                       else model(**forward_inputs))
            logits  = outputs.logits if hasattr(outputs, "logits") else outputs

            # torchcamΒ 0.7 β†’ scores=…, earlier β†’ logits=…
            try:
                cam_result = cam(class_idx, scores=logits)
            except TypeError:
                cam_result = cam(class_idx, logits=logits)

            mask = cam_result[0].detach().cpu().numpy()

        # clean up handles
        if hasattr(cam, "remove_hooks"):
            cam.remove_hooks()
        elif hasattr(cam, "clear_hooks"):
            cam.clear_hooks()

    mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-6)
    heat = Image.fromarray((plt.cm.jet(mask)[:, :, :3] * 255).astype(np.uint8))\
               .resize(base_img.size, Image.BICUBIC)
    return Image.blend(base_img.convert("RGBA"), heat.convert("RGBA"), alpha=0.45)

# ───────────── SuSy patch‑ranking helper ──────────────────────────────
to_tensor = transforms.ToTensor()
to_gray   = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])

def susy_predict(img: Image.Image):
    w, h   = img.size
    npx, npy = max(1, w // PATCH), max(1, h // PATCH)
    patches  = np.zeros((npx * npy, PATCH, PATCH, 3), dtype=np.uint8)

    for i in range(npx):
        for j in range(npy):
            x, y = i * PATCH, j * PATCH
            patches[i*npy+j] = np.array(img.crop((x, y, x+PATCH, y+PATCH)).resize((PATCH, PATCH)))

    contrasts = []
    for p in patches:
        g = to_gray(Image.fromarray(p)).squeeze(0).numpy()
        glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True)
        contrasts.append(graycoprops(glcm, "contrast")[0, 0])

    idx   = np.argsort(contrasts)[::-1][:TOP]
    tens  = torch.from_numpy(patches[idx].transpose(0, 3, 1, 2)).float() / 255.0
    with torch.no_grad():
        probs = susy_mod(tens.to(device)).softmax(-1).mean(0).cpu().numpy()[1:]
    return dict(zip(GEN_CLASSES, probs))

# ───────────────────── Pipeline ───────────────────────────────────────
def pipeline(img_arr):
    img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr
    heatmaps = []

    # Stage‑1 classification (no grad)
    with torch.no_grad():
        inp_bin  = bin_proc(images=img, return_tensors="pt").to(device)
        logits   = bin_mod(**inp_bin).logits.softmax(-1)[0]

    ai_conf, real_conf = logits
    winner_idx = 0 if ai_conf >= real_conf else 1

    # Stage‑1 heat‑map
    inp_bin_heat = {k: v.clone().detach().requires_grad_(True) for k, v in inp_bin.items()}
    heatmaps.append(
        overlay_explanation(bin_mod, inp_bin_heat, CAM_LAYER_BIN, winner_idx, img)
    )

    verdict = f"Authentic ({real_conf*100:.1f}β€―%)"
    bar_df, show_bar = None, False

    # Stage‑2 if AI
    if ai_conf > real_conf:
        verdict = f"AI‑generated ({ai_conf*100:.1f}β€―%)"
        gen_probs = susy_predict(img)
        bar_df    = pd.DataFrame({"class": gen_probs.keys(), "prob": gen_probs.values()})
        show_bar  = True

        with torch.no_grad():
            susy_in = to_tensor(img.resize((224, 224))).unsqueeze(0).to(device)
            g_idx   = susy_mod(susy_in)[0, 1:].argmax().item() + 1

        heatmaps.append(
            overlay_explanation(susy_mod, susy_in, CAM_LAYER_SUSY, g_idx, img)
        )

    return verdict, gr.update(value=bar_df, visible=show_bar), heatmaps

# ───────────────────────── UI ─────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## πŸ–ΌοΈΒ Two‑Stage AI Fake DetectorΒ β€”Β Explained with Heat‑maps")
    with gr.Row():
        img_in = gr.Image(type="numpy", label="Upload image")
        btn    = gr.Button("Detect")

    txt_out = gr.Textbox(label="Verdict", interactive=False)
    bar_out = gr.BarPlot(x="class", y="prob", title="Likely generator",
                         y_label="probability", visible=False)
    gal_out = gr.Gallery(label="Heat‑maps", columns=2, height=320)

    btn.click(pipeline, inputs=img_in, outputs=[txt_out, bar_out, gal_out])

demo.launch()