import torch import torch.nn.functional as F import numpy as np # Import the globally loaded models instance from model_loader import models class Interferencer: """ Performs inference using the FFT CNN model. """ def __init__(self): """ Initializes the interferencer with the loaded model. """ self.fft_model = models.fft_model @torch.no_grad() def predict(self, image_tensor: torch.Tensor) -> dict: """ Takes a preprocessed image tensor and returns the classification result. Args: image_tensor (torch.Tensor): The preprocessed image tensor. Returns: dict: A dictionary containing the classification label and confidence score. """ # 1. Get model outputs (logits) outputs = self.fft_model(image_tensor) # 2. Apply softmax to get probabilities probabilities = F.softmax(outputs, dim=1) # 3. Get the confidence and the predicted class index confidence, predicted_idx = torch.max(probabilities, 1) prediction = predicted_idx.item() # 4. Map the prediction to a human-readable label # Ensure this mapping matches the labels used during training # Typically: 0 -> fake, 1 -> real label_map = {0: 'fake', 1: 'real'} classification_label = label_map.get(prediction, "unknown") return { "classification": classification_label, "confidence": confidence.item() } # Create a single instance of the interferencer interferencer = Interferencer()