AI_API / features /text_classifier /model_loader.py
Pujan-Dev's picture
fixed :fixed the testing error
49fe170
import json
import logging
import pickle
import shutil
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from config import Config
REPO_ID = Config.REPO_ID_LANG
MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None
HF_TOKEN = Config.HF_TOKEN
ENGLISH_SUBDIR = "English_model"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REQUIRED_FILES = (
"classifier.pkl",
"scaler.pkl",
"word_vectorizer.pkl",
"char_vectorizer.pkl",
"feature_names.json",
"metadata.json",
)
def _patch_legacy_logistic_model(model):
"""Backfill attributes expected by newer sklearn versions."""
if isinstance(model, (LogisticRegression, LogisticRegressionCV)) and not hasattr(model, "multi_class"):
model.multi_class = "auto"
return model
def _has_required_artifacts(model_dir: Path) -> bool:
if not model_dir.exists() or not model_dir.is_dir():
return False
return all((model_dir / filename).exists() for filename in REQUIRED_FILES)
def _resolve_artifact_dir(base_dir: Path) -> Path | None:
candidates = [base_dir, base_dir / ENGLISH_SUBDIR]
for candidate in candidates:
if _has_required_artifacts(candidate):
return candidate
return None
def warmup():
logging.info("Warming up model...")
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
if _resolve_artifact_dir(MODEL_DIR):
logging.info("Model artifacts already exist, skipping download.")
return
download_model_repo()
def download_model_repo():
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
if not REPO_ID:
raise ValueError("English_model repo id is not configured")
if _resolve_artifact_dir(MODEL_DIR):
logging.info("Model artifacts already exist, skipping download.")
return
snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN))
source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path
MODEL_DIR.mkdir(parents=True, exist_ok=True)
shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True)
def load_model():
if MODEL_DIR is None:
raise ValueError("LANG_MODEL is not configured")
artifact_dir = _resolve_artifact_dir(MODEL_DIR)
if artifact_dir is None:
logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR)
download_model_repo()
artifact_dir = _resolve_artifact_dir(MODEL_DIR)
if artifact_dir is None:
raise FileNotFoundError(
f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}"
)
with open(artifact_dir / "classifier.pkl", "rb") as f:
loaded_classifier = pickle.load(f)
loaded_classifier = _patch_legacy_logistic_model(loaded_classifier)
with open(artifact_dir / "scaler.pkl", "rb") as f:
loaded_scaler = pickle.load(f)
with open(artifact_dir / "word_vectorizer.pkl", "rb") as f:
loaded_word_vectorizer = pickle.load(f)
with open(artifact_dir / "char_vectorizer.pkl", "rb") as f:
loaded_char_vectorizer = pickle.load(f)
with open(artifact_dir / "feature_names.json", "r") as f:
loaded_features = json.load(f)
with open(artifact_dir / "metadata.json", "r") as f:
loaded_metadata = json.load(f)
return (
loaded_classifier,
loaded_scaler,
loaded_word_vectorizer,
loaded_char_vectorizer,
loaded_features,
loaded_metadata,
)