File size: 4,256 Bytes
8672435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import logging
from typing import Dict, Any, List
import numpy as np
import joblib
import torch
from transformers import AutoModel
import os

# Setup logging
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialization code. 'path' is the directory where your model artifacts are.
        Hugging Face downloads your repo to 'path' automatically.
        """
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"🚀 Loading Jina Encoder on {self.device}...")
        
        # 1. Load Jina (The heavy lifter)
        self.encoder = AutoModel.from_pretrained(
            "jinaai/jina-embeddings-v3", 
            trust_remote_code=True, 
            torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32
        ).to(self.device)
        
        # 2. Load Router & Specialists
        logger.info("loading XGBoost Cascade...")
        self.model_dir = path
        self.router = joblib.load(os.path.join(path, "router_xgb.pkl"))
        self.router_le = joblib.load(os.path.join(path, "router_le.pkl"))
        self.specialists = {}

    def _get_vector(self, text):
        log_len = np.log1p(len(str(text)))
        with torch.no_grad():
            vec = self.encoder.encode([text], task="classification", max_length=8192)
        # Jina returns (1, 1024), we append log_len -> (1, 1025)
        return np.hstack([vec, [[log_len]]])

    def _load_specialist(self, category):
        # Lazy loading to keep memory usage low at startup
        safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
        
        if safe_name not in self.specialists:
            try:
                clf_path = os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl")
                le_path = os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl")
                
                if os.path.exists(clf_path):
                    clf = joblib.load(clf_path)
                    le = joblib.load(le_path)
                    self.specialists[safe_name] = (clf, le)
                else:
                    return None
            except Exception as e:
                logger.error(f"Failed to load specialist {safe_name}: {e}")
                return None
        return self.specialists[safe_name]

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        The main inference method called by the API.
        Expected JSON input: {"inputs": "text content here..."}
        """
        # Handle both single string and list inputs
        inputs = data.pop("inputs", data)
        if isinstance(inputs, list):
            inputs = inputs[0]
        
        text = str(inputs)
        vector = self._get_vector(text)
        
        # 1. Router Prediction
        router_probs = self.router.predict_proba(vector)[0]
        top_indices = np.argsort(router_probs)[::-1][:2]
        
        candidates = []
        
        for idx in top_indices:
            category = self.router_le.classes_[idx]
            router_conf = router_probs[idx]
            
            # 2. Specialist Prediction
            specialist = self._load_specialist(category)
            if specialist:
                clf, le = specialist
                spec_probs = clf.predict_proba(vector)[0]
                best_idx = np.argmax(spec_probs)
                label = le.classes_[best_idx]
                spec_conf = spec_probs[best_idx]
                
                # Soft Score
                combined_score = np.sqrt(router_conf * spec_conf)
                
                candidates.append({
                    "category": category,
                    "label": label,
                    "score": float(combined_score),
                    "confidence": float(combined_score)
                })
            else:
                candidates.append({
                    "category": category,
                    "label": category,
                    "score": float(router_conf),
                    "confidence": float(router_conf)
                })
        
        # Winner Take All
        best_match = max(candidates, key=lambda x: x['score'])
        return [best_match]