Spaces:
Sleeping
Sleeping
rag added
#1
by
Sangyog10
- opened
- features/real_forged_classifier/controller.py +0 -36
- features/real_forged_classifier/inferencer.py +0 -52
- features/real_forged_classifier/main.py +0 -26
- features/real_forged_classifier/model.py +0 -34
- features/real_forged_classifier/model_loader.py +0 -60
- features/real_forged_classifier/preprocessor.py +0 -67
- features/real_forged_classifier/routes.py +0 -37
features/real_forged_classifier/controller.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
from typing import IO
|
2 |
-
from preprocessor import preprocessor
|
3 |
-
from inferencer import interferencer
|
4 |
-
|
5 |
-
class ClassificationController:
|
6 |
-
"""
|
7 |
-
Controller to handle the image classification logic.
|
8 |
-
"""
|
9 |
-
def classify_image(self, image_file: IO) -> dict:
|
10 |
-
"""
|
11 |
-
Orchestrates the classification of a single image file.
|
12 |
-
|
13 |
-
Args:
|
14 |
-
image_file (IO): The image file to classify.
|
15 |
-
|
16 |
-
Returns:
|
17 |
-
dict: The classification result.
|
18 |
-
"""
|
19 |
-
try:
|
20 |
-
# Step 1: Preprocess the image
|
21 |
-
image_tensor = preprocessor.process(image_file)
|
22 |
-
|
23 |
-
# Step 2: Perform inference
|
24 |
-
result = interferencer.predict(image_tensor)
|
25 |
-
|
26 |
-
return result
|
27 |
-
except ValueError as e:
|
28 |
-
# Handle specific errors like invalid images
|
29 |
-
return {"error": str(e)}
|
30 |
-
except Exception as e:
|
31 |
-
# Handle unexpected errors
|
32 |
-
print(f"An unexpected error occurred: {e}")
|
33 |
-
return {"error": "An internal error occurred during classification."}
|
34 |
-
|
35 |
-
# Create a single instance of the controller
|
36 |
-
controller = ClassificationController()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features/real_forged_classifier/inferencer.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
# Import the globally loaded models instance
|
6 |
-
from model_loader import models
|
7 |
-
|
8 |
-
class Interferencer:
|
9 |
-
"""
|
10 |
-
Performs inference using the FFT CNN model.
|
11 |
-
"""
|
12 |
-
def __init__(self):
|
13 |
-
"""
|
14 |
-
Initializes the interferencer with the loaded model.
|
15 |
-
"""
|
16 |
-
self.fft_model = models.fft_model
|
17 |
-
|
18 |
-
@torch.no_grad()
|
19 |
-
def predict(self, image_tensor: torch.Tensor) -> dict:
|
20 |
-
"""
|
21 |
-
Takes a preprocessed image tensor and returns the classification result.
|
22 |
-
|
23 |
-
Args:
|
24 |
-
image_tensor (torch.Tensor): The preprocessed image tensor.
|
25 |
-
|
26 |
-
Returns:
|
27 |
-
dict: A dictionary containing the classification label and confidence score.
|
28 |
-
"""
|
29 |
-
# 1. Get model outputs (logits)
|
30 |
-
outputs = self.fft_model(image_tensor)
|
31 |
-
|
32 |
-
# 2. Apply softmax to get probabilities
|
33 |
-
probabilities = F.softmax(outputs, dim=1)
|
34 |
-
|
35 |
-
# 3. Get the confidence and the predicted class index
|
36 |
-
confidence, predicted_idx = torch.max(probabilities, 1)
|
37 |
-
|
38 |
-
prediction = predicted_idx.item()
|
39 |
-
|
40 |
-
# 4. Map the prediction to a human-readable label
|
41 |
-
# Ensure this mapping matches the labels used during training
|
42 |
-
# Typically: 0 -> fake, 1 -> real
|
43 |
-
label_map = {0: 'fake', 1: 'real'}
|
44 |
-
classification_label = label_map.get(prediction, "unknown")
|
45 |
-
|
46 |
-
return {
|
47 |
-
"classification": classification_label,
|
48 |
-
"confidence": confidence.item()
|
49 |
-
}
|
50 |
-
|
51 |
-
# Create a single instance of the interferencer
|
52 |
-
interferencer = Interferencer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features/real_forged_classifier/main.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
-
from routes import router as api_router
|
3 |
-
|
4 |
-
# Initialize the FastAPI app
|
5 |
-
app = FastAPI(
|
6 |
-
title="Real vs. Fake Image Classification API",
|
7 |
-
description="An API to classify images as real or forged using FFT and cnn.",
|
8 |
-
version="1.0.0"
|
9 |
-
)
|
10 |
-
|
11 |
-
# Include the API router
|
12 |
-
# All routes defined in routes.py will be available under the /api prefix
|
13 |
-
app.include_router(api_router, prefix="/api", tags=["Classification"])
|
14 |
-
|
15 |
-
@app.get("/", tags=["Root"])
|
16 |
-
async def read_root():
|
17 |
-
"""
|
18 |
-
A simple root endpoint to confirm the API is running.
|
19 |
-
"""
|
20 |
-
return {"message": "Welcome to the Image Classification API. Go to /docs for the API documentation."}
|
21 |
-
|
22 |
-
# To run this application:
|
23 |
-
# 1. Make sure you have all dependencies from requirements.txt installed.
|
24 |
-
# 2. Make sure the 'svm_model.joblib' file is in the same directory.
|
25 |
-
# 3. Run the following command in your terminal:
|
26 |
-
# uvicorn main:app --reload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features/real_forged_classifier/model.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
class FFTCNN(nn.Module):
|
6 |
-
"""
|
7 |
-
Defines the Convolutional Neural Network architecture.
|
8 |
-
This structure must match the model that was trained and saved.
|
9 |
-
"""
|
10 |
-
def __init__(self):
|
11 |
-
super(FFTCNN, self).__init__()
|
12 |
-
# Ensure 'self.' is used here to define the layers as instance attributes
|
13 |
-
self.conv_layers = nn.Sequential(
|
14 |
-
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
15 |
-
nn.ReLU(),
|
16 |
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
17 |
-
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
18 |
-
nn.ReLU(),
|
19 |
-
nn.MaxPool2d(kernel_size=2, stride=2)
|
20 |
-
)
|
21 |
-
|
22 |
-
# Ensure 'self.' is used here as well
|
23 |
-
self.fc_layers = nn.Sequential(
|
24 |
-
nn.Linear(32 * 56 * 56, 128), # This size depends on your 224x224 input
|
25 |
-
nn.ReLU(),
|
26 |
-
nn.Linear(128, 2) # 2 output classes
|
27 |
-
)
|
28 |
-
|
29 |
-
def forward(self, x):
|
30 |
-
# Now, 'self.conv_layers' can be found because it was defined correctly
|
31 |
-
x = self.conv_layers(x)
|
32 |
-
x = x.view(x.size(0), -1) # Flatten the feature maps
|
33 |
-
x = self.fc_layers(x)
|
34 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features/real_forged_classifier/model_loader.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from pathlib import Path
|
3 |
-
from huggingface_hub import hf_hub_download
|
4 |
-
from model import FFTCNN # Import the model architecture
|
5 |
-
|
6 |
-
class ModelLoader:
|
7 |
-
"""
|
8 |
-
A class to load and hold the PyTorch CNN model.
|
9 |
-
"""
|
10 |
-
def __init__(self, model_repo_id: str, model_filename: str):
|
11 |
-
"""
|
12 |
-
Initializes the ModelLoader and loads the model.
|
13 |
-
|
14 |
-
Args:
|
15 |
-
model_repo_id (str): The repository ID on Hugging Face.
|
16 |
-
model_filename (str): The name of the model file (.pth) in the repository.
|
17 |
-
"""
|
18 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
-
print(f"Using device: {self.device}")
|
20 |
-
|
21 |
-
self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
|
22 |
-
print("FFT CNN model loaded successfully.")
|
23 |
-
|
24 |
-
def _load_fft_model(self, repo_id: str, filename: str):
|
25 |
-
"""
|
26 |
-
Downloads and loads the FFT CNN model from a Hugging Face Hub repository.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
repo_id (str): The repository ID on Hugging Face.
|
30 |
-
filename (str): The name of the model file (.pth) in the repository.
|
31 |
-
|
32 |
-
Returns:
|
33 |
-
The loaded PyTorch model object.
|
34 |
-
"""
|
35 |
-
print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
|
36 |
-
try:
|
37 |
-
# Download the model file from the Hub. It returns the cached path.
|
38 |
-
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
39 |
-
print(f"Model downloaded to: {model_path}")
|
40 |
-
|
41 |
-
# Initialize the model architecture
|
42 |
-
model = FFTCNN()
|
43 |
-
|
44 |
-
# Load the saved weights (state_dict) into the model
|
45 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
|
46 |
-
|
47 |
-
# Set the model to evaluation mode
|
48 |
-
model.to(self.device)
|
49 |
-
model.eval()
|
50 |
-
|
51 |
-
return model
|
52 |
-
except Exception as e:
|
53 |
-
print(f"Error downloading or loading model from Hugging Face: {e}")
|
54 |
-
raise
|
55 |
-
|
56 |
-
# --- Global Model Instance ---
|
57 |
-
MODEL_REPO_ID = 'rhnsa/real_forged_classifier'
|
58 |
-
MODEL_FILENAME = 'fft_cnn_model_78.pth'
|
59 |
-
models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features/real_forged_classifier/preprocessor.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
from PIL import Image
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from typing import IO
|
5 |
-
import cv2
|
6 |
-
from torchvision import transforms
|
7 |
-
|
8 |
-
# Import the globally loaded models instance
|
9 |
-
from model_loader import models
|
10 |
-
|
11 |
-
class ImagePreprocessor:
|
12 |
-
"""
|
13 |
-
Handles preprocessing of images for the FFT CNN model.
|
14 |
-
"""
|
15 |
-
def __init__(self):
|
16 |
-
"""
|
17 |
-
Initializes the preprocessor.
|
18 |
-
"""
|
19 |
-
self.device = models.device
|
20 |
-
# Define the image transformations, matching the training process
|
21 |
-
self.transform = transforms.Compose([
|
22 |
-
transforms.ToPILImage(),
|
23 |
-
transforms.Resize((224, 224)),
|
24 |
-
transforms.ToTensor(),
|
25 |
-
])
|
26 |
-
|
27 |
-
def process(self, image_file: IO) -> torch.Tensor:
|
28 |
-
"""
|
29 |
-
Opens an image file, applies FFT, preprocesses it, and returns a tensor.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
image_file (IO): The image file object (e.g., from a file upload).
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
torch.Tensor: The preprocessed image as a tensor, ready for the model.
|
36 |
-
"""
|
37 |
-
try:
|
38 |
-
# Read the image file into a numpy array
|
39 |
-
image_np = np.frombuffer(image_file.read(), np.uint8)
|
40 |
-
# Decode the image as grayscale
|
41 |
-
img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
|
42 |
-
except Exception as e:
|
43 |
-
print(f"Error reading or decoding image: {e}")
|
44 |
-
raise ValueError("Invalid or corrupted image file.")
|
45 |
-
|
46 |
-
if img is None:
|
47 |
-
raise ValueError("Could not decode image. File may be empty or corrupted.")
|
48 |
-
|
49 |
-
# 1. Apply Fast Fourier Transform (FFT)
|
50 |
-
f = np.fft.fft2(img)
|
51 |
-
fshift = np.fft.fftshift(f)
|
52 |
-
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
53 |
-
|
54 |
-
# Normalize the magnitude spectrum to be in the range [0, 255]
|
55 |
-
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
|
56 |
-
magnitude_spectrum = np.uint8(magnitude_spectrum)
|
57 |
-
|
58 |
-
# 2. Apply torchvision transforms
|
59 |
-
image_tensor = self.transform(magnitude_spectrum)
|
60 |
-
|
61 |
-
# Add a batch dimension and move to the correct device
|
62 |
-
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
63 |
-
|
64 |
-
return image_tensor
|
65 |
-
|
66 |
-
# Create a single instance of the preprocessor
|
67 |
-
preprocessor = ImagePreprocessor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features/real_forged_classifier/routes.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
from fastapi import APIRouter, File, UploadFile, HTTPException, status
|
2 |
-
from fastapi.responses import JSONResponse
|
3 |
-
|
4 |
-
# Import the controller instance
|
5 |
-
from controller import controller
|
6 |
-
|
7 |
-
# Create an API router
|
8 |
-
router = APIRouter()
|
9 |
-
|
10 |
-
@router.post("/classify_forgery", summary="Classify an image as Real or Fake")
|
11 |
-
async def classify_image_endpoint(image: UploadFile = File(...)):
|
12 |
-
"""
|
13 |
-
Accepts an image file and classifies it as 'real' or 'fake'.
|
14 |
-
|
15 |
-
- **image**: The image file to be classified (e.g., JPEG, PNG).
|
16 |
-
|
17 |
-
Returns a JSON object with the classification and a confidence score.
|
18 |
-
"""
|
19 |
-
# Check for a valid image content type
|
20 |
-
if not image.content_type.startswith("image/"):
|
21 |
-
raise HTTPException(
|
22 |
-
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
23 |
-
detail="Unsupported file type. Please upload an image (e.g., JPEG, PNG)."
|
24 |
-
)
|
25 |
-
|
26 |
-
# The controller expects a file-like object, which `image.file` provides
|
27 |
-
result = controller.classify_image(image.file)
|
28 |
-
|
29 |
-
if "error" in result:
|
30 |
-
# If the controller returned an error, forward it as an HTTP exception
|
31 |
-
raise HTTPException(
|
32 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
33 |
-
detail=result["error"]
|
34 |
-
)
|
35 |
-
|
36 |
-
return JSONResponse(content=result, status_code=status.HTTP_200_OK)
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|