Sleepyriizi's picture
CAM layers fixed
ae68879
"""
Two‑stage AI‑image detector with visual explainability
──────────────────────────────────────────────────────
Stage‑1 : haywoodsloan/ai-image-detector-deploy (Swin‑V2) β†’ RealΒ vsΒ AI
⟳ Grad‑CAM overlay
Stage‑2 : SuSy.pt (torchscript ResNet) β†’ Generator
⟳ Saliency‑grad overlay (Captum)
"""
# ───────────────────── Imports ────────────────────────────────────────
import torch, numpy as np, pandas as pd, matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchcam.methods import GradCAM
from captum.attr import Saliency
from skimage.feature import graycomatrix, graycoprops
import gradio as gr
# ─────────────────── Runtime / models ────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.set_loglevel("ERROR")
BIN_ID = "haywoodsloan/ai-image-detector-deploy"
bin_proc = AutoImageProcessor.from_pretrained(BIN_ID)
bin_mod = AutoModelForImageClassification.from_pretrained(BIN_ID).to(device).eval()
CAM_LAYER_BIN = "encoder.layers.3.blocks.1.layernorm_after"
susy_mod = torch.jit.load("SuSy.pt").to(device).eval()
CAM_LAYER_SUSY = "feature_extractor.resnet_model.layer4.1.relu"
GEN_CLASSES = ["Stable Diffusion 1.x", "DALLΒ·EΒ 3",
"MJΒ V5/V6", "Stable DiffusionΒ XL", "MJΒ V1/V2"]
PATCH, TOP = 224, 5
# ─────────────── Universal overlay helper ────────────────────────────
def overlay_explanation(model, model_inputs, target_layer, class_idx, base_img):
"""Return heat‑map PIL.Image blended on top of base_img."""
is_script = isinstance(model, torch.jit.ScriptModule)
# clone & ensure gradients
if torch.is_tensor(model_inputs):
forward_inputs = model_inputs.clone().detach().requires_grad_(True)
else:
forward_inputs = {
k: v.clone().detach().requires_grad_(True)
for k, v in model_inputs.items()
}
if is_script:
model.zero_grad(set_to_none=True)
sal = Saliency(model)
grads = sal.attribute(forward_inputs, target=class_idx).abs().mean(1, keepdim=True)
mask = grads.squeeze().detach().cpu().numpy()
else:
mods = dict(model.named_modules())
tgt = mods.get(target_layer) or next(m for n, m in mods.items() if n.endswith(target_layer))
cam = GradCAM(model, target_layer=tgt)
with torch.enable_grad():
outputs = (model(forward_inputs) if torch.is_tensor(forward_inputs)
else model(**forward_inputs))
logits = outputs.logits if hasattr(outputs, "logits") else outputs
# torchcamΒ 0.7 β†’ scores=…, earlier β†’ logits=…
try:
cam_result = cam(class_idx, scores=logits)
except TypeError:
cam_result = cam(class_idx, logits=logits)
mask = cam_result[0].detach().cpu().numpy()
# clean up handles
if hasattr(cam, "remove_hooks"):
cam.remove_hooks()
elif hasattr(cam, "clear_hooks"):
cam.clear_hooks()
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-6)
heat = Image.fromarray((plt.cm.jet(mask)[:, :, :3] * 255).astype(np.uint8))\
.resize(base_img.size, Image.BICUBIC)
return Image.blend(base_img.convert("RGBA"), heat.convert("RGBA"), alpha=0.45)
# ───────────── SuSy patch‑ranking helper ──────────────────────────────
to_tensor = transforms.ToTensor()
to_gray = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
def susy_predict(img: Image.Image):
w, h = img.size
npx, npy = max(1, w // PATCH), max(1, h // PATCH)
patches = np.zeros((npx * npy, PATCH, PATCH, 3), dtype=np.uint8)
for i in range(npx):
for j in range(npy):
x, y = i * PATCH, j * PATCH
patches[i*npy+j] = np.array(img.crop((x, y, x+PATCH, y+PATCH)).resize((PATCH, PATCH)))
contrasts = []
for p in patches:
g = to_gray(Image.fromarray(p)).squeeze(0).numpy()
glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True)
contrasts.append(graycoprops(glcm, "contrast")[0, 0])
idx = np.argsort(contrasts)[::-1][:TOP]
tens = torch.from_numpy(patches[idx].transpose(0, 3, 1, 2)).float() / 255.0
with torch.no_grad():
probs = susy_mod(tens.to(device)).softmax(-1).mean(0).cpu().numpy()[1:]
return dict(zip(GEN_CLASSES, probs))
# ───────────────────── Pipeline ───────────────────────────────────────
def pipeline(img_arr):
img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr
heatmaps = []
# Stage‑1 classification (no grad)
with torch.no_grad():
inp_bin = bin_proc(images=img, return_tensors="pt").to(device)
logits = bin_mod(**inp_bin).logits.softmax(-1)[0]
ai_conf, real_conf = logits
winner_idx = 0 if ai_conf >= real_conf else 1
# Stage‑1 heat‑map
inp_bin_heat = {k: v.clone().detach().requires_grad_(True) for k, v in inp_bin.items()}
heatmaps.append(
overlay_explanation(bin_mod, inp_bin_heat, CAM_LAYER_BIN, winner_idx, img)
)
verdict = f"Authentic ({real_conf*100:.1f}β€―%)"
bar_df, show_bar = None, False
# Stage‑2 if AI
if ai_conf > real_conf:
verdict = f"AI‑generated ({ai_conf*100:.1f}β€―%)"
gen_probs = susy_predict(img)
bar_df = pd.DataFrame({"class": gen_probs.keys(), "prob": gen_probs.values()})
show_bar = True
with torch.no_grad():
susy_in = to_tensor(img.resize((224, 224))).unsqueeze(0).to(device)
g_idx = susy_mod(susy_in)[0, 1:].argmax().item() + 1
heatmaps.append(
overlay_explanation(susy_mod, susy_in, CAM_LAYER_SUSY, g_idx, img)
)
return verdict, gr.update(value=bar_df, visible=show_bar), heatmaps
# ───────────────────────── UI ─────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## πŸ–ΌοΈΒ Two‑Stage AI Fake DetectorΒ β€”Β Explained with Heat‑maps")
with gr.Row():
img_in = gr.Image(type="numpy", label="Upload image")
btn = gr.Button("Detect")
txt_out = gr.Textbox(label="Verdict", interactive=False)
bar_out = gr.BarPlot(x="class", y="prob", title="Likely generator",
y_label="probability", visible=False)
gal_out = gr.Gallery(label="Heat‑maps", columns=2, height=320)
btn.click(pipeline, inputs=img_in, outputs=[txt_out, bar_out, gal_out])
demo.launch()