Spaces:
Running
Running
""" | |
Twoβstage AIβimage detector with visual explainability | |
ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
Stageβ1 : haywoodsloan/ai-image-detector-deploy (SwinβV2) β RealΒ vsΒ AI | |
β³ GradβCAM overlay | |
Stageβ2 : SuSy.pt (torchscript ResNet) β Generator | |
β³ Saliencyβgrad overlay (Captum) | |
""" | |
# βββββββββββββββββββββ Imports ββββββββββββββββββββββββββββββββββββββββ | |
import torch, numpy as np, pandas as pd, matplotlib.pyplot as plt | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from torchcam.methods import GradCAM | |
from captum.attr import Saliency | |
from skimage.feature import graycomatrix, graycoprops | |
import gradio as gr | |
# βββββββββββββββββββ Runtime / models ββββββββββββββββββββββββββββββββ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
plt.set_loglevel("ERROR") | |
BIN_ID = "haywoodsloan/ai-image-detector-deploy" | |
bin_proc = AutoImageProcessor.from_pretrained(BIN_ID) | |
bin_mod = AutoModelForImageClassification.from_pretrained(BIN_ID).to(device).eval() | |
CAM_LAYER_BIN = "encoder.layers.3.blocks.1.layernorm_after" | |
susy_mod = torch.jit.load("SuSy.pt").to(device).eval() | |
CAM_LAYER_SUSY = "feature_extractor.resnet_model.layer4.1.relu" | |
GEN_CLASSES = ["Stable Diffusion 1.x", "DALLΒ·EΒ 3", | |
"MJΒ V5/V6", "Stable DiffusionΒ XL", "MJΒ V1/V2"] | |
PATCH, TOP = 224, 5 | |
# βββββββββββββββ Universal overlay helper ββββββββββββββββββββββββββββ | |
def overlay_explanation(model, model_inputs, target_layer, class_idx, base_img): | |
"""Return heatβmap PIL.Image blended on top of base_img.""" | |
is_script = isinstance(model, torch.jit.ScriptModule) | |
# clone & ensure gradients | |
if torch.is_tensor(model_inputs): | |
forward_inputs = model_inputs.clone().detach().requires_grad_(True) | |
else: | |
forward_inputs = { | |
k: v.clone().detach().requires_grad_(True) | |
for k, v in model_inputs.items() | |
} | |
if is_script: | |
model.zero_grad(set_to_none=True) | |
sal = Saliency(model) | |
grads = sal.attribute(forward_inputs, target=class_idx).abs().mean(1, keepdim=True) | |
mask = grads.squeeze().detach().cpu().numpy() | |
else: | |
mods = dict(model.named_modules()) | |
tgt = mods.get(target_layer) or next(m for n, m in mods.items() if n.endswith(target_layer)) | |
cam = GradCAM(model, target_layer=tgt) | |
with torch.enable_grad(): | |
outputs = (model(forward_inputs) if torch.is_tensor(forward_inputs) | |
else model(**forward_inputs)) | |
logits = outputs.logits if hasattr(outputs, "logits") else outputs | |
# torchcamΒ 0.7 β scores=β¦, earlier β logits=β¦ | |
try: | |
cam_result = cam(class_idx, scores=logits) | |
except TypeError: | |
cam_result = cam(class_idx, logits=logits) | |
mask = cam_result[0].detach().cpu().numpy() | |
# clean up handles | |
if hasattr(cam, "remove_hooks"): | |
cam.remove_hooks() | |
elif hasattr(cam, "clear_hooks"): | |
cam.clear_hooks() | |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-6) | |
heat = Image.fromarray((plt.cm.jet(mask)[:, :, :3] * 255).astype(np.uint8))\ | |
.resize(base_img.size, Image.BICUBIC) | |
return Image.blend(base_img.convert("RGBA"), heat.convert("RGBA"), alpha=0.45) | |
# βββββββββββββ SuSy patchβranking helper ββββββββββββββββββββββββββββββ | |
to_tensor = transforms.ToTensor() | |
to_gray = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()]) | |
def susy_predict(img: Image.Image): | |
w, h = img.size | |
npx, npy = max(1, w // PATCH), max(1, h // PATCH) | |
patches = np.zeros((npx * npy, PATCH, PATCH, 3), dtype=np.uint8) | |
for i in range(npx): | |
for j in range(npy): | |
x, y = i * PATCH, j * PATCH | |
patches[i*npy+j] = np.array(img.crop((x, y, x+PATCH, y+PATCH)).resize((PATCH, PATCH))) | |
contrasts = [] | |
for p in patches: | |
g = to_gray(Image.fromarray(p)).squeeze(0).numpy() | |
glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True) | |
contrasts.append(graycoprops(glcm, "contrast")[0, 0]) | |
idx = np.argsort(contrasts)[::-1][:TOP] | |
tens = torch.from_numpy(patches[idx].transpose(0, 3, 1, 2)).float() / 255.0 | |
with torch.no_grad(): | |
probs = susy_mod(tens.to(device)).softmax(-1).mean(0).cpu().numpy()[1:] | |
return dict(zip(GEN_CLASSES, probs)) | |
# βββββββββββββββββββββ Pipeline βββββββββββββββββββββββββββββββββββββββ | |
def pipeline(img_arr): | |
img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr | |
heatmaps = [] | |
# Stageβ1 classification (no grad) | |
with torch.no_grad(): | |
inp_bin = bin_proc(images=img, return_tensors="pt").to(device) | |
logits = bin_mod(**inp_bin).logits.softmax(-1)[0] | |
ai_conf, real_conf = logits | |
winner_idx = 0 if ai_conf >= real_conf else 1 | |
# Stageβ1 heatβmap | |
inp_bin_heat = {k: v.clone().detach().requires_grad_(True) for k, v in inp_bin.items()} | |
heatmaps.append( | |
overlay_explanation(bin_mod, inp_bin_heat, CAM_LAYER_BIN, winner_idx, img) | |
) | |
verdict = f"Authentic ({real_conf*100:.1f}β―%)" | |
bar_df, show_bar = None, False | |
# Stageβ2 if AI | |
if ai_conf > real_conf: | |
verdict = f"AIβgenerated ({ai_conf*100:.1f}β―%)" | |
gen_probs = susy_predict(img) | |
bar_df = pd.DataFrame({"class": gen_probs.keys(), "prob": gen_probs.values()}) | |
show_bar = True | |
with torch.no_grad(): | |
susy_in = to_tensor(img.resize((224, 224))).unsqueeze(0).to(device) | |
g_idx = susy_mod(susy_in)[0, 1:].argmax().item() + 1 | |
heatmaps.append( | |
overlay_explanation(susy_mod, susy_in, CAM_LAYER_SUSY, g_idx, img) | |
) | |
return verdict, gr.update(value=bar_df, visible=show_bar), heatmaps | |
# βββββββββββββββββββββββββ UI βββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## πΌοΈΒ TwoβStage AI Fake DetectorΒ βΒ Explained with Heatβmaps") | |
with gr.Row(): | |
img_in = gr.Image(type="numpy", label="Upload image") | |
btn = gr.Button("Detect") | |
txt_out = gr.Textbox(label="Verdict", interactive=False) | |
bar_out = gr.BarPlot(x="class", y="prob", title="Likely generator", | |
y_label="probability", visible=False) | |
gal_out = gr.Gallery(label="Heatβmaps", columns=2, height=320) | |
btn.click(pipeline, inputs=img_in, outputs=[txt_out, bar_out, gal_out]) | |
demo.launch() | |