Pregnancy_Risk_Evaluator / model_utils.py
abreza's picture
add shap value
1c87170
import numpy as np
import joblib
import warnings
import matplotlib.pyplot as plt
import matplotlib
import shap
import os
import tempfile
from config import MODEL_PATH, FEATURE_NAMES
warnings.filterwarnings('ignore')
matplotlib.use('Agg')
plt.rcParams['font.family'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet):
height_m = height / 100
bmi = weight / (height_m ** 2)
nlr = neutrophil / lymphocyte if lymphocyte > 0 else 0
plr = platelet / lymphocyte if lymphocyte > 0 else 0
return bmi, nlr, plr
def create_shap_plot(shap_values, feature_values, feature_names, prediction_proba):
shap_vals = shap_values[0][:, 1] # Shape: (18,) - SHAP values for class 1
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_filename = temp_file.name
temp_file.close()
fig, ax = plt.subplots(figsize=(10, 12))
sorted_indices = np.argsort(np.abs(shap_vals))
sorted_shap_vals = shap_vals[sorted_indices]
sorted_feature_names = [feature_names[i] for i in sorted_indices]
sorted_feature_values = feature_values[sorted_indices]
colors = ['red' if val > 0 else 'blue' for val in sorted_shap_vals]
bars = ax.barh(range(len(sorted_shap_vals)), sorted_shap_vals, color=colors, alpha=0.7)
ax.set_yticks(range(len(sorted_feature_names)))
ax.set_yticklabels([f"{name} = {val:.2f}" for name, val in zip(sorted_feature_names, sorted_feature_values)])
ax.set_xlabel('SHAP Value (Impact on Prediction)', fontsize=12)
ax.set_title(f'Feature Impact Analysis\nComplication Risk: {prediction_proba[1]*100:.1f}%',
fontsize=14, pad=20)
ax.axvline(x=0, color='black', linestyle='-', alpha=0.3)
for i, (bar, val) in enumerate(zip(bars, sorted_shap_vals)):
if val != 0:
ax.text(val + (0.001 if val > 0 else -0.001), i, f'{val:.3f}',
va='center', ha='left' if val > 0 else 'right', fontsize=9)
ax.text(0.02, 0.98, 'Red: Increases risk\nBlue: Decreases risk',
transform=ax.transAxes, va='top', ha='left',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
plt.tight_layout()
plt.savefig(temp_filename, dpi=300, bbox_inches='tight',
facecolor='white', edgecolor='none')
plt.close()
return temp_filename
def get_shap_explainer_and_values(model, input_data):
background_data = np.array([[
28, 65, 162, 24.7, 2, 1, 0, 1, 28, 11.5, 34.0,
250, 8.5, 12.0, 6.0, 1.8, 3.33, 139
]])
explainer = shap.KernelExplainer(model.predict_proba, background_data)
shap_values = explainer.shap_values(input_data, nsamples=100)
return shap_values
def predict_outcome(age, weight, height, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit,
platelet, mpv, pdw, neutrophil, lymphocyte):
model = get_model()
if model is None:
return "خطا: مدل بارگذاری نشده است", "", None
try:
bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet)
input_data = np.array([[
age, weight, height, bmi, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit, platelet,
mpv, pdw, neutrophil, lymphocyte, nlr, plr
]])
prediction_proba = model.predict_proba(input_data)[0]
prediction = model.predict(input_data)[0]
if prediction == 0:
result = f"🟢 پیش‌بینی: سالم (احتمال سالم بودن: {prediction_proba[0]*100:.1f}%)"
risk_level = "کم"
else:
result = f"🔴 پیش‌بینی: پرخطر (احتمال عوارض: {prediction_proba[1]*100:.1f}%)"
risk_level = "بالا"
detailed_report = f"""
📊 **گزارش تفصیلی پیش‌بینی**
**نتیجه کلی:** {result}
**سطح ریسک:** {risk_level}
**ویژگی‌های محاسبه شده:**
- BMI: {bmi:.2f}
- NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f}
- PLR (نسبت پلاکت به لنفوسیت): {plr:.2f}
**احتمالات تفصیلی:**
- احتمال سالم بودن: {prediction_proba[0]*100:.1f}%
- احتمال بروز عوارض: {prediction_proba[1]*100:.1f}%
⚠️ **توجه:** این پیش‌بینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود.
"""
shap_values = get_shap_explainer_and_values(model, input_data)
shap_plot_path = create_shap_plot(
shap_values,
input_data[0],
FEATURE_NAMES,
prediction_proba
)
return result, detailed_report, shap_plot_path
except Exception as e:
return f"خطا در پردازش: {str(e)}", "", None
model = None
def get_model():
global model
if model is None:
try:
model = joblib.load(MODEL_PATH)
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
return model
def cleanup_temp_files():
try:
temp_dir = tempfile.gettempdir()
for filename in os.listdir(temp_dir):
if filename.endswith('.png') and 'tmp' in filename:
try:
os.remove(os.path.join(temp_dir, filename))
except:
pass
except:
pass