Spaces:
Sleeping
Sleeping
from PIL import Image | |
import torch | |
import numpy as np | |
from typing import IO | |
import cv2 | |
from torchvision import transforms | |
# Import the globally loaded models instance | |
from model_loader import models | |
class ImagePreprocessor: | |
""" | |
Handles preprocessing of images for the FFT CNN model. | |
""" | |
def __init__(self): | |
""" | |
Initializes the preprocessor. | |
""" | |
self.device = models.device | |
# Define the image transformations, matching the training process | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
def process(self, image_file: IO) -> torch.Tensor: | |
""" | |
Opens an image file, applies FFT, preprocesses it, and returns a tensor. | |
Args: | |
image_file (IO): The image file object (e.g., from a file upload). | |
Returns: | |
torch.Tensor: The preprocessed image as a tensor, ready for the model. | |
""" | |
try: | |
# Read the image file into a numpy array | |
image_np = np.frombuffer(image_file.read(), np.uint8) | |
# Decode the image as grayscale | |
img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE) | |
except Exception as e: | |
print(f"Error reading or decoding image: {e}") | |
raise ValueError("Invalid or corrupted image file.") | |
if img is None: | |
raise ValueError("Could not decode image. File may be empty or corrupted.") | |
# 1. Apply Fast Fourier Transform (FFT) | |
f = np.fft.fft2(img) | |
fshift = np.fft.fftshift(f) | |
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0) | |
# Normalize the magnitude spectrum to be in the range [0, 255] | |
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) | |
magnitude_spectrum = np.uint8(magnitude_spectrum) | |
# 2. Apply torchvision transforms | |
image_tensor = self.transform(magnitude_spectrum) | |
# Add a batch dimension and move to the correct device | |
image_tensor = image_tensor.unsqueeze(0).to(self.device) | |
return image_tensor | |
# Create a single instance of the preprocessor | |
preprocessor = ImagePreprocessor() | |