Spaces:
Sleeping
Sleeping
File size: 1,634 Bytes
b7c5baf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn.functional as F
import numpy as np
# Import the globally loaded models instance
from model_loader import models
class Interferencer:
"""
Performs inference using the FFT CNN model.
"""
def __init__(self):
"""
Initializes the interferencer with the loaded model.
"""
self.fft_model = models.fft_model
@torch.no_grad()
def predict(self, image_tensor: torch.Tensor) -> dict:
"""
Takes a preprocessed image tensor and returns the classification result.
Args:
image_tensor (torch.Tensor): The preprocessed image tensor.
Returns:
dict: A dictionary containing the classification label and confidence score.
"""
# 1. Get model outputs (logits)
outputs = self.fft_model(image_tensor)
# 2. Apply softmax to get probabilities
probabilities = F.softmax(outputs, dim=1)
# 3. Get the confidence and the predicted class index
confidence, predicted_idx = torch.max(probabilities, 1)
prediction = predicted_idx.item()
# 4. Map the prediction to a human-readable label
# Ensure this mapping matches the labels used during training
# Typically: 0 -> fake, 1 -> real
label_map = {0: 'fake', 1: 'real'}
classification_label = label_map.get(prediction, "unknown")
return {
"classification": classification_label,
"confidence": confidence.item()
}
# Create a single instance of the interferencer
interferencer = Interferencer()
|