1523 / app.py
yaya36095's picture
Update app.py
6c7e92f verified
"""
Hugging Face Spaces app for Enhanced AI Image Detector
"""
import os
import sys
import gradio as gr
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
# Print debugging information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Working directory: {os.getcwd()}")
print(f"Directory contents: {os.listdir('.')}")
# Define the model architecture based on EfficientNetV2-S
class AIDetectorModel(nn.Module):
def __init__(self):
super(AIDetectorModel, self).__init__()
# Load EfficientNetV2-S as base model
self.base_model = models.efficientnet_v2_s(weights=None)
# Replace classifier with custom layers
num_features = self.base_model.classifier[1].in_features
self.base_model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(num_features, 2) # 2 classes: real or AI-generated
)
def forward(self, x):
return self.base_model(x)
# Define image transformations - make sure these match what was used during training
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Global variables for model and device
model = None
device = None
# Function to load model (called when first needed)
def load_model_if_needed():
global model, device
# Only load if not already loaded
if model is not None:
return model, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = AIDetectorModel()
model_path = "best_model_improved.pth"
try:
# Check if file exists
if not os.path.exists(model_path):
print(f"Model file not found at {model_path}")
print(f"Searching in current directory: {os.listdir('.')}")
raise FileNotFoundError(f"Model file not found: {model_path}")
# Try to load the model
print(f"Loading model from {model_path}")
state_dict = torch.load(model_path, map_location=device)
print(f"Model state keys: {list(state_dict.keys())[:5]}... (showing first 5)")
print(f"Model state contains {len(state_dict)} parameters")
# Try to adapt the state dict if needed
adapted_state_dict = {}
for key, value in state_dict.items():
# Remove 'module.' prefix if it exists (from DataParallel)
if key.startswith('module.'):
adapted_key = key[7:]
else:
adapted_key = key
# Handle potential differences in key names
if adapted_key.startswith('base_model.classifier.') and 'base_model.classifier.' not in str(model.state_dict().keys()):
# Try to map to the right classifier key
classifier_part = adapted_key.split('base_model.classifier.')[1]
if '.' in classifier_part:
layer_idx, param_type = classifier_part.split('.')
new_key = f"base_model.classifier.{int(layer_idx) // 2}.{param_type}"
adapted_state_dict[new_key] = value
else:
adapted_state_dict[adapted_key] = value
else:
adapted_state_dict[adapted_key] = value
# Try loading with the adapted state dict
try:
model.load_state_dict(adapted_state_dict)
print("Model loaded successfully with adapted state dict")
except Exception as e:
print(f"Error with adapted state dict: {str(e)}")
# Fall back to original with strict=False
model.load_state_dict(state_dict, strict=False)
print("Model loaded with original state dict and strict=False")
except Exception as e:
print(f"Error loading model: {str(e)}")
print("Trying with strict=False...")
try:
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
print("Model loaded with strict=False")
except Exception as e2:
print(f"Failed to load model with strict=False: {str(e2)}")
# Initialize with pretrained weights as fallback
print("Initializing with pretrained weights as fallback")
base_model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)
model.base_model = base_model
model.base_model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(model.base_model.classifier[1].in_features, 2)
)
model.to(device)
model.eval()
return model, device
def analyze_image(image):
"""Analyze an image to determine if it's real or AI-generated"""
if image is None:
return "Please upload an image", ""
try:
# Load model when first needed
global model, device
if model is None:
model, device = load_model_if_needed()
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Preprocess the image
image_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(image_tensor)
print(f"Raw model outputs: {outputs}")
# Try applying a temperature scaling to make predictions more extreme
# This can help when model is giving outputs too close to 0.5
temperature = 2.0 # Adjust this value to control prediction confidence
scaled_outputs = outputs / temperature
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(scaled_outputs, dim=1)
print(f"Softmax probabilities (with temperature scaling): {probabilities}")
# IMPORTANT: Try both class index mappings
# Some models might have class 0 as AI and class 1 as real, or vice versa
ai_score_option1 = probabilities[0, 1].item() # Assuming class 1 is AI
ai_score_option2 = probabilities[0, 0].item() # Assuming class 0 is AI
# Choose the option that gives more extreme predictions (further from 0.5)
if abs(ai_score_option1 - 0.5) > abs(ai_score_option2 - 0.5):
ai_score = ai_score_option1
real_score = probabilities[0, 0].item()
print(f"Using class 1 as AI: AI score: {ai_score:.4f}, Real score: {real_score:.4f}")
else:
ai_score = ai_score_option2
real_score = probabilities[0, 1].item()
print(f"Using class 0 as AI: AI score: {ai_score:.4f}, Real score: {real_score:.4f}")
# Determine if the image is AI-generated
is_ai_generated = ai_score > 0.5
# Create result message
if is_ai_generated:
message = f"πŸ€– This image is likely AI-generated (Confidence: {ai_score:.2f})"
else:
message = f"πŸ“· This image is likely authentic (Confidence: {real_score:.2f})"
# Create detailed analysis
detailed_analysis = f"""
### Detailed Analysis:
| Property | Value |
|----------|-------|
| AI Score | {ai_score:.4f} |
| Real Score | {real_score:.4f} |
| Prediction | {'AI-generated' if is_ai_generated else 'Real'} |
| Model | Enhanced AI Image Detector |
| Architecture | EfficientNetV2-S |
"""
return message, detailed_analysis
except Exception as e:
error_message = f"Error analyzing image: {str(e)}"
print(f"Exception in analyze_image: {str(e)}")
import traceback
traceback.print_exc()
return error_message, ""
# Create the Gradio interface
with gr.Blocks(title="Enhanced AI Image Detector") as demo:
gr.Markdown("# πŸ” Enhanced AI Image Detector")
gr.Markdown("Upload an image to determine if it's real or AI-generated.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
analyze_button = gr.Button("Analyze Image", variant="primary")
# Example images are commented out to avoid 404 errors
# You can add your own example images here if needed
# gr.Examples(
# examples=[
# "path/to/example_real_image.jpg",
# "path/to/example_ai_image.jpg"
# ],
# inputs=input_image,
# label="Example Images"
# )
with gr.Column(scale=1):
result_text = gr.Textbox(label="Result")
detailed_output = gr.Markdown(label="Detailed Analysis")
analyze_button.click(
fn=analyze_image,
inputs=[input_image],
outputs=[result_text, detailed_output]
)
gr.Markdown("""
## How it works
This model uses a trained PyTorch neural network (EfficientNetV2-S) to detect AI-generated images. The model has been trained on a large dataset of real and AI-generated images to learn the subtle differences between them.
The model can detect patterns that are often invisible to the human eye, including:
1. **Noise and artifact patterns** specific to AI generation methods
2. **Texture inconsistencies** that appear in AI-generated content
3. **Color and lighting anomalies** common in synthetic images
4. **Structural patterns** that differ from natural photographs
## Limitations
- The model may struggle with highly realistic AI-generated images from newer generation models
- Some real images with unusual characteristics may be misclassified
- Performance depends on image quality and resolution
- The model works best with images similar to those in its training dataset
""")
# Launch the app
demo.launch()