import numpy as np import joblib import torch from transformers import AutoModel import os class FinancialFilingClassifier: def __init__(self, model_dir): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Loading Jina Encoder on {self.device}...") self.encoder = AutoModel.from_pretrained( "jinaai/jina-embeddings-v3", trust_remote_code=True, torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32 ).to(self.device) print("Loading XGBoost Cascade...") self.router = joblib.load(os.path.join(model_dir, "router_xgb.pkl")) self.router_le = joblib.load(os.path.join(model_dir, "router_le.pkl")) self.specialists = {} self.model_dir = model_dir def _get_vector(self, text): log_len = np.log1p(len(str(text))) with torch.no_grad(): vec = self.encoder.encode([text], task="classification", max_length=8192) return np.hstack([vec, [[log_len]]]) def _load_specialist(self, category): safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_") if safe_name not in self.specialists: try: clf = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl")) le = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl")) self.specialists[safe_name] = (clf, le) except FileNotFoundError: return None return self.specialists[safe_name] def predict(self, text): vector = self._get_vector(text) router_probs = self.router.predict_proba(vector)[0] top_indices = np.argsort(router_probs)[::-1][:2] candidates = [] for idx in top_indices: category = self.router_le.classes_[idx] router_conf = router_probs[idx] specialist = self._load_specialist(category) if specialist: clf, le = specialist spec_probs = clf.predict_proba(vector)[0] best_idx = np.argmax(spec_probs) label = le.classes_[best_idx] spec_conf = spec_probs[best_idx] combined_score = np.sqrt(router_conf * spec_conf) candidates.append({"category": category, "label": label, "score": float(combined_score)}) else: candidates.append({"category": category, "label": category, "score": float(router_conf)}) return max(candidates, key=lambda x: x['score'])