import logging from typing import Dict, Any, List import numpy as np import joblib import torch from transformers import AutoModel import os # Setup logging logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): """ Initialization code. 'path' is the directory where your model artifacts are. Hugging Face downloads your repo to 'path' automatically. """ self.device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info(f"🚀 Loading Jina Encoder on {self.device}...") # 1. Load Jina (The heavy lifter) self.encoder = AutoModel.from_pretrained( "jinaai/jina-embeddings-v3", trust_remote_code=True, torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32 ).to(self.device) # 2. Load Router & Specialists logger.info("loading XGBoost Cascade...") self.model_dir = path self.router = joblib.load(os.path.join(path, "router_xgb.pkl")) self.router_le = joblib.load(os.path.join(path, "router_le.pkl")) self.specialists = {} def _get_vector(self, text): log_len = np.log1p(len(str(text))) with torch.no_grad(): vec = self.encoder.encode([text], task="classification", max_length=8192) # Jina returns (1, 1024), we append log_len -> (1, 1025) return np.hstack([vec, [[log_len]]]) def _load_specialist(self, category): # Lazy loading to keep memory usage low at startup safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_") if safe_name not in self.specialists: try: clf_path = os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl") le_path = os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl") if os.path.exists(clf_path): clf = joblib.load(clf_path) le = joblib.load(le_path) self.specialists[safe_name] = (clf, le) else: return None except Exception as e: logger.error(f"Failed to load specialist {safe_name}: {e}") return None return self.specialists[safe_name] def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ The main inference method called by the API. Expected JSON input: {"inputs": "text content here..."} """ # Handle both single string and list inputs inputs = data.pop("inputs", data) if isinstance(inputs, list): inputs = inputs[0] text = str(inputs) vector = self._get_vector(text) # 1. Router Prediction router_probs = self.router.predict_proba(vector)[0] top_indices = np.argsort(router_probs)[::-1][:2] candidates = [] for idx in top_indices: category = self.router_le.classes_[idx] router_conf = router_probs[idx] # 2. Specialist Prediction specialist = self._load_specialist(category) if specialist: clf, le = specialist spec_probs = clf.predict_proba(vector)[0] best_idx = np.argmax(spec_probs) label = le.classes_[best_idx] spec_conf = spec_probs[best_idx] # Soft Score combined_score = np.sqrt(router_conf * spec_conf) candidates.append({ "category": category, "label": label, "score": float(combined_score), "confidence": float(combined_score) }) else: candidates.append({ "category": category, "label": category, "score": float(router_conf), "confidence": float(router_conf) }) # Winner Take All best_match = max(candidates, key=lambda x: x['score']) return [best_match]