Spaces:
Sleeping
Sleeping
File size: 6,055 Bytes
e329104 f9d3994 e329104 f9d3994 a8787f9 ad4a317 92bf9c1 ad4a317 e329104 f9d3994 ad4a317 e329104 f9d3994 ad4a317 e329104 f9d3994 e329104 ad4a317 03482eb a8787f9 e329104 a8787f9 e329104 a8787f9 e329104 a8787f9 e329104 f9d3994 e329104 ad4a317 f9d3994 e329104 f9d3994 ad4a317 f9d3994 e329104 f9d3994 e329104 f9d3994 ad4a317 f9d3994 e329104 f9d3994 ad4a317 f9d3994 e329104 f9d3994 e329104 f9d3994 a8787f9 e329104 ad4a317 e329104 ad4a317 e329104 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os, json, traceback
import numpy as np
import pandas as pd
import xgboost as xgb
import gradio as gr
from typing import Dict, Any, List, Union
# ---------- Config via env ----------
MODEL_LOCAL_PATH = os.getenv("MODEL_LOCAL_PATH", "xgb_model.json")
# Leave HF_MODEL_REPO empty ("") to load locally.
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "") # e.g., "mjpsm/Entrepreneurial-Readiness-XGB"
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "xgb_model.json") # include subfolders if needed
HF_TOKEN = os.getenv("HF_TOKEN", None)
FEATURES_PATH = os.getenv("FEATURES_PATH", "feature_order.json")
LABELS_PATH = os.getenv("LABELS_PATH", "label_map.json") # ensure filename matches your repo
# ---------- Load model ----------
def _load_model() -> xgb.Booster:
booster = xgb.Booster()
if HF_MODEL_REPO:
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=MODEL_FILENAME,
token=HF_TOKEN # None is fine for public; required for private
)
print(f"[BOOT] Loaded model from Hub: {HF_MODEL_REPO}/{MODEL_FILENAME}")
booster.load_model(model_path)
else:
if not os.path.exists(MODEL_LOCAL_PATH):
raise FileNotFoundError(
f"Local model not found at '{MODEL_LOCAL_PATH}'. "
"Either upload the file or set HF_MODEL_REPO/MODEL_FILENAME to download from the Hub."
)
print(f"[BOOT] Loaded local model: {MODEL_LOCAL_PATH}")
booster.load_model(MODEL_LOCAL_PATH)
return booster
model = _load_model()
# ---------- Features & Labels ----------
# If present, these files define exact column order and class names
feature_order: List[str] = json.load(open(FEATURES_PATH)) if os.path.exists(FEATURES_PATH) else []
if os.path.exists(LABELS_PATH):
raw_map = json.load(open(LABELS_PATH))
if isinstance(raw_map, dict):
# invert mapping β sort by value β get list of keys
labels: List[str] = [k.capitalize() for k, v in sorted(raw_map.items(), key=lambda x: x[1])]
elif isinstance(raw_map, list):
labels: List[str] = raw_map
else:
labels: List[str] = []
else:
labels: List[str] = ["Low", "Medium", "High"] # fallback
# Fallback feature order (edit if needed; should match training names exactly)
fallback_features = [
"savings",
"monthly_income",
"monthly_bills",
"monthly_entertainment_spend",
"sales_skills_1to10",
"age",
"dependents_count",
"assets",
"risk_tolerance_1to10",
"confidence_1to10",
"idea_difficulty_1to10",
"runway_months",
"savings_to_expense_ratio",
"prior_businesses_started_",
"prior_exits",
"time_available_hours_per_week"
]
def _ensure_feature_matrix(rows: List[Dict[str, Any]]):
cols = feature_order or fallback_features
# Build a DataFrame so XGBoost sees column names during predict()
data = [[float(r.get(k, 0)) for k in cols] for r in rows]
df = pd.DataFrame(data, columns=cols).astype(np.float32)
return df, cols
def _predict_core(rows: List[Dict[str, Any]]):
df, _ = _ensure_feature_matrix(rows)
dmat = xgb.DMatrix(df) # names preserved via DataFrame
preds = model.predict(dmat) # (n,) for binary logistic; (n, C) for multi:softprob
outputs = []
for p in preds:
# Binary logistic returns a scalar probability of class 1
if isinstance(p, (np.floating, float, np.float32, np.float64)):
prob1 = float(p)
probs = [1.0 - prob1, prob1]
idx = int(prob1 >= 0.5)
else:
p = np.asarray(p)
probs = p.tolist()
idx = int(np.argmax(p))
if labels and idx < len(labels):
pred_label = labels[idx]
prob_map = {
str(labels[j]) if j < len(labels) else str(j): float(probs[j])
for j in range(len(probs))
}
else:
pred_label = str(idx)
prob_map = {str(j): float(probs[j]) for j in range(len(probs))}
outputs.append({"prediction": pred_label, "probabilities": prob_map})
return outputs
def predict_batch(payload: Union[Dict[str, Any], List[Dict[str, Any]]]):
"""Accepts a single object or a list of objects; returns prediction(s)."""
try:
if isinstance(payload, dict):
return _predict_core([payload])[0]
elif isinstance(payload, list):
return _predict_core(payload)
else:
raise ValueError("Input must be an object or a list of objects.")
except Exception as e:
tb = traceback.format_exc()
print("[ERROR]", tb)
return {"error": str(e), "traceback": tb}
def health():
return {
"model_loaded": True,
"source": "hub" if HF_MODEL_REPO else "local",
"features": feature_order or fallback_features
}
# ---------- Gradio UI / API ----------
example_row = {
"savings": 2500,
"monthly_income": 2200,
"monthly_bills": 1500,
"monthly_entertainment_spend": 200,
"sales_skills_1to10": 6,
"age": 21,
"dependents_count": 0,
"assets": 0,
"risk_tolerance_1to10": 7,
"confidence_1to10": 7,
"idea_difficulty_1to10": 5,
"runway_months": 3,
"savings_to_expense_ratio": 1.3,
"prior_businesses_started_": 0,
"prior_exits": 0,
"time_available_hours_per_week": 15
}
with gr.Blocks() as demo:
gr.Markdown("# Entrepreneurial Readiness β XGBoost Classifier (API + UI)")
gr.Markdown("Submit a JSON object **or** a list of objects. Returns `prediction` and `probabilities`.")
json_in = gr.JSON(label="Input JSON", value=example_row)
json_out = gr.JSON(label="Output")
# UI button and API endpoint
gr.Button("Predict").click(predict_batch, inputs=json_in, outputs=json_out, api_name="predict")
# Optional: simple health check endpoint
gr.Button("Health").click(fn=lambda: health(), inputs=None, outputs=json_out, api_name="health")
demo.launch()
|