|
""" |
|
Hugging Face Spaces app for Enhanced AI Image Detector |
|
""" |
|
|
|
import os |
|
import sys |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
|
|
|
|
print(f"Python version: {sys.version}") |
|
print(f"PyTorch version: {torch.__version__}") |
|
print(f"Working directory: {os.getcwd()}") |
|
print(f"Directory contents: {os.listdir('.')}") |
|
|
|
|
|
class AIDetectorModel(nn.Module): |
|
def __init__(self): |
|
super(AIDetectorModel, self).__init__() |
|
|
|
self.base_model = models.efficientnet_v2_s(weights=None) |
|
|
|
|
|
num_features = self.base_model.classifier[1].in_features |
|
self.base_model.classifier = nn.Sequential( |
|
nn.Dropout(p=0.2), |
|
nn.Linear(num_features, 2) |
|
) |
|
|
|
def forward(self, x): |
|
return self.base_model(x) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
model = None |
|
device = None |
|
|
|
|
|
def load_model_if_needed(): |
|
global model, device |
|
|
|
|
|
if model is not None: |
|
return model, device |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
model = AIDetectorModel() |
|
model_path = "best_model_improved.pth" |
|
|
|
try: |
|
|
|
if not os.path.exists(model_path): |
|
print(f"Model file not found at {model_path}") |
|
print(f"Searching in current directory: {os.listdir('.')}") |
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
|
|
print(f"Loading model from {model_path}") |
|
state_dict = torch.load(model_path, map_location=device) |
|
print(f"Model state keys: {list(state_dict.keys())[:5]}... (showing first 5)") |
|
print(f"Model state contains {len(state_dict)} parameters") |
|
|
|
|
|
adapted_state_dict = {} |
|
for key, value in state_dict.items(): |
|
|
|
if key.startswith('module.'): |
|
adapted_key = key[7:] |
|
else: |
|
adapted_key = key |
|
|
|
|
|
if adapted_key.startswith('base_model.classifier.') and 'base_model.classifier.' not in str(model.state_dict().keys()): |
|
|
|
classifier_part = adapted_key.split('base_model.classifier.')[1] |
|
if '.' in classifier_part: |
|
layer_idx, param_type = classifier_part.split('.') |
|
new_key = f"base_model.classifier.{int(layer_idx) // 2}.{param_type}" |
|
adapted_state_dict[new_key] = value |
|
else: |
|
adapted_state_dict[adapted_key] = value |
|
else: |
|
adapted_state_dict[adapted_key] = value |
|
|
|
|
|
try: |
|
model.load_state_dict(adapted_state_dict) |
|
print("Model loaded successfully with adapted state dict") |
|
except Exception as e: |
|
print(f"Error with adapted state dict: {str(e)}") |
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
print("Model loaded with original state dict and strict=False") |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
print("Trying with strict=False...") |
|
try: |
|
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) |
|
print("Model loaded with strict=False") |
|
except Exception as e2: |
|
print(f"Failed to load model with strict=False: {str(e2)}") |
|
|
|
print("Initializing with pretrained weights as fallback") |
|
base_model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT) |
|
model.base_model = base_model |
|
model.base_model.classifier = nn.Sequential( |
|
nn.Dropout(p=0.2), |
|
nn.Linear(model.base_model.classifier[1].in_features, 2) |
|
) |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
return model, device |
|
|
|
def analyze_image(image): |
|
"""Analyze an image to determine if it's real or AI-generated""" |
|
if image is None: |
|
return "Please upload an image", "" |
|
|
|
try: |
|
|
|
global model, device |
|
if model is None: |
|
model, device = load_model_if_needed() |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image).convert('RGB') |
|
|
|
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(image_tensor) |
|
print(f"Raw model outputs: {outputs}") |
|
|
|
|
|
|
|
temperature = 2.0 |
|
scaled_outputs = outputs / temperature |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(scaled_outputs, dim=1) |
|
print(f"Softmax probabilities (with temperature scaling): {probabilities}") |
|
|
|
|
|
|
|
ai_score_option1 = probabilities[0, 1].item() |
|
ai_score_option2 = probabilities[0, 0].item() |
|
|
|
|
|
if abs(ai_score_option1 - 0.5) > abs(ai_score_option2 - 0.5): |
|
ai_score = ai_score_option1 |
|
real_score = probabilities[0, 0].item() |
|
print(f"Using class 1 as AI: AI score: {ai_score:.4f}, Real score: {real_score:.4f}") |
|
else: |
|
ai_score = ai_score_option2 |
|
real_score = probabilities[0, 1].item() |
|
print(f"Using class 0 as AI: AI score: {ai_score:.4f}, Real score: {real_score:.4f}") |
|
|
|
|
|
is_ai_generated = ai_score > 0.5 |
|
|
|
|
|
if is_ai_generated: |
|
message = f"π€ This image is likely AI-generated (Confidence: {ai_score:.2f})" |
|
else: |
|
message = f"π· This image is likely authentic (Confidence: {real_score:.2f})" |
|
|
|
|
|
detailed_analysis = f""" |
|
### Detailed Analysis: |
|
|
|
| Property | Value | |
|
|----------|-------| |
|
| AI Score | {ai_score:.4f} | |
|
| Real Score | {real_score:.4f} | |
|
| Prediction | {'AI-generated' if is_ai_generated else 'Real'} | |
|
| Model | Enhanced AI Image Detector | |
|
| Architecture | EfficientNetV2-S | |
|
""" |
|
|
|
return message, detailed_analysis |
|
|
|
except Exception as e: |
|
error_message = f"Error analyzing image: {str(e)}" |
|
print(f"Exception in analyze_image: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return error_message, "" |
|
|
|
|
|
with gr.Blocks(title="Enhanced AI Image Detector") as demo: |
|
gr.Markdown("# π Enhanced AI Image Detector") |
|
gr.Markdown("Upload an image to determine if it's real or AI-generated.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
analyze_button = gr.Button("Analyze Image", variant="primary") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
result_text = gr.Textbox(label="Result") |
|
detailed_output = gr.Markdown(label="Detailed Analysis") |
|
|
|
analyze_button.click( |
|
fn=analyze_image, |
|
inputs=[input_image], |
|
outputs=[result_text, detailed_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
## How it works |
|
|
|
This model uses a trained PyTorch neural network (EfficientNetV2-S) to detect AI-generated images. The model has been trained on a large dataset of real and AI-generated images to learn the subtle differences between them. |
|
|
|
The model can detect patterns that are often invisible to the human eye, including: |
|
|
|
1. **Noise and artifact patterns** specific to AI generation methods |
|
2. **Texture inconsistencies** that appear in AI-generated content |
|
3. **Color and lighting anomalies** common in synthetic images |
|
4. **Structural patterns** that differ from natural photographs |
|
|
|
## Limitations |
|
|
|
- The model may struggle with highly realistic AI-generated images from newer generation models |
|
- Some real images with unusual characteristics may be misclassified |
|
- Performance depends on image quality and resolution |
|
- The model works best with images similar to those in its training dataset |
|
""") |
|
|
|
|
|
demo.launch() |
|
|