Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from efficientnet_pytorch import EfficientNet | |
| from PIL import Image | |
| from datetime import datetime | |
| # Define HardSwish activation | |
| class HardSwish(nn.Module): | |
| def __init__(self): | |
| super(HardSwish, self).__init__() | |
| def forward(self, x): | |
| return x * (torch.clamp(x + 3, 0, 6) / 6) | |
| # Define custom EfficientNet model | |
| class CustomEfficientNet(nn.Module): | |
| def __init__(self, num_classes): | |
| super(CustomEfficientNet, self).__init__() | |
| self.model = EfficientNet.from_name('efficientnet-b3') | |
| num_ftrs = self.model._fc.in_features | |
| self.model._fc = nn.Sequential( | |
| nn.Linear(num_ftrs, 512), | |
| HardSwish(), | |
| nn.Dropout(p=0.4), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Class names | |
| class_names = [ | |
| 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', | |
| 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', | |
| 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' | |
| ] | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model | |
| model = CustomEfficientNet(num_classes=14) | |
| checkpoint = torch.load('Final_global_model.pth.tar', map_location=device) | |
| if 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model = model.to(device) | |
| model.eval() | |
| # Transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((300, 300)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Prediction function | |
| def predict(patient_name, scan_date, image): | |
| if image is None: | |
| raise ValueError("β Error: No image uploaded.") | |
| # Ensure image is RGB | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = torch.sigmoid(outputs).cpu().numpy()[0] | |
| results = {class_names[i]: float(probs[i]) for i in range(len(class_names))} | |
| sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) | |
| top5 = {k: f"{v*100:.2f}%" for k, v in list(sorted_results.items())[:5]} | |
| summary = f"π **Patient Name**: {patient_name}\nπ **Scan Date**: {scan_date.strftime('%Y-%m-%d')}\n\n### Top 5 Predictions" | |
| return summary, top5 | |
| # Gradio UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π©Ί Chest X-ray Disease Classifier | |
| Upload a chest X-ray and get the top 5 predicted diseases with probability scores. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| patient_name = gr.Textbox(label="Patient Name", placeholder="Enter full name...") | |
| scan_date = gr.Date(label="Scan Date", value=datetime.today) | |
| image = gr.Image(label="Chest X-ray Image", type="pil") | |
| predict_button = gr.Button("π Predict") | |
| with gr.Column(): | |
| summary = gr.Markdown() | |
| output = gr.Label(num_top_classes=5) | |
| predict_button.click( | |
| predict, | |
| inputs=[patient_name, scan_date, image], | |
| outputs=[summary, output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |