Spaces:
Sleeping
Sleeping
import clip | |
import torch | |
import joblib | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
class ModelLoader: | |
""" | |
A class to load and hold the machine learning models. | |
This ensures that models are loaded only once. | |
""" | |
def __init__(self, clip_model_name: str, svm_repo_id: str, svm_filename: str): | |
""" | |
Initializes the ModelLoader and loads the models. | |
Args: | |
clip_model_name (str): The name of the CLIP model to load (e.g., 'ViT-L/14'). | |
svm_repo_id (str): The repository ID on Hugging Face (e.g., 'rhnsa/ai_human_image_detector'). | |
svm_filename (str): The name of the model file in the repository (e.g., 'model.joblib'). | |
""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
self.clip_model, self.clip_preprocess = self._load_clip_model(clip_model_name) | |
self.svm_model = self._load_svm_model(repo_id=svm_repo_id, filename=svm_filename) | |
print("Models loaded successfully.") | |
def _load_clip_model(self, model_name: str): | |
""" | |
Loads the specified CLIP model and its preprocessor. | |
Args: | |
model_name (str): The name of the CLIP model. | |
Returns: | |
A tuple containing the loaded CLIP model and its preprocess function. | |
""" | |
try: | |
model, preprocess = clip.load(model_name, device=self.device) | |
return model, preprocess | |
except Exception as e: | |
print(f"Error loading CLIP model: {e}") | |
raise | |
def _load_svm_model(self, repo_id: str, filename: str): | |
""" | |
Downloads and loads the SVM model from a Hugging Face Hub repository. | |
Args: | |
repo_id (str): The repository ID on Hugging Face. | |
filename (str): The name of the model file in the repository. | |
Returns: | |
The loaded SVM model object. | |
""" | |
print(f"Downloading SVM model from Hugging Face repo: {repo_id}") | |
try: | |
# Download the model file from the Hub. It returns the cached path. | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
print(f"SVM model downloaded to: {model_path}") | |
# Load the model from the downloaded path | |
svm_model = joblib.load(model_path) | |
return svm_model | |
except Exception as e: | |
print(f"Error downloading or loading SVM model from Hugging Face: {e}") | |
raise | |
# --- Global Model Instance --- | |
# This creates a single instance of the models that can be imported by other modules. | |
CLIP_MODEL_NAME = 'ViT-L/14' | |
SVM_REPO_ID = 'rhnsa/ai_human_image_detector' | |
SVM_FILENAME = 'svm_model_real.joblib' # The name of your model file in the Hugging Face repo | |
# This instance will be created when the application starts. | |
models = ModelLoader( | |
clip_model_name=CLIP_MODEL_NAME, | |
svm_repo_id=SVM_REPO_ID, | |
svm_filename=SVM_FILENAME | |
) | |