Spaces:
Sleeping
Sleeping
import numpy as np | |
import joblib | |
import warnings | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import shap | |
import os | |
import tempfile | |
from config import MODEL_PATH, FEATURE_NAMES | |
warnings.filterwarnings('ignore') | |
matplotlib.use('Agg') | |
plt.rcParams['font.family'] = ['DejaVu Sans'] | |
plt.rcParams['axes.unicode_minus'] = False | |
def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet): | |
height_m = height / 100 | |
bmi = weight / (height_m ** 2) | |
nlr = neutrophil / lymphocyte if lymphocyte > 0 else 0 | |
plr = platelet / lymphocyte if lymphocyte > 0 else 0 | |
return bmi, nlr, plr | |
def create_shap_plot(shap_values, feature_values, feature_names, prediction_proba): | |
shap_vals = shap_values[0][:, 1] # Shape: (18,) - SHAP values for class 1 | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
temp_filename = temp_file.name | |
temp_file.close() | |
fig, ax = plt.subplots(figsize=(10, 12)) | |
sorted_indices = np.argsort(np.abs(shap_vals)) | |
sorted_shap_vals = shap_vals[sorted_indices] | |
sorted_feature_names = [feature_names[i] for i in sorted_indices] | |
sorted_feature_values = feature_values[sorted_indices] | |
colors = ['red' if val > 0 else 'blue' for val in sorted_shap_vals] | |
bars = ax.barh(range(len(sorted_shap_vals)), sorted_shap_vals, color=colors, alpha=0.7) | |
ax.set_yticks(range(len(sorted_feature_names))) | |
ax.set_yticklabels([f"{name} = {val:.2f}" for name, val in zip(sorted_feature_names, sorted_feature_values)]) | |
ax.set_xlabel('SHAP Value (Impact on Prediction)', fontsize=12) | |
ax.set_title(f'Feature Impact Analysis\nComplication Risk: {prediction_proba[1]*100:.1f}%', | |
fontsize=14, pad=20) | |
ax.axvline(x=0, color='black', linestyle='-', alpha=0.3) | |
for i, (bar, val) in enumerate(zip(bars, sorted_shap_vals)): | |
if val != 0: | |
ax.text(val + (0.001 if val > 0 else -0.001), i, f'{val:.3f}', | |
va='center', ha='left' if val > 0 else 'right', fontsize=9) | |
ax.text(0.02, 0.98, 'Red: Increases risk\nBlue: Decreases risk', | |
transform=ax.transAxes, va='top', ha='left', | |
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) | |
plt.tight_layout() | |
plt.savefig(temp_filename, dpi=300, bbox_inches='tight', | |
facecolor='white', edgecolor='none') | |
plt.close() | |
return temp_filename | |
def get_shap_explainer_and_values(model, input_data): | |
background_data = np.array([[ | |
28, 65, 162, 24.7, 2, 1, 0, 1, 28, 11.5, 34.0, | |
250, 8.5, 12.0, 6.0, 1.8, 3.33, 139 | |
]]) | |
explainer = shap.KernelExplainer(model.predict_proba, background_data) | |
shap_values = explainer.shap_values(input_data, nsamples=100) | |
return shap_values | |
def predict_outcome(age, weight, height, gravidity, parity, h_abortion, | |
living_child, gestational_age, hemoglobin, hematocrit, | |
platelet, mpv, pdw, neutrophil, lymphocyte): | |
model = get_model() | |
if model is None: | |
return "خطا: مدل بارگذاری نشده است", "", None | |
try: | |
bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet) | |
input_data = np.array([[ | |
age, weight, height, bmi, gravidity, parity, h_abortion, | |
living_child, gestational_age, hemoglobin, hematocrit, platelet, | |
mpv, pdw, neutrophil, lymphocyte, nlr, plr | |
]]) | |
prediction_proba = model.predict_proba(input_data)[0] | |
prediction = model.predict(input_data)[0] | |
if prediction == 0: | |
result = f"🟢 پیشبینی: سالم (احتمال سالم بودن: {prediction_proba[0]*100:.1f}%)" | |
risk_level = "کم" | |
else: | |
result = f"🔴 پیشبینی: پرخطر (احتمال عوارض: {prediction_proba[1]*100:.1f}%)" | |
risk_level = "بالا" | |
detailed_report = f""" | |
📊 **گزارش تفصیلی پیشبینی** | |
**نتیجه کلی:** {result} | |
**سطح ریسک:** {risk_level} | |
**ویژگیهای محاسبه شده:** | |
- BMI: {bmi:.2f} | |
- NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f} | |
- PLR (نسبت پلاکت به لنفوسیت): {plr:.2f} | |
**احتمالات تفصیلی:** | |
- احتمال سالم بودن: {prediction_proba[0]*100:.1f}% | |
- احتمال بروز عوارض: {prediction_proba[1]*100:.1f}% | |
⚠️ **توجه:** این پیشبینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود. | |
""" | |
shap_values = get_shap_explainer_and_values(model, input_data) | |
shap_plot_path = create_shap_plot( | |
shap_values, | |
input_data[0], | |
FEATURE_NAMES, | |
prediction_proba | |
) | |
return result, detailed_report, shap_plot_path | |
except Exception as e: | |
return f"خطا در پردازش: {str(e)}", "", None | |
model = None | |
def get_model(): | |
global model | |
if model is None: | |
try: | |
model = joblib.load(MODEL_PATH) | |
print("Model loaded successfully!") | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
return model | |
def cleanup_temp_files(): | |
try: | |
temp_dir = tempfile.gettempdir() | |
for filename in os.listdir(temp_dir): | |
if filename.endswith('.png') and 'tmp' in filename: | |
try: | |
os.remove(os.path.join(temp_dir, filename)) | |
except: | |
pass | |
except: | |
pass | |