Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
# Import the globally loaded models instance | |
from model_loader import models | |
class Interferencer: | |
""" | |
Performs inference using the FFT CNN model. | |
""" | |
def __init__(self): | |
""" | |
Initializes the interferencer with the loaded model. | |
""" | |
self.fft_model = models.fft_model | |
def predict(self, image_tensor: torch.Tensor) -> dict: | |
""" | |
Takes a preprocessed image tensor and returns the classification result. | |
Args: | |
image_tensor (torch.Tensor): The preprocessed image tensor. | |
Returns: | |
dict: A dictionary containing the classification label and confidence score. | |
""" | |
# 1. Get model outputs (logits) | |
outputs = self.fft_model(image_tensor) | |
# 2. Apply softmax to get probabilities | |
probabilities = F.softmax(outputs, dim=1) | |
# 3. Get the confidence and the predicted class index | |
confidence, predicted_idx = torch.max(probabilities, 1) | |
prediction = predicted_idx.item() | |
# 4. Map the prediction to a human-readable label | |
# Ensure this mapping matches the labels used during training | |
# Typically: 0 -> fake, 1 -> real | |
label_map = {0: 'fake', 1: 'real'} | |
classification_label = label_map.get(prediction, "unknown") | |
return { | |
"classification": classification_label, | |
"confidence": confidence.item() | |
} | |
# Create a single instance of the interferencer | |
interferencer = Interferencer() | |