Spaces:
Sleeping
Sleeping
import torch | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
from model import FFTCNN # Import the model architecture | |
class ModelLoader: | |
""" | |
A class to load and hold the PyTorch CNN model. | |
""" | |
def __init__(self, model_repo_id: str, model_filename: str): | |
""" | |
Initializes the ModelLoader and loads the model. | |
Args: | |
model_repo_id (str): The repository ID on Hugging Face. | |
model_filename (str): The name of the model file (.pth) in the repository. | |
""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename) | |
print("FFT CNN model loaded successfully.") | |
def _load_fft_model(self, repo_id: str, filename: str): | |
""" | |
Downloads and loads the FFT CNN model from a Hugging Face Hub repository. | |
Args: | |
repo_id (str): The repository ID on Hugging Face. | |
filename (str): The name of the model file (.pth) in the repository. | |
Returns: | |
The loaded PyTorch model object. | |
""" | |
print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}") | |
try: | |
# Download the model file from the Hub. It returns the cached path. | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
print(f"Model downloaded to: {model_path}") | |
# Initialize the model architecture | |
model = FFTCNN() | |
# Load the saved weights (state_dict) into the model | |
model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) | |
# Set the model to evaluation mode | |
model.to(self.device) | |
model.eval() | |
return model | |
except Exception as e: | |
print(f"Error downloading or loading model from Hugging Face: {e}") | |
raise | |
# --- Global Model Instance --- | |
MODEL_REPO_ID = 'rhnsa/real_forged_classifier' | |
MODEL_FILENAME = 'fft_cnn_model_78.pth' | |
models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME) | |