""" Two‑stage AI‑image detector with visual explainability ────────────────────────────────────────────────────── Stage‑1 : haywoodsloan/ai-image-detector-deploy (Swin‑V2) → Real vs AI ⟳ Grad‑CAM overlay Stage‑2 : SuSy.pt (torchscript ResNet) → Generator ⟳ Saliency‑grad overlay (Captum) """ # ───────────────────── Imports ──────────────────────────────────────── import torch, numpy as np, pandas as pd, matplotlib.pyplot as plt from PIL import Image from torchvision import transforms from transformers import AutoImageProcessor, AutoModelForImageClassification from torchcam.methods import GradCAM from captum.attr import Saliency from skimage.feature import graycomatrix, graycoprops import gradio as gr # ─────────────────── Runtime / models ──────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") plt.set_loglevel("ERROR") BIN_ID = "haywoodsloan/ai-image-detector-deploy" bin_proc = AutoImageProcessor.from_pretrained(BIN_ID) bin_mod = AutoModelForImageClassification.from_pretrained(BIN_ID).to(device).eval() CAM_LAYER_BIN = "encoder.layers.3.blocks.1.layernorm_after" susy_mod = torch.jit.load("SuSy.pt").to(device).eval() CAM_LAYER_SUSY = "feature_extractor.resnet_model.layer4.1.relu" GEN_CLASSES = ["Stable Diffusion 1.x", "DALL·E 3", "MJ V5/V6", "Stable Diffusion XL", "MJ V1/V2"] PATCH, TOP = 224, 5 # ─────────────── Universal overlay helper ──────────────────────────── def overlay_explanation(model, model_inputs, target_layer, class_idx, base_img): """Return heat‑map PIL.Image blended on top of base_img.""" is_script = isinstance(model, torch.jit.ScriptModule) # clone & ensure gradients if torch.is_tensor(model_inputs): forward_inputs = model_inputs.clone().detach().requires_grad_(True) else: forward_inputs = { k: v.clone().detach().requires_grad_(True) for k, v in model_inputs.items() } if is_script: model.zero_grad(set_to_none=True) sal = Saliency(model) grads = sal.attribute(forward_inputs, target=class_idx).abs().mean(1, keepdim=True) mask = grads.squeeze().detach().cpu().numpy() else: mods = dict(model.named_modules()) tgt = mods.get(target_layer) or next(m for n, m in mods.items() if n.endswith(target_layer)) cam = GradCAM(model, target_layer=tgt) with torch.enable_grad(): outputs = (model(forward_inputs) if torch.is_tensor(forward_inputs) else model(**forward_inputs)) logits = outputs.logits if hasattr(outputs, "logits") else outputs # torchcam 0.7 → scores=…, earlier → logits=… try: cam_result = cam(class_idx, scores=logits) except TypeError: cam_result = cam(class_idx, logits=logits) mask = cam_result[0].detach().cpu().numpy() # clean up handles if hasattr(cam, "remove_hooks"): cam.remove_hooks() elif hasattr(cam, "clear_hooks"): cam.clear_hooks() mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-6) heat = Image.fromarray((plt.cm.jet(mask)[:, :, :3] * 255).astype(np.uint8))\ .resize(base_img.size, Image.BICUBIC) return Image.blend(base_img.convert("RGBA"), heat.convert("RGBA"), alpha=0.45) # ───────────── SuSy patch‑ranking helper ────────────────────────────── to_tensor = transforms.ToTensor() to_gray = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()]) def susy_predict(img: Image.Image): w, h = img.size npx, npy = max(1, w // PATCH), max(1, h // PATCH) patches = np.zeros((npx * npy, PATCH, PATCH, 3), dtype=np.uint8) for i in range(npx): for j in range(npy): x, y = i * PATCH, j * PATCH patches[i*npy+j] = np.array(img.crop((x, y, x+PATCH, y+PATCH)).resize((PATCH, PATCH))) contrasts = [] for p in patches: g = to_gray(Image.fromarray(p)).squeeze(0).numpy() glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True) contrasts.append(graycoprops(glcm, "contrast")[0, 0]) idx = np.argsort(contrasts)[::-1][:TOP] tens = torch.from_numpy(patches[idx].transpose(0, 3, 1, 2)).float() / 255.0 with torch.no_grad(): probs = susy_mod(tens.to(device)).softmax(-1).mean(0).cpu().numpy()[1:] return dict(zip(GEN_CLASSES, probs)) # ───────────────────── Pipeline ─────────────────────────────────────── def pipeline(img_arr): img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr heatmaps = [] # Stage‑1 classification (no grad) with torch.no_grad(): inp_bin = bin_proc(images=img, return_tensors="pt").to(device) logits = bin_mod(**inp_bin).logits.softmax(-1)[0] ai_conf, real_conf = logits winner_idx = 0 if ai_conf >= real_conf else 1 # Stage‑1 heat‑map inp_bin_heat = {k: v.clone().detach().requires_grad_(True) for k, v in inp_bin.items()} heatmaps.append( overlay_explanation(bin_mod, inp_bin_heat, CAM_LAYER_BIN, winner_idx, img) ) verdict = f"Authentic ({real_conf*100:.1f} %)" bar_df, show_bar = None, False # Stage‑2 if AI if ai_conf > real_conf: verdict = f"AI‑generated ({ai_conf*100:.1f} %)" gen_probs = susy_predict(img) bar_df = pd.DataFrame({"class": gen_probs.keys(), "prob": gen_probs.values()}) show_bar = True with torch.no_grad(): susy_in = to_tensor(img.resize((224, 224))).unsqueeze(0).to(device) g_idx = susy_mod(susy_in)[0, 1:].argmax().item() + 1 heatmaps.append( overlay_explanation(susy_mod, susy_in, CAM_LAYER_SUSY, g_idx, img) ) return verdict, gr.update(value=bar_df, visible=show_bar), heatmaps # ───────────────────────── UI ───────────────────────────────────────── with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## 🖼️ Two‑Stage AI Fake Detector — Explained with Heat‑maps") with gr.Row(): img_in = gr.Image(type="numpy", label="Upload image") btn = gr.Button("Detect") txt_out = gr.Textbox(label="Verdict", interactive=False) bar_out = gr.BarPlot(x="class", y="prob", title="Likely generator", y_label="probability", visible=False) gal_out = gr.Gallery(label="Heat‑maps", columns=2, height=320) btn.click(pipeline, inputs=img_in, outputs=[txt_out, bar_out, gal_out]) demo.launch()