File size: 3,043 Bytes
bfb2e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import clip
import torch
import joblib
from pathlib import Path
from huggingface_hub import hf_hub_download

class ModelLoader:
    """
    A class to load and hold the machine learning models.
    This ensures that models are loaded only once.
    """
    def __init__(self, clip_model_name: str, svm_repo_id: str, svm_filename: str):
        """
        Initializes the ModelLoader and loads the models.

        Args:
            clip_model_name (str): The name of the CLIP model to load (e.g., 'ViT-L/14').
            svm_repo_id (str): The repository ID on Hugging Face (e.g., 'rhnsa/ai_human_image_detector').
            svm_filename (str): The name of the model file in the repository (e.g., 'model.joblib').
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        self.clip_model, self.clip_preprocess = self._load_clip_model(clip_model_name)
        self.svm_model = self._load_svm_model(repo_id=svm_repo_id, filename=svm_filename)
        print("Models loaded successfully.")

    def _load_clip_model(self, model_name: str):
        """
        Loads the specified CLIP model and its preprocessor.
        
        Args:
            model_name (str): The name of the CLIP model.
        
        Returns:
            A tuple containing the loaded CLIP model and its preprocess function.
        """
        try:
            model, preprocess = clip.load(model_name, device=self.device)
            return model, preprocess
        except Exception as e:
            print(f"Error loading CLIP model: {e}")
            raise

    def _load_svm_model(self, repo_id: str, filename: str):
        """
        Downloads and loads the SVM model from a Hugging Face Hub repository.

        Args:
            repo_id (str): The repository ID on Hugging Face.
            filename (str): The name of the model file in the repository.

        Returns:
            The loaded SVM model object.
        """
        print(f"Downloading SVM model from Hugging Face repo: {repo_id}")
        try:
            # Download the model file from the Hub. It returns the cached path.
            model_path = hf_hub_download(repo_id=repo_id, filename=filename)
            print(f"SVM model downloaded to: {model_path}")
            
            # Load the model from the downloaded path
            svm_model = joblib.load(model_path)
            return svm_model
        except Exception as e:
            print(f"Error downloading or loading SVM model from Hugging Face: {e}")
            raise

# --- Global Model Instance ---
# This creates a single instance of the models that can be imported by other modules.
CLIP_MODEL_NAME = 'ViT-L/14'
SVM_REPO_ID = 'rhnsa/ai_human_image_detector'
SVM_FILENAME = 'svm_model_real.joblib' # The name of your model file in the Hugging Face repo

# This instance will be created when the application starts.
models = ModelLoader(
    clip_model_name=CLIP_MODEL_NAME, 
    svm_repo_id=SVM_REPO_ID, 
    svm_filename=SVM_FILENAME
)