Rohan Kumar Shah
added real and forgery detection model
b7c5baf
from PIL import Image
import torch
import numpy as np
from typing import IO
import cv2
from torchvision import transforms
# Import the globally loaded models instance
from model_loader import models
class ImagePreprocessor:
"""
Handles preprocessing of images for the FFT CNN model.
"""
def __init__(self):
"""
Initializes the preprocessor.
"""
self.device = models.device
# Define the image transformations, matching the training process
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def process(self, image_file: IO) -> torch.Tensor:
"""
Opens an image file, applies FFT, preprocesses it, and returns a tensor.
Args:
image_file (IO): The image file object (e.g., from a file upload).
Returns:
torch.Tensor: The preprocessed image as a tensor, ready for the model.
"""
try:
# Read the image file into a numpy array
image_np = np.frombuffer(image_file.read(), np.uint8)
# Decode the image as grayscale
img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
except Exception as e:
print(f"Error reading or decoding image: {e}")
raise ValueError("Invalid or corrupted image file.")
if img is None:
raise ValueError("Could not decode image. File may be empty or corrupted.")
# 1. Apply Fast Fourier Transform (FFT)
f = np.fft.fft2(img)
fshift = np.fft.fftshift(f)
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
# Normalize the magnitude spectrum to be in the range [0, 255]
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
magnitude_spectrum = np.uint8(magnitude_spectrum)
# 2. Apply torchvision transforms
image_tensor = self.transform(magnitude_spectrum)
# Add a batch dimension and move to the correct device
image_tensor = image_tensor.unsqueeze(0).to(self.device)
return image_tensor
# Create a single instance of the preprocessor
preprocessor = ImagePreprocessor()