silashundhausen's picture
Upload folder using huggingface_hub
8672435 verified
import logging
from typing import Dict, Any, List
import numpy as np
import joblib
import torch
from transformers import AutoModel
import os
# Setup logging
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
"""
Initialization code. 'path' is the directory where your model artifacts are.
Hugging Face downloads your repo to 'path' automatically.
"""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"๐Ÿš€ Loading Jina Encoder on {self.device}...")
# 1. Load Jina (The heavy lifter)
self.encoder = AutoModel.from_pretrained(
"jinaai/jina-embeddings-v3",
trust_remote_code=True,
torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32
).to(self.device)
# 2. Load Router & Specialists
logger.info("loading XGBoost Cascade...")
self.model_dir = path
self.router = joblib.load(os.path.join(path, "router_xgb.pkl"))
self.router_le = joblib.load(os.path.join(path, "router_le.pkl"))
self.specialists = {}
def _get_vector(self, text):
log_len = np.log1p(len(str(text)))
with torch.no_grad():
vec = self.encoder.encode([text], task="classification", max_length=8192)
# Jina returns (1, 1024), we append log_len -> (1, 1025)
return np.hstack([vec, [[log_len]]])
def _load_specialist(self, category):
# Lazy loading to keep memory usage low at startup
safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
if safe_name not in self.specialists:
try:
clf_path = os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl")
le_path = os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl")
if os.path.exists(clf_path):
clf = joblib.load(clf_path)
le = joblib.load(le_path)
self.specialists[safe_name] = (clf, le)
else:
return None
except Exception as e:
logger.error(f"Failed to load specialist {safe_name}: {e}")
return None
return self.specialists[safe_name]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
The main inference method called by the API.
Expected JSON input: {"inputs": "text content here..."}
"""
# Handle both single string and list inputs
inputs = data.pop("inputs", data)
if isinstance(inputs, list):
inputs = inputs[0]
text = str(inputs)
vector = self._get_vector(text)
# 1. Router Prediction
router_probs = self.router.predict_proba(vector)[0]
top_indices = np.argsort(router_probs)[::-1][:2]
candidates = []
for idx in top_indices:
category = self.router_le.classes_[idx]
router_conf = router_probs[idx]
# 2. Specialist Prediction
specialist = self._load_specialist(category)
if specialist:
clf, le = specialist
spec_probs = clf.predict_proba(vector)[0]
best_idx = np.argmax(spec_probs)
label = le.classes_[best_idx]
spec_conf = spec_probs[best_idx]
# Soft Score
combined_score = np.sqrt(router_conf * spec_conf)
candidates.append({
"category": category,
"label": label,
"score": float(combined_score),
"confidence": float(combined_score)
})
else:
candidates.append({
"category": category,
"label": category,
"score": float(router_conf),
"confidence": float(router_conf)
})
# Winner Take All
best_match = max(candidates, key=lambda x: x['score'])
return [best_match]