import gradio as gr import numpy as np import tensorflow as tf import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms import cv2 from tensorflow.keras.models import load_model from PIL import Image import os import pickle from tensorflow.keras import backend as K # I/O image dimensions DISPLAY_DIMS = (256, 256) # For display CLASS_DIMS = (224, 224) # For classification model input SEG_DIMS = (128, 128) # For segmentation model input # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define Dice Coefficient function for TensorFlow segmentation model def dice_coefficient(y_true, y_pred, smooth=1): y_true_f = K.flatten(tf.cast(y_true, tf.float32)) y_pred_f = K.flatten(tf.cast(y_pred, tf.float32)) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) # Define Classification Model (PyTorch) class ClassificationModel(nn.Module): def __init__(self, input_dim): super(ClassificationModel, self).__init__() self.fc1 = nn.Linear(input_dim, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 16) self.fc4 = nn.Linear(16, 2) # Binary Classification self.dropout = nn.Dropout(0.3) def forward(self, x): x = F.relu(self.fc1(x)) x = self.dropout(x) x = F.relu(self.fc2(x)) x = self.dropout(x) x = F.relu(self.fc3(x)) x = self.fc4(x) return x # Load models try: # Load ResNet feature extractor resnet = models.resnet18(pretrained=True) resnet = nn.Sequential(*list(resnet.children())[:-1]) # Remove FC layer resnet.to(device) resnet.eval() # Load Feature Selector with open("feature_selector.pkl", "rb") as f: selector = pickle.load(f) # Load Classification Model input_dim = selector.get_support().sum() # Number of selected features classification_model = ClassificationModel(input_dim).to(device) classification_model.load_state_dict(torch.load("trained_model.pth", map_location=device)) classification_model.eval() # Image transformation for PyTorch model transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load segmentation model segmentation_model = None if os.path.exists("segmentation_model.h5"): segmentation_model = load_model("segmentation_model.h5", custom_objects={'dice_coefficient': dice_coefficient}, compile=False) print("Loaded segmentation_model.h5") elif os.path.exists("best_model.keras"): segmentation_model = load_model("best_model.keras", custom_objects={'dice_coefficient': dice_coefficient}, compile=False) print("Loaded best_model.keras") models_loaded = True print("Models loaded successfully!") except Exception as e: print(f"Error loading models: {e}") print("The app will run in demo mode with simulated predictions.") models_loaded = False resnet = None selector = None classification_model = None segmentation_model = None transform = None # Function to preprocess image for classification def preprocess_for_classification(image): if not isinstance(image, Image.Image): image = Image.fromarray(np.array(image)) image = image.convert("RGB") # Ensure RGB return transform(image).unsqueeze(0).to(device) # Function to preprocess image for segmentation def preprocess_for_segmentation(image): if isinstance(image, Image.Image): image = np.array(image) # Convert to RGB if needed if len(image.shape) == 2: # Grayscale image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # RGBA image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) # Resize to segmentation model's input size image = cv2.resize(image, SEG_DIMS) # Normalize image = image / 255.0 # Add batch dimension image = np.expand_dims(image, axis=0) return image # Function to classify COVID-19 using PyTorch model def classify_image(image): if image is None: return "No image provided", None, 0 try: if models_loaded and resnet is not None and classification_model is not None: # Preprocess and extract features img_tensor = preprocess_for_classification(image) with torch.no_grad(): features = resnet(img_tensor).view(-1).cpu().numpy() # Select features using the feature selector features_selected = selector.transform(features.reshape(1, -1)) input_tensor = torch.tensor(features_selected, dtype=torch.float32).to(device) # Make prediction with torch.no_grad(): output = classification_model(input_tensor) print("Classification output:",output) predicted_class = torch.argmax(output, dim=1).item() print("Classification predicted class:",predicted_class) probabilities = F.softmax(output, dim=1) print("Classification probabilities:",probabilities) confidence = probabilities[0][predicted_class].item() # Map class index to label (0 -> COVID, 1 -> Non-COVID) status = "COVID" if predicted_class == 0 else "Non-COVID" return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f})", image, predicted_class else: # Demo mode with simulated predictions import random predicted_class = random.randint(0, 1) # 0 or 1 confidence = random.uniform(0.7, 0.99) status = "COVID" if predicted_class == 0 else "Non-COVID" return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f}) [DEMO]", image, predicted_class except Exception as e: return f"Error during classification: {str(e)}", image, 0 # Function to segment lesions in CT images def segment_image(image): if image is None: return "No segmentation performed", None, None try: if models_loaded and segmentation_model is not None: # Preprocess for segmentation input_image = preprocess_for_segmentation(image) # Predict mask pred_mask = segmentation_model.predict(input_image) binary_mask = (pred_mask > 0.5).astype(np.uint8) # Create colored overlay if isinstance(image, Image.Image): display_image = np.array(image) else: display_image = np.array(image) # Resize original image for display display_image = cv2.resize(display_image, DISPLAY_DIMS) # Resize predicted mask to match display image display_mask = cv2.resize(binary_mask[0].squeeze(), DISPLAY_DIMS) # Create overlay overlay = display_image.copy() if len(overlay.shape) == 2: # If grayscale overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB) elif overlay.shape[2] == 4: # If RGBA overlay = cv2.cvtColor(overlay, cv2.COLOR_RGBA2RGB) # Apply red mask on segmented areas overlay[:, :, 0] = np.maximum(overlay[:, :, 0], display_mask * 255) # Red channel overlay[:, :, 1] = np.where(display_mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) # Reduce green overlay[:, :, 2] = np.where(display_mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) # Reduce blue # Calculate lesion percentage lesion_percentage = np.sum(binary_mask) / binary_mask.size * 100 # Enhance the segmentation mask for visibility # Convert to 3-channel image with a heatmap colormap enhanced_mask = cv2.normalize(display_mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) # Apply color map for visibility # return f"Lesion Coverage: {lesion_percentage:.2f}%", enhanced_mask, overlay return enhanced_mask, overlay else: # Demo mode with simulated segmentation return simulate_segmentation(image) except Exception as e: return f"Error during segmentation: {str(e)}", None, image # Function to simulate segmentation for demo mode def simulate_segmentation(image): # For demo mode, create a simulated segmentation import random if isinstance(image, Image.Image): display_image = np.array(image) else: display_image = np.array(image) if len(display_image.shape) == 2: display_image = cv2.cvtColor(display_image, cv2.COLOR_GRAY2RGB) elif display_image.shape[2] == 4: display_image = cv2.cvtColor(display_image, cv2.COLOR_RGBA2RGB) display_image = cv2.resize(display_image, DISPLAY_DIMS) # Create a blank mask mask = np.zeros(DISPLAY_DIMS, dtype=np.uint8) # Simulate random blobs num_blobs = random.randint(1, 3) for i in range(num_blobs): center_x = random.randint(50, DISPLAY_DIMS[0]-50) center_y = random.randint(50, DISPLAY_DIMS[1]-50) radius = random.randint(10, 30) cv2.circle(mask, (center_x, center_y), radius, 1, -1) # Create colored overlay overlay = display_image.copy() # Apply red mask on segmented areas overlay[:, :, 0] = np.maximum(overlay[:, :, 0], mask * 255) # Red channel overlay[:, :, 1] = np.where(mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) # Reduce green overlay[:, :, 2] = np.where(mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) # Reduce blue # Enhance the mask for visibility enhanced_mask = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) # Apply color map for visibility lesion_percentage = np.sum(mask) / mask.size * 100 # return f"Lesion Coverage: {lesion_percentage:.2f}% [DEMO]", enhanced_mask, overlay return enhanced_mask, overlay # Function to run both classification and segmentation def process_image(image): if image is None: return None, "No image provided", None, "No image provided" # Run classification classification_result, processed_image, predicted_class = classify_image(image) # Run segmentation (now for all images regardless of class) # segmentation_result, segmentation_map, overlay_image = segment_image(image) segmentation_map, overlay_image = segment_image(image) # Combine results # combined_result = f"{classification_result}\n{segmentation_result}" # return overlay_image, combined_result, segmentation_map, classification_result return overlay_image, classification_result, segmentation_map, classification_result # Load example images def load_covid_examples(): examples = [] try: # Look for COVID example images for i in range(1, 6): covid_path = f"./examples/Covid ({i}).png" if os.path.exists(covid_path): examples.append([covid_path]) # If no COVID examples were found, create placeholders if len(examples) == 0: for i in range(1, 6): covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200 cv2.putText(covid_img, f"COVID Example {i}", (30, 128), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2) examples.append([covid_img]) except Exception as e: print(f"Could not load COVID examples: {e}") return examples def load_non_covid_examples(): examples = [] try: # Look for Non-COVID example images for i in range(1, 6): non_covid_path = f"./examples/Non-Covid ({i}).png" if os.path.exists(non_covid_path): examples.append([non_covid_path]) # If no Non-COVID examples were found, create placeholders if len(examples) == 0: for i in range(1, 6): non_covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200 cv2.putText(non_covid_img, f"Non-COVID Example {i}", (30, 128), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2) examples.append([non_covid_img]) except Exception as e: print(f"Could not load Non-COVID examples: {e}") return examples class GradioInterface: def __init__(self): self.covid_examples = load_covid_examples() self.non_covid_examples = load_non_covid_examples() def create_interface(self): app_styles = """ """ header_html = f""" {app_styles}
Upload CT scan images to detect COVID-19 and segment lesions if present. The system uses ResNet-18 for feature extraction and a U-Net for lesion segmentation.