import numpy as np import joblib import warnings import matplotlib.pyplot as plt import matplotlib import shap import os import tempfile from config import MODEL_PATH, FEATURE_NAMES warnings.filterwarnings('ignore') matplotlib.use('Agg') plt.rcParams['font.family'] = ['DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet): height_m = height / 100 bmi = weight / (height_m ** 2) nlr = neutrophil / lymphocyte if lymphocyte > 0 else 0 plr = platelet / lymphocyte if lymphocyte > 0 else 0 return bmi, nlr, plr def create_shap_plot(shap_values, feature_values, feature_names, prediction_proba): shap_vals = shap_values[0][:, 1] # Shape: (18,) - SHAP values for class 1 temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') temp_filename = temp_file.name temp_file.close() fig, ax = plt.subplots(figsize=(10, 12)) sorted_indices = np.argsort(np.abs(shap_vals)) sorted_shap_vals = shap_vals[sorted_indices] sorted_feature_names = [feature_names[i] for i in sorted_indices] sorted_feature_values = feature_values[sorted_indices] colors = ['red' if val > 0 else 'blue' for val in sorted_shap_vals] bars = ax.barh(range(len(sorted_shap_vals)), sorted_shap_vals, color=colors, alpha=0.7) ax.set_yticks(range(len(sorted_feature_names))) ax.set_yticklabels([f"{name} = {val:.2f}" for name, val in zip(sorted_feature_names, sorted_feature_values)]) ax.set_xlabel('SHAP Value (Impact on Prediction)', fontsize=12) ax.set_title(f'Feature Impact Analysis\nComplication Risk: {prediction_proba[1]*100:.1f}%', fontsize=14, pad=20) ax.axvline(x=0, color='black', linestyle='-', alpha=0.3) for i, (bar, val) in enumerate(zip(bars, sorted_shap_vals)): if val != 0: ax.text(val + (0.001 if val > 0 else -0.001), i, f'{val:.3f}', va='center', ha='left' if val > 0 else 'right', fontsize=9) ax.text(0.02, 0.98, 'Red: Increases risk\nBlue: Decreases risk', transform=ax.transAxes, va='top', ha='left', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) plt.tight_layout() plt.savefig(temp_filename, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none') plt.close() return temp_filename def get_shap_explainer_and_values(model, input_data): background_data = np.array([[ 28, 65, 162, 24.7, 2, 1, 0, 1, 28, 11.5, 34.0, 250, 8.5, 12.0, 6.0, 1.8, 3.33, 139 ]]) explainer = shap.KernelExplainer(model.predict_proba, background_data) shap_values = explainer.shap_values(input_data, nsamples=100) return shap_values def predict_outcome(age, weight, height, gravidity, parity, h_abortion, living_child, gestational_age, hemoglobin, hematocrit, platelet, mpv, pdw, neutrophil, lymphocyte): model = get_model() if model is None: return "خطا: مدل بارگذاری نشده است", "", None try: bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet) input_data = np.array([[ age, weight, height, bmi, gravidity, parity, h_abortion, living_child, gestational_age, hemoglobin, hematocrit, platelet, mpv, pdw, neutrophil, lymphocyte, nlr, plr ]]) prediction_proba = model.predict_proba(input_data)[0] prediction = model.predict(input_data)[0] if prediction == 0: result = f"🟢 پیش‌بینی: سالم (احتمال سالم بودن: {prediction_proba[0]*100:.1f}%)" risk_level = "کم" else: result = f"🔴 پیش‌بینی: پرخطر (احتمال عوارض: {prediction_proba[1]*100:.1f}%)" risk_level = "بالا" detailed_report = f""" 📊 **گزارش تفصیلی پیش‌بینی** **نتیجه کلی:** {result} **سطح ریسک:** {risk_level} **ویژگی‌های محاسبه شده:** - BMI: {bmi:.2f} - NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f} - PLR (نسبت پلاکت به لنفوسیت): {plr:.2f} **احتمالات تفصیلی:** - احتمال سالم بودن: {prediction_proba[0]*100:.1f}% - احتمال بروز عوارض: {prediction_proba[1]*100:.1f}% ⚠️ **توجه:** این پیش‌بینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود. """ shap_values = get_shap_explainer_and_values(model, input_data) shap_plot_path = create_shap_plot( shap_values, input_data[0], FEATURE_NAMES, prediction_proba ) return result, detailed_report, shap_plot_path except Exception as e: return f"خطا در پردازش: {str(e)}", "", None model = None def get_model(): global model if model is None: try: model = joblib.load(MODEL_PATH) print("Model loaded successfully!") return model except Exception as e: print(f"Error loading model: {e}") return None return model def cleanup_temp_files(): try: temp_dir = tempfile.gettempdir() for filename in os.listdir(temp_dir): if filename.endswith('.png') and 'tmp' in filename: try: os.remove(os.path.join(temp_dir, filename)) except: pass except: pass