import json import logging import pickle import shutil from pathlib import Path import torch from huggingface_hub import snapshot_download from sklearn.linear_model import LogisticRegression, LogisticRegressionCV from config import Config REPO_ID = Config.REPO_ID_LANG MODEL_DIR = Path(Config.LANG_MODEL) if Config.LANG_MODEL else None HF_TOKEN = Config.HF_TOKEN ENGLISH_SUBDIR = "English_model" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") REQUIRED_FILES = ( "classifier.pkl", "scaler.pkl", "word_vectorizer.pkl", "char_vectorizer.pkl", "feature_names.json", "metadata.json", ) def _patch_legacy_logistic_model(model): """Backfill attributes expected by newer sklearn versions.""" if isinstance(model, (LogisticRegression, LogisticRegressionCV)) and not hasattr(model, "multi_class"): model.multi_class = "auto" return model def _has_required_artifacts(model_dir: Path) -> bool: if not model_dir.exists() or not model_dir.is_dir(): return False return all((model_dir / filename).exists() for filename in REQUIRED_FILES) def _resolve_artifact_dir(base_dir: Path) -> Path | None: candidates = [base_dir, base_dir / ENGLISH_SUBDIR] for candidate in candidates: if _has_required_artifacts(candidate): return candidate return None def warmup(): logging.info("Warming up model...") if MODEL_DIR is None: raise ValueError("LANG_MODEL is not configured") if _resolve_artifact_dir(MODEL_DIR): logging.info("Model artifacts already exist, skipping download.") return download_model_repo() def download_model_repo(): if MODEL_DIR is None: raise ValueError("LANG_MODEL is not configured") if not REPO_ID: raise ValueError("English_model repo id is not configured") if _resolve_artifact_dir(MODEL_DIR): logging.info("Model artifacts already exist, skipping download.") return snapshot_path = Path(snapshot_download(repo_id=REPO_ID, token=HF_TOKEN)) source_dir = snapshot_path / ENGLISH_SUBDIR if (snapshot_path / ENGLISH_SUBDIR).is_dir() else snapshot_path MODEL_DIR.mkdir(parents=True, exist_ok=True) shutil.copytree(source_dir, MODEL_DIR, dirs_exist_ok=True) def load_model(): if MODEL_DIR is None: raise ValueError("LANG_MODEL is not configured") artifact_dir = _resolve_artifact_dir(MODEL_DIR) if artifact_dir is None: logging.info("Model artifacts missing in %s, downloading now.", MODEL_DIR) download_model_repo() artifact_dir = _resolve_artifact_dir(MODEL_DIR) if artifact_dir is None: raise FileNotFoundError( f"Required model artifacts not found in {MODEL_DIR}. Expected files: {', '.join(REQUIRED_FILES)}" ) with open(artifact_dir / "classifier.pkl", "rb") as f: loaded_classifier = pickle.load(f) loaded_classifier = _patch_legacy_logistic_model(loaded_classifier) with open(artifact_dir / "scaler.pkl", "rb") as f: loaded_scaler = pickle.load(f) with open(artifact_dir / "word_vectorizer.pkl", "rb") as f: loaded_word_vectorizer = pickle.load(f) with open(artifact_dir / "char_vectorizer.pkl", "rb") as f: loaded_char_vectorizer = pickle.load(f) with open(artifact_dir / "feature_names.json", "r") as f: loaded_features = json.load(f) with open(artifact_dir / "metadata.json", "r") as f: loaded_metadata = json.load(f) return ( loaded_classifier, loaded_scaler, loaded_word_vectorizer, loaded_char_vectorizer, loaded_features, loaded_metadata, )