Rohan Kumar Shah
added real and forgery detection model
b7c5baf
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from model import FFTCNN # Import the model architecture
class ModelLoader:
"""
A class to load and hold the PyTorch CNN model.
"""
def __init__(self, model_repo_id: str, model_filename: str):
"""
Initializes the ModelLoader and loads the model.
Args:
model_repo_id (str): The repository ID on Hugging Face.
model_filename (str): The name of the model file (.pth) in the repository.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
print("FFT CNN model loaded successfully.")
def _load_fft_model(self, repo_id: str, filename: str):
"""
Downloads and loads the FFT CNN model from a Hugging Face Hub repository.
Args:
repo_id (str): The repository ID on Hugging Face.
filename (str): The name of the model file (.pth) in the repository.
Returns:
The loaded PyTorch model object.
"""
print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
try:
# Download the model file from the Hub. It returns the cached path.
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Model downloaded to: {model_path}")
# Initialize the model architecture
model = FFTCNN()
# Load the saved weights (state_dict) into the model
model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
# Set the model to evaluation mode
model.to(self.device)
model.eval()
return model
except Exception as e:
print(f"Error downloading or loading model from Hugging Face: {e}")
raise
# --- Global Model Instance ---
MODEL_REPO_ID = 'rhnsa/real_forged_classifier'
MODEL_FILENAME = 'fft_cnn_model_78.pth'
models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)