Spaces:
Runtime error
Runtime error
File size: 3,528 Bytes
8bb8257 ab120a7 4f3fb6c 588ecb6 4f3fb6c ab120a7 4f3fb6c ab120a7 4f3fb6c ab120a7 588ecb6 ab120a7 4f3fb6c ab120a7 4f3fb6c ab120a7 2c1ab36 4f3fb6c ab120a7 4f3fb6c ab120a7 4f3fb6c ab120a7 4f3fb6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
from datetime import datetime
# Define HardSwish activation
class HardSwish(nn.Module):
def __init__(self):
super(HardSwish, self).__init__()
def forward(self, x):
return x * (torch.clamp(x + 3, 0, 6) / 6)
# Define custom EfficientNet model
class CustomEfficientNet(nn.Module):
def __init__(self, num_classes):
super(CustomEfficientNet, self).__init__()
self.model = EfficientNet.from_name('efficientnet-b3')
num_ftrs = self.model._fc.in_features
self.model._fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
HardSwish(),
nn.Dropout(p=0.4),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.model(x)
# Class names
class_names = [
'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
model = CustomEfficientNet(num_classes=14)
checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
# Transformations
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict(patient_name, scan_date, image):
if image is None:
raise ValueError("❌ Error: No image uploaded.")
# Ensure image is RGB
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
probs = torch.sigmoid(outputs).cpu().numpy()[0]
results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]}
summary = f"📋 **Patient Name**: {patient_name}\n📅 **Scan Date**: {scan_date.strftime('%Y-%m-%d')}\n\n### Top 5 Predictions"
return summary, top5
# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🩺 Chest X-ray Disease Classifier
Upload a chest X-ray and get the top 5 predicted diseases with probability scores.
"""
)
with gr.Row():
with gr.Column():
patient_name = gr.Textbox(label="Patient Name", placeholder="Enter full name...")
scan_date = gr.Date(label="Scan Date", value=datetime.today)
image = gr.Image(label="Chest X-ray Image", type="pil")
predict_button = gr.Button("🔍 Predict")
with gr.Column():
summary = gr.Markdown()
output = gr.Label(num_top_classes=5)
predict_button.click(
predict,
inputs=[patient_name, scan_date, image],
outputs=[summary, output]
)
if __name__ == "__main__":
demo.launch()
|