#!/usr/bin/env python3 # explain_model.py import os import json import numpy as np import pandas as pd import torch import joblib import shap import matplotlib.pyplot as plt from safetensors.torch import load_file from network import PricePredictor # --- 0. Config --- MODEL_DIR = "model" DATA_DIR = "data" SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl") DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv") CONFIG_PATH = os.path.join(MODEL_DIR, "config.json") TARGET_COLUMN = "price_will_rise_30_in_6m" # --- 1. Load model & assets --- with open(CONFIG_PATH, "r") as f: config = json.load(f) feature_columns = config["feature_columns"] input_size = config["input_size"] model = PricePredictor(input_size=input_size) model.load_state_dict(load_file(os.path.join(MODEL_DIR, "model.safetensors"))) model.eval() scaler = joblib.load(SCALER_PATH) full_data = pd.read_csv(DATA_PATH) # Sanity checks missing_cols = [c for c in feature_columns if c not in full_data.columns] if missing_cols: raise ValueError(f"Missing required feature columns in CSV: {missing_cols}") features_df = full_data[feature_columns] if features_df.shape[1] != input_size: raise ValueError( f"Config input_size={input_size}, but CSV provides {features_df.shape[1]} features. " f"Ensure config['feature_columns'] matches the trained model." ) # --- 2. Prepare Data for SHAP --- bg_n = min(100, len(features_df)) explain_n = min(10, len(features_df)) background_idx = features_df.sample(n=bg_n, random_state=42).index explain_idx = features_df.sample(n=explain_n, random_state=1).index background_data = features_df.loc[background_idx] explain_instances = features_df.loc[explain_idx] # Use arrays for scaler to avoid feature-name warnings background_data_scaled = scaler.transform(background_data.values) explain_instances_scaled = scaler.transform(explain_instances.values) background_tensor = torch.tensor(background_data_scaled, dtype=torch.float32) # no grad explain_tensor = torch.tensor(explain_instances_scaled, dtype=torch.float32, requires_grad=True) # --- Helpers --- def get_shap_explanations(model, background_tensor, explain_tensor): """Try DeepExplainer then fall back to GradientExplainer. Return (explanation, explainer_used_name).""" try: print("Initializing SHAP DeepExplainer...") explainer = shap.DeepExplainer(model, background_tensor) print("Calculating SHAP values for the sample...") exp = explainer(explain_tensor) setattr(exp, "_expected_value_hint", getattr(explainer, "expected_value", None)) return exp, "deep" except Exception as e: print(f"[DeepExplainer failed: {e}] Falling back to GradientExplainer...") explain_tensor.requires_grad_(True) grad_explainer = shap.GradientExplainer(model, background_tensor) exp = grad_explainer(explain_tensor) setattr(exp, "_expected_value_hint", getattr(grad_explainer, "expected_value", None)) return exp, "grad" def compute_base_value_safe(shap_explanation, instance_idx, model, background_tensor): """Return scalar base value robustly across SHAP versions.""" bv = getattr(shap_explanation, "base_values", None) if bv is not None: try: return float(np.squeeze(bv[instance_idx])) except Exception: try: return float(np.squeeze(bv)) except Exception: pass ev = getattr(shap_explanation, "_expected_value_hint", None) if ev is not None: try: return float(np.squeeze(ev)) except Exception: try: return float(np.mean(ev)) except Exception: pass with torch.no_grad(): mu = background_tensor.mean(dim=0, keepdim=True) out = model(mu).detach().cpu().squeeze() return float(out.mean().item()) if out.numel() > 1 else float(out.item()) def stack_sample_shap_values(exp, n_features_expected): """ Some SHAP versions return exp.values with shape (n_samples, 1) or other oddities. However, exp[i].values is typically the correct 1D (n_features,) vector. We rebuild a full matrix by stacking per-sample slices. """ rows = [] n_samples = len(exp.values) if hasattr(exp.values, "__len__") else len(exp) # Safer: iterate using the __getitem__ API for i in range(n_samples): v = np.asarray(exp[i].values).reshape(-1,) rows.append(v) M = np.vstack(rows) # (n_samples, n_features) if M.shape[1] != n_features_expected: raise RuntimeError( f"Rebuilt SHAP matrix has shape {M.shape}; expected n_features={n_features_expected}." ) return M # --- 3. Compute SHAP explanations --- shap_explanation, _ = get_shap_explanations(model, background_tensor, explain_tensor) print("Calculation complete.") # Attach unscaled display data for pretty plotting shap_explanation.display_data = explain_instances.values shap_explanation.feature_names = feature_columns # --- 4a. Global Feature Importance (Bar / Summary) --- print("\nGenerating global feature importance plot (summary_plot.png)...") # Robustly build a (n_samples, n_features) matrix by stacking per-sample vectors shap_vals_matrix = stack_sample_shap_values(shap_explanation, n_features_expected=len(feature_columns)) mean_abs_shap = np.abs(shap_vals_matrix).mean(axis=0) # (n_features,) # Build a fresh Explanation with values aligned to feature_names plot_explanation = shap.Explanation(values=mean_abs_shap, feature_names=feature_columns) plt.figure() shap.plots.bar(plot_explanation, show=False) plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)") plt.savefig("summary_plot.png", bbox_inches="tight") plt.close() print("Saved: summary_plot.png") # --- 4b. Local Explanation (Force Plot) --- print("\nGenerating local explanation for one card (force_plot.html)...") instance_to_explain_index = 0 single_explanation = shap_explanation[instance_to_explain_index] # Some SHAP versions drop display_data on slicing; pull directly if needed if getattr(single_explanation, "display_data", None) is None: row_unscaled = explain_instances.values[instance_to_explain_index] else: row_unscaled = single_explanation.display_data features_row = np.atleast_2d(np.asarray(row_unscaled, dtype=float)) base_val = compute_base_value_safe(shap_explanation, instance_to_explain_index, model, background_tensor) phi = np.asarray(single_explanation.values).reshape(-1,) # (n_features,) force_plot = shap.force_plot( base_val, phi, features=features_row, feature_names=feature_columns ) shap.save_html("force_plot.html", force_plot) print("Saved: force_plot.html (open in a browser)") # --- 4c. Optional: local waterfall PNG (often clearer) --- try: print("Generating local waterfall plot (waterfall_single.png)...") plt.figure() shap.plots.waterfall(single_explanation, show=False, max_display=20) plt.savefig("waterfall_single.png", bbox_inches="tight") plt.close() print("Saved: waterfall_single.png") except Exception as e: print(f"Waterfall plot skipped (reason: {e})") # --- 5. Print metadata for the explained card --- original_card_data = full_data.loc[explain_idx[instance_to_explain_index]] name_val = original_card_data.get("name", "N/A") tcgp_val = original_card_data.get("tcgplayer_id", "N/A") label_val = original_card_data.get(TARGET_COLUMN, None) label_str = "RISE" if bool(label_val) else "NOT RISE" if label_val is not None else "N/A" print("\n--- Card Explained in force_plot.html / waterfall_single.png ---") print(f"Name: {name_val}") print(f"TCGPlayer ID: {tcgp_val}") print(f"Actual Outcome in Dataset: {label_str}") # TODO: convert the model into a format where i can share on hugging face as a model that can be pulled down and used # TODO: include the SHAP charts force_plot.html and summary_plot.png explaining the model, as well as compute some other evaluation metrics for explanation in the card