import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from efficientnet_pytorch import EfficientNet from PIL import Image from datetime import datetime # Define HardSwish activation class HardSwish(nn.Module): def __init__(self): super(HardSwish, self).__init__() def forward(self, x): return x * (torch.clamp(x + 3, 0, 6) / 6) # Define custom EfficientNet model class CustomEfficientNet(nn.Module): def __init__(self, num_classes): super(CustomEfficientNet, self).__init__() self.model = EfficientNet.from_name('efficientnet-b3') num_ftrs = self.model._fc.in_features self.model._fc = nn.Sequential( nn.Linear(num_ftrs, 512), HardSwish(), nn.Dropout(p=0.4), nn.Linear(512, num_classes) ) def forward(self, x): return self.model(x) # Class names class_names = [ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' ] # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model model = CustomEfficientNet(num_classes=14) checkpoint = torch.load('Final_global_model.pth.tar', map_location=device) if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) model = model.to(device) model.eval() # Transformations transform = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Prediction function def predict(patient_name, scan_date, image): if image is None: raise ValueError("āŒ Error: No image uploaded.") # Ensure image is RGB if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') img = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img) probs = torch.sigmoid(outputs).cpu().numpy()[0] results = {class_names[i]: float(probs[i]) for i in range(len(class_names))} sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]} summary = f"šŸ“‹ **Patient Name**: {patient_name}\nšŸ“… **Scan Date**: {scan_date.strftime('%Y-%m-%d')}\n\n### Top 5 Predictions" return summary, top5 # Gradio UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🩺 Chest X-ray Disease Classifier Upload a chest X-ray and get the top 5 predicted diseases with probability scores. """ ) with gr.Row(): with gr.Column(): patient_name = gr.Textbox(label="Patient Name", placeholder="Enter full name...") scan_date = gr.Date(label="Scan Date", value=datetime.today) image = gr.Image(label="Chest X-ray Image", type="pil") predict_button = gr.Button("šŸ” Predict") with gr.Column(): summary = gr.Markdown() output = gr.Label(num_top_classes=5) predict_button.click( predict, inputs=[patient_name, scan_date, image], outputs=[summary, output] ) if __name__ == "__main__": demo.launch()