Spaces:
Sleeping
Sleeping
import os, json, traceback | |
import numpy as np | |
import pandas as pd | |
import xgboost as xgb | |
import gradio as gr | |
from typing import Dict, Any, List, Union | |
# ---------- Config via env ---------- | |
MODEL_LOCAL_PATH = os.getenv("MODEL_LOCAL_PATH", "xgb_model.json") | |
# Leave HF_MODEL_REPO empty ("") to load locally. | |
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "") # e.g., "mjpsm/Entrepreneurial-Readiness-XGB" | |
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "xgb_model.json") # include subfolders if needed | |
HF_TOKEN = os.getenv("HF_TOKEN", None) | |
FEATURES_PATH = os.getenv("FEATURES_PATH", "feature_order.json") | |
LABELS_PATH = os.getenv("LABELS_PATH", "label_map.json") # ensure filename matches your repo | |
# ---------- Load model ---------- | |
def _load_model() -> xgb.Booster: | |
booster = xgb.Booster() | |
if HF_MODEL_REPO: | |
from huggingface_hub import hf_hub_download | |
model_path = hf_hub_download( | |
repo_id=HF_MODEL_REPO, | |
filename=MODEL_FILENAME, | |
token=HF_TOKEN # None is fine for public; required for private | |
) | |
print(f"[BOOT] Loaded model from Hub: {HF_MODEL_REPO}/{MODEL_FILENAME}") | |
booster.load_model(model_path) | |
else: | |
if not os.path.exists(MODEL_LOCAL_PATH): | |
raise FileNotFoundError( | |
f"Local model not found at '{MODEL_LOCAL_PATH}'. " | |
"Either upload the file or set HF_MODEL_REPO/MODEL_FILENAME to download from the Hub." | |
) | |
print(f"[BOOT] Loaded local model: {MODEL_LOCAL_PATH}") | |
booster.load_model(MODEL_LOCAL_PATH) | |
return booster | |
model = _load_model() | |
# ---------- Features & Labels ---------- | |
# If present, these files define exact column order and class names | |
feature_order: List[str] = json.load(open(FEATURES_PATH)) if os.path.exists(FEATURES_PATH) else [] | |
if os.path.exists(LABELS_PATH): | |
raw_map = json.load(open(LABELS_PATH)) | |
if isinstance(raw_map, dict): | |
# invert mapping β sort by value β get list of keys | |
labels: List[str] = [k.capitalize() for k, v in sorted(raw_map.items(), key=lambda x: x[1])] | |
elif isinstance(raw_map, list): | |
labels: List[str] = raw_map | |
else: | |
labels: List[str] = [] | |
else: | |
labels: List[str] = ["Low", "Medium", "High"] # fallback | |
# Fallback feature order (edit if needed; should match training names exactly) | |
fallback_features = [ | |
"savings", | |
"monthly_income", | |
"monthly_bills", | |
"monthly_entertainment_spend", | |
"sales_skills_1to10", | |
"age", | |
"dependents_count", | |
"assets", | |
"risk_tolerance_1to10", | |
"confidence_1to10", | |
"idea_difficulty_1to10", | |
"runway_months", | |
"savings_to_expense_ratio", | |
"prior_businesses_started_", | |
"prior_exits", | |
"time_available_hours_per_week" | |
] | |
def _ensure_feature_matrix(rows: List[Dict[str, Any]]): | |
cols = feature_order or fallback_features | |
# Build a DataFrame so XGBoost sees column names during predict() | |
data = [[float(r.get(k, 0)) for k in cols] for r in rows] | |
df = pd.DataFrame(data, columns=cols).astype(np.float32) | |
return df, cols | |
def _predict_core(rows: List[Dict[str, Any]]): | |
df, _ = _ensure_feature_matrix(rows) | |
dmat = xgb.DMatrix(df) # names preserved via DataFrame | |
preds = model.predict(dmat) # (n,) for binary logistic; (n, C) for multi:softprob | |
outputs = [] | |
for p in preds: | |
# Binary logistic returns a scalar probability of class 1 | |
if isinstance(p, (np.floating, float, np.float32, np.float64)): | |
prob1 = float(p) | |
probs = [1.0 - prob1, prob1] | |
idx = int(prob1 >= 0.5) | |
else: | |
p = np.asarray(p) | |
probs = p.tolist() | |
idx = int(np.argmax(p)) | |
if labels and idx < len(labels): | |
pred_label = labels[idx] | |
prob_map = { | |
str(labels[j]) if j < len(labels) else str(j): float(probs[j]) | |
for j in range(len(probs)) | |
} | |
else: | |
pred_label = str(idx) | |
prob_map = {str(j): float(probs[j]) for j in range(len(probs))} | |
outputs.append({"prediction": pred_label, "probabilities": prob_map}) | |
return outputs | |
def predict_batch(payload: Union[Dict[str, Any], List[Dict[str, Any]]]): | |
"""Accepts a single object or a list of objects; returns prediction(s).""" | |
try: | |
if isinstance(payload, dict): | |
return _predict_core([payload])[0] | |
elif isinstance(payload, list): | |
return _predict_core(payload) | |
else: | |
raise ValueError("Input must be an object or a list of objects.") | |
except Exception as e: | |
tb = traceback.format_exc() | |
print("[ERROR]", tb) | |
return {"error": str(e), "traceback": tb} | |
def health(): | |
return { | |
"model_loaded": True, | |
"source": "hub" if HF_MODEL_REPO else "local", | |
"features": feature_order or fallback_features | |
} | |
# ---------- Gradio UI / API ---------- | |
example_row = { | |
"savings": 2500, | |
"monthly_income": 2200, | |
"monthly_bills": 1500, | |
"monthly_entertainment_spend": 200, | |
"sales_skills_1to10": 6, | |
"age": 21, | |
"dependents_count": 0, | |
"assets": 0, | |
"risk_tolerance_1to10": 7, | |
"confidence_1to10": 7, | |
"idea_difficulty_1to10": 5, | |
"runway_months": 3, | |
"savings_to_expense_ratio": 1.3, | |
"prior_businesses_started_": 0, | |
"prior_exits": 0, | |
"time_available_hours_per_week": 15 | |
} | |
with gr.Blocks() as demo: | |
gr.Markdown("# Entrepreneurial Readiness β XGBoost Classifier (API + UI)") | |
gr.Markdown("Submit a JSON object **or** a list of objects. Returns `prediction` and `probabilities`.") | |
json_in = gr.JSON(label="Input JSON", value=example_row) | |
json_out = gr.JSON(label="Output") | |
# UI button and API endpoint | |
gr.Button("Predict").click(predict_batch, inputs=json_in, outputs=json_out, api_name="predict") | |
# Optional: simple health check endpoint | |
gr.Button("Health").click(fn=lambda: health(), inputs=None, outputs=json_out, api_name="health") | |
demo.launch() | |