from PIL import Image import torch import numpy as np from typing import IO import cv2 from torchvision import transforms # Import the globally loaded models instance from model_loader import models class ImagePreprocessor: """ Handles preprocessing of images for the FFT CNN model. """ def __init__(self): """ Initializes the preprocessor. """ self.device = models.device # Define the image transformations, matching the training process self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((224, 224)), transforms.ToTensor(), ]) def process(self, image_file: IO) -> torch.Tensor: """ Opens an image file, applies FFT, preprocesses it, and returns a tensor. Args: image_file (IO): The image file object (e.g., from a file upload). Returns: torch.Tensor: The preprocessed image as a tensor, ready for the model. """ try: # Read the image file into a numpy array image_np = np.frombuffer(image_file.read(), np.uint8) # Decode the image as grayscale img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE) except Exception as e: print(f"Error reading or decoding image: {e}") raise ValueError("Invalid or corrupted image file.") if img is None: raise ValueError("Could not decode image. File may be empty or corrupted.") # 1. Apply Fast Fourier Transform (FFT) f = np.fft.fft2(img) fshift = np.fft.fftshift(f) magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0) # Normalize the magnitude spectrum to be in the range [0, 255] magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) magnitude_spectrum = np.uint8(magnitude_spectrum) # 2. Apply torchvision transforms image_tensor = self.transform(magnitude_spectrum) # Add a batch dimension and move to the correct device image_tensor = image_tensor.unsqueeze(0).to(self.device) return image_tensor # Create a single instance of the preprocessor preprocessor = ImagePreprocessor()