File size: 3,528 Bytes
8bb8257
ab120a7
 
 
 
 
4f3fb6c
588ecb6
4f3fb6c
ab120a7
 
 
 
 
 
 
4f3fb6c
ab120a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f3fb6c
 
 
 
 
ab120a7
588ecb6
ab120a7
 
4f3fb6c
ab120a7
 
 
 
 
 
 
 
 
4f3fb6c
ab120a7
 
 
 
 
 
2c1ab36
4f3fb6c
ab120a7
 
 
4f3fb6c
ab120a7
 
 
 
 
 
 
 
 
 
 
 
 
 
4f3fb6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab120a7
 
4f3fb6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
from datetime import datetime

# Define HardSwish activation
class HardSwish(nn.Module):
    def __init__(self):
        super(HardSwish, self).__init__()

    def forward(self, x):
        return x * (torch.clamp(x + 3, 0, 6) / 6)

# Define custom EfficientNet model
class CustomEfficientNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomEfficientNet, self).__init__()
        self.model = EfficientNet.from_name('efficientnet-b3')
        num_ftrs = self.model._fc.in_features
        self.model._fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            HardSwish(),
            nn.Dropout(p=0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Class names
class_names = [
    'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
    'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
    'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model
model = CustomEfficientNet(num_classes=14)
checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
if 'state_dict' in checkpoint:
    model.load_state_dict(checkpoint['state_dict'])
else:
    model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()

# Transformations
transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Prediction function
def predict(patient_name, scan_date, image):
    if image is None:
        raise ValueError("❌ Error: No image uploaded.")

    # Ensure image is RGB
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)

    if image.mode != 'RGB':
        image = image.convert('RGB')

    img = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img)
        probs = torch.sigmoid(outputs).cpu().numpy()[0]

    results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
    sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
    top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]}

    summary = f"📋 **Patient Name**: {patient_name}\n📅 **Scan Date**: {scan_date.strftime('%Y-%m-%d')}\n\n### Top 5 Predictions"
    return summary, top5

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🩺 Chest X-ray Disease Classifier
        Upload a chest X-ray and get the top 5 predicted diseases with probability scores.
        """
    )
    with gr.Row():
        with gr.Column():
            patient_name = gr.Textbox(label="Patient Name", placeholder="Enter full name...")
            scan_date = gr.Date(label="Scan Date", value=datetime.today)
            image = gr.Image(label="Chest X-ray Image", type="pil")
            predict_button = gr.Button("🔍 Predict")
        with gr.Column():
            summary = gr.Markdown()
            output = gr.Label(num_top_classes=5)

    predict_button.click(
        predict,
        inputs=[patient_name, scan_date, image],
        outputs=[summary, output]
    )

if __name__ == "__main__":
    demo.launch()