Mrhuman1's picture
Update app.py
4f3fb6c verified
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
from datetime import datetime
# Define HardSwish activation
class HardSwish(nn.Module):
def __init__(self):
super(HardSwish, self).__init__()
def forward(self, x):
return x * (torch.clamp(x + 3, 0, 6) / 6)
# Define custom EfficientNet model
class CustomEfficientNet(nn.Module):
def __init__(self, num_classes):
super(CustomEfficientNet, self).__init__()
self.model = EfficientNet.from_name('efficientnet-b3')
num_ftrs = self.model._fc.in_features
self.model._fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
HardSwish(),
nn.Dropout(p=0.4),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.model(x)
# Class names
class_names = [
'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'
]
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
model = CustomEfficientNet(num_classes=14)
checkpoint = torch.load('Final_global_model.pth.tar', map_location=device)
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
# Transformations
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict(patient_name, scan_date, image):
if image is None:
raise ValueError("❌ Error: No image uploaded.")
# Ensure image is RGB
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
probs = torch.sigmoid(outputs).cpu().numpy()[0]
results = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]}
summary = f"πŸ“‹ **Patient Name**: {patient_name}\nπŸ“… **Scan Date**: {scan_date.strftime('%Y-%m-%d')}\n\n### Top 5 Predictions"
return summary, top5
# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🩺 Chest X-ray Disease Classifier
Upload a chest X-ray and get the top 5 predicted diseases with probability scores.
"""
)
with gr.Row():
with gr.Column():
patient_name = gr.Textbox(label="Patient Name", placeholder="Enter full name...")
scan_date = gr.Date(label="Scan Date", value=datetime.today)
image = gr.Image(label="Chest X-ray Image", type="pil")
predict_button = gr.Button("πŸ” Predict")
with gr.Column():
summary = gr.Markdown()
output = gr.Label(num_top_classes=5)
predict_button.click(
predict,
inputs=[patient_name, scan_date, image],
outputs=[summary, output]
)
if __name__ == "__main__":
demo.launch()