Rohan Kumar Shah
added real and forgery detection model
b7c5baf
import torch
import torch.nn.functional as F
import numpy as np
# Import the globally loaded models instance
from model_loader import models
class Interferencer:
"""
Performs inference using the FFT CNN model.
"""
def __init__(self):
"""
Initializes the interferencer with the loaded model.
"""
self.fft_model = models.fft_model
@torch.no_grad()
def predict(self, image_tensor: torch.Tensor) -> dict:
"""
Takes a preprocessed image tensor and returns the classification result.
Args:
image_tensor (torch.Tensor): The preprocessed image tensor.
Returns:
dict: A dictionary containing the classification label and confidence score.
"""
# 1. Get model outputs (logits)
outputs = self.fft_model(image_tensor)
# 2. Apply softmax to get probabilities
probabilities = F.softmax(outputs, dim=1)
# 3. Get the confidence and the predicted class index
confidence, predicted_idx = torch.max(probabilities, 1)
prediction = predicted_idx.item()
# 4. Map the prediction to a human-readable label
# Ensure this mapping matches the labels used during training
# Typically: 0 -> fake, 1 -> real
label_map = {0: 'fake', 1: 'real'}
classification_label = label_map.get(prediction, "unknown")
return {
"classification": classification_label,
"confidence": confidence.item()
}
# Create a single instance of the interferencer
interferencer = Interferencer()