Rohan Kumar Shah
ai_human_classifier_added
bfb2e8a
raw
history blame
3.04 kB
import clip
import torch
import joblib
from pathlib import Path
from huggingface_hub import hf_hub_download
class ModelLoader:
"""
A class to load and hold the machine learning models.
This ensures that models are loaded only once.
"""
def __init__(self, clip_model_name: str, svm_repo_id: str, svm_filename: str):
"""
Initializes the ModelLoader and loads the models.
Args:
clip_model_name (str): The name of the CLIP model to load (e.g., 'ViT-L/14').
svm_repo_id (str): The repository ID on Hugging Face (e.g., 'rhnsa/ai_human_image_detector').
svm_filename (str): The name of the model file in the repository (e.g., 'model.joblib').
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
self.clip_model, self.clip_preprocess = self._load_clip_model(clip_model_name)
self.svm_model = self._load_svm_model(repo_id=svm_repo_id, filename=svm_filename)
print("Models loaded successfully.")
def _load_clip_model(self, model_name: str):
"""
Loads the specified CLIP model and its preprocessor.
Args:
model_name (str): The name of the CLIP model.
Returns:
A tuple containing the loaded CLIP model and its preprocess function.
"""
try:
model, preprocess = clip.load(model_name, device=self.device)
return model, preprocess
except Exception as e:
print(f"Error loading CLIP model: {e}")
raise
def _load_svm_model(self, repo_id: str, filename: str):
"""
Downloads and loads the SVM model from a Hugging Face Hub repository.
Args:
repo_id (str): The repository ID on Hugging Face.
filename (str): The name of the model file in the repository.
Returns:
The loaded SVM model object.
"""
print(f"Downloading SVM model from Hugging Face repo: {repo_id}")
try:
# Download the model file from the Hub. It returns the cached path.
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"SVM model downloaded to: {model_path}")
# Load the model from the downloaded path
svm_model = joblib.load(model_path)
return svm_model
except Exception as e:
print(f"Error downloading or loading SVM model from Hugging Face: {e}")
raise
# --- Global Model Instance ---
# This creates a single instance of the models that can be imported by other modules.
CLIP_MODEL_NAME = 'ViT-L/14'
SVM_REPO_ID = 'rhnsa/ai_human_image_detector'
SVM_FILENAME = 'svm_model_real.joblib' # The name of your model file in the Hugging Face repo
# This instance will be created when the application starts.
models = ModelLoader(
clip_model_name=CLIP_MODEL_NAME,
svm_repo_id=SVM_REPO_ID,
svm_filename=SVM_FILENAME
)