File size: 1,634 Bytes
b7c5baf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn.functional as F
import numpy as np

# Import the globally loaded models instance
from model_loader import models

class Interferencer:
    """
    Performs inference using the FFT CNN model.
    """
    def __init__(self):
        """
        Initializes the interferencer with the loaded model.
        """
        self.fft_model = models.fft_model

    @torch.no_grad()
    def predict(self, image_tensor: torch.Tensor) -> dict:
        """
        Takes a preprocessed image tensor and returns the classification result.

        Args:
            image_tensor (torch.Tensor): The preprocessed image tensor.

        Returns:
            dict: A dictionary containing the classification label and confidence score.
        """
        # 1. Get model outputs (logits)
        outputs = self.fft_model(image_tensor)
        
        # 2. Apply softmax to get probabilities
        probabilities = F.softmax(outputs, dim=1)
        
        # 3. Get the confidence and the predicted class index
        confidence, predicted_idx = torch.max(probabilities, 1)
        
        prediction = predicted_idx.item()
        
        # 4. Map the prediction to a human-readable label
        # Ensure this mapping matches the labels used during training
        # Typically: 0 -> fake, 1 -> real
        label_map = {0: 'fake', 1: 'real'}
        classification_label = label_map.get(prediction, "unknown")

        return {
            "classification": classification_label,
            "confidence": confidence.item()
        }

# Create a single instance of the interferencer
interferencer = Interferencer()