Pregnancy_Risk_Evaluator / interface.py
abreza's picture
add shap value
1c87170
import gradio as gr
from model_utils import predict_outcome
from styles import RTL_CSS, HTML_HEAD
from config import (
APP_TITLE, MODEL_ACCURACY, MODEL_AUC,
DEFAULT_VALUES, FIELD_RANGES, EXAMPLE_CASES
)
def create_patient_info_section():
with gr.Column():
gr.Markdown("### 📝 اطلاعات بیمار")
age = gr.Number(label="سن", value=DEFAULT_VALUES['age'],
minimum=FIELD_RANGES['age']['min'], maximum=FIELD_RANGES['age']['max'])
weight = gr.Number(label="وزن (کیلوگرم)", value=DEFAULT_VALUES['weight'],
minimum=FIELD_RANGES['weight']['min'], maximum=FIELD_RANGES['weight']['max'])
height = gr.Number(label="قد (سانتی‌متر)", value=DEFAULT_VALUES['height'],
minimum=FIELD_RANGES['height']['min'], maximum=FIELD_RANGES['height']['max'])
with gr.Row():
gravidity = gr.Number(label="تعداد بارداری", value=DEFAULT_VALUES['gravidity'],
minimum=FIELD_RANGES['gravidity']['min'], maximum=FIELD_RANGES['gravidity']['max'])
parity = gr.Number(label="تعداد زایمان", value=DEFAULT_VALUES['parity'],
minimum=FIELD_RANGES['parity']['min'], maximum=FIELD_RANGES['parity']['max'])
with gr.Row():
h_abortion = gr.Number(label="تعداد سقط", value=DEFAULT_VALUES['h_abortion'],
minimum=FIELD_RANGES['h_abortion']['min'], maximum=FIELD_RANGES['h_abortion']['max'])
living_child = gr.Number(label="فرزند زنده", value=DEFAULT_VALUES['living_child'],
minimum=FIELD_RANGES['living_child']['min'], maximum=FIELD_RANGES['living_child']['max'])
gestational_age = gr.Number(label="سن بارداری (هفته)", value=DEFAULT_VALUES['gestational_age'],
minimum=FIELD_RANGES['gestational_age']['min'], maximum=FIELD_RANGES['gestational_age']['max'])
return age, weight, height, gravidity, parity, h_abortion, living_child, gestational_age
def create_lab_tests_section():
with gr.Column():
gr.Markdown("### 🧪 آزمایشات خون")
hemoglobin = gr.Number(label="هموگلوبین", value=DEFAULT_VALUES['hemoglobin'],
minimum=FIELD_RANGES['hemoglobin']['min'], maximum=FIELD_RANGES['hemoglobin']['max'])
hematocrit = gr.Number(label="هماتوکریت", value=DEFAULT_VALUES['hematocrit'],
minimum=FIELD_RANGES['hematocrit']['min'], maximum=FIELD_RANGES['hematocrit']['max'])
platelet = gr.Number(label="پلاکت", value=DEFAULT_VALUES['platelet'],
minimum=FIELD_RANGES['platelet']['min'], maximum=FIELD_RANGES['platelet']['max'])
with gr.Row():
mpv = gr.Number(label="MPV", value=DEFAULT_VALUES['mpv'],
minimum=FIELD_RANGES['mpv']['min'], maximum=FIELD_RANGES['mpv']['max'])
pdw = gr.Number(label="PDW", value=DEFAULT_VALUES['pdw'],
minimum=FIELD_RANGES['pdw']['min'], maximum=FIELD_RANGES['pdw']['max'])
with gr.Row():
neutrophil = gr.Number(label="نوتروفیل", value=DEFAULT_VALUES['neutrophil'],
minimum=FIELD_RANGES['neutrophil']['min'], maximum=FIELD_RANGES['neutrophil']['max'])
lymphocyte = gr.Number(label="لنفوسیت", value=DEFAULT_VALUES['lymphocyte'],
minimum=FIELD_RANGES['lymphocyte']['min'], maximum=FIELD_RANGES['lymphocyte']['max'])
return hemoglobin, hematocrit, platelet, mpv, pdw, neutrophil, lymphocyte
def predict_with_explanation(age, weight, height, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit,
platelet, mpv, pdw, neutrophil, lymphocyte):
required_fields = [age, weight, height, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit,
platelet, mpv, pdw, neutrophil, lymphocyte]
if any(field is None or field == "" for field in required_fields):
return "⚠️ لطفاً تمام فیلدها را پر کنید", "برای پیش‌بینی دقیق، تمام اطلاعات مورد نیاز است.", None
result, detailed_report, shap_plot = predict_outcome(
age, weight, height, gravidity, parity, h_abortion,
living_child, gestational_age, hemoglobin, hematocrit,
platelet, mpv, pdw, neutrophil, lymphocyte
)
return result, detailed_report, shap_plot
def clear_all_fields():
return tuple([None] * 17) + ("", "", None)
def load_example(example_name):
example_data = EXAMPLE_CASES[example_name]
return tuple(example_data[key] for key in [
'age', 'weight', 'height', 'gravidity', 'parity', 'h_abortion',
'living_child', 'gestational_age', 'hemoglobin', 'hematocrit',
'platelet', 'mpv', 'pdw', 'neutrophil', 'lymphocyte'
])
def create_interface():
with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft(), css=RTL_CSS, head=HTML_HEAD) as demo:
gr.Markdown(f"""
# {APP_TITLE}
این سیستم با استفاده از مدل هوش مصنوعی **AdaBoost**، احتمال بروز عوارض در بارداری را پیش‌بینی می‌کند.
**📊 عملکرد مدل:** دقت {MODEL_ACCURACY} | AUC {MODEL_AUC}
🔍 **ویژگی‌های سیستم:**
- پیش‌بینی دقیق با استفاده از هوش مصنوعی
- تحلیل SHAP برای توضیح تأثیر هر ویژگی
- گزارش تفصیلی و قابل فهم برای پزشکان
- نمودار تصویری تأثیر پارامترها
📝 **راهنما:** تمام فیلدها را پر کنید یا از مثال‌های آماده استفاده کنید.
""")
with gr.Row():
patient_inputs = create_patient_info_section()
lab_inputs = create_lab_tests_section()
with gr.Row():
predict_btn = gr.Button("🔍 پیش‌بینی", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ پاک کردن", variant="secondary")
with gr.Row():
with gr.Column(scale=2):
result_text = gr.Textbox(label="نتیجه پیش‌بینی", lines=2)
detailed_report = gr.Markdown(label="گزارش تفصیلی")
with gr.Column(scale=1):
shap_plot = gr.Image(label="نمودار SHAP - تأثیر ویژگی‌ها", type="filepath")
gr.Markdown("---")
gr.Markdown("## 📚 مثال‌های آماده")
with gr.Row():
for example_name in EXAMPLE_CASES.keys():
example_btn = gr.Button(f"📋 {example_name}", variant="secondary")
example_btn.click(
fn=lambda name=example_name: load_example(name),
outputs=list(patient_inputs) + list(lab_inputs)
)
predict_btn.click(
fn=predict_with_explanation,
inputs=list(patient_inputs) + list(lab_inputs),
outputs=[result_text, detailed_report, shap_plot]
)
clear_btn.click(
fn=clear_all_fields,
outputs=list(patient_inputs) + list(lab_inputs) + [result_text, detailed_report, shap_plot]
)
return demo