File size: 2,252 Bytes
b7c5baf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from model import FFTCNN # Import the model architecture

class ModelLoader:
    """
    A class to load and hold the PyTorch CNN model.
    """
    def __init__(self, model_repo_id: str, model_filename: str):
        """
        Initializes the ModelLoader and loads the model.

        Args:
            model_repo_id (str): The repository ID on Hugging Face.
            model_filename (str): The name of the model file (.pth) in the repository.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
        print("FFT CNN model loaded successfully.")

    def _load_fft_model(self, repo_id: str, filename: str):
        """
        Downloads and loads the FFT CNN model from a Hugging Face Hub repository.

        Args:
            repo_id (str): The repository ID on Hugging Face.
            filename (str): The name of the model file (.pth) in the repository.

        Returns:
            The loaded PyTorch model object.
        """
        print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
        try:
            # Download the model file from the Hub. It returns the cached path.
            model_path = hf_hub_download(repo_id=repo_id, filename=filename)
            print(f"Model downloaded to: {model_path}")
            
            # Initialize the model architecture
            model = FFTCNN()
            
            # Load the saved weights (state_dict) into the model
            model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
            
            # Set the model to evaluation mode
            model.to(self.device)
            model.eval()
            
            return model
        except Exception as e:
            print(f"Error downloading or loading model from Hugging Face: {e}")
            raise

# --- Global Model Instance ---
MODEL_REPO_ID = 'rhnsa/real_forged_classifier' 
MODEL_FILENAME = 'fft_cnn_model_78.pth'
models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)