Spaces:
Build error
Build error
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import models, transforms | |
import cv2 | |
from tensorflow.keras.models import load_model | |
from PIL import Image | |
import os | |
import pickle | |
from tensorflow.keras import backend as K | |
# I/O image dimensions | |
DISPLAY_DIMS = (256, 256) # For display | |
CLASS_DIMS = (224, 224) # For classification model input | |
SEG_DIMS = (128, 128) # For segmentation model input | |
# Device configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define Dice Coefficient function for TensorFlow segmentation model | |
def dice_coefficient(y_true, y_pred, smooth=1): | |
y_true_f = K.flatten(tf.cast(y_true, tf.float32)) | |
y_pred_f = K.flatten(tf.cast(y_pred, tf.float32)) | |
intersection = K.sum(y_true_f * y_pred_f) | |
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) | |
# Define Classification Model (PyTorch) | |
class ClassificationModel(nn.Module): | |
def __init__(self, input_dim): | |
super(ClassificationModel, self).__init__() | |
self.fc1 = nn.Linear(input_dim, 128) | |
self.fc2 = nn.Linear(128, 64) | |
self.fc3 = nn.Linear(64, 16) | |
self.fc4 = nn.Linear(16, 2) # Binary Classification | |
self.dropout = nn.Dropout(0.3) | |
def forward(self, x): | |
x = F.relu(self.fc1(x)) | |
x = self.dropout(x) | |
x = F.relu(self.fc2(x)) | |
x = self.dropout(x) | |
x = F.relu(self.fc3(x)) | |
x = self.fc4(x) | |
return x | |
# Load models | |
try: | |
# Load ResNet feature extractor | |
resnet = models.resnet18(pretrained=True) | |
resnet = nn.Sequential(*list(resnet.children())[:-1]) # Remove FC layer | |
resnet.to(device) | |
resnet.eval() | |
# Load Feature Selector | |
with open("feature_selector.pkl", "rb") as f: | |
selector = pickle.load(f) | |
# Load Classification Model | |
input_dim = selector.get_support().sum() # Number of selected features | |
classification_model = ClassificationModel(input_dim).to(device) | |
classification_model.load_state_dict(torch.load("trained_model.pth", map_location=device)) | |
classification_model.eval() | |
# Image transformation for PyTorch model | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load segmentation model | |
segmentation_model = None | |
if os.path.exists("segmentation_model.h5"): | |
segmentation_model = load_model("segmentation_model.h5", | |
custom_objects={'dice_coefficient': dice_coefficient}, | |
compile=False) | |
print("Loaded segmentation_model.h5") | |
elif os.path.exists("best_model.keras"): | |
segmentation_model = load_model("best_model.keras", | |
custom_objects={'dice_coefficient': dice_coefficient}, | |
compile=False) | |
print("Loaded best_model.keras") | |
models_loaded = True | |
print("Models loaded successfully!") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
print("The app will run in demo mode with simulated predictions.") | |
models_loaded = False | |
resnet = None | |
selector = None | |
classification_model = None | |
segmentation_model = None | |
transform = None | |
# Function to preprocess image for classification | |
def preprocess_for_classification(image): | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(np.array(image)) | |
image = image.convert("RGB") # Ensure RGB | |
return transform(image).unsqueeze(0).to(device) | |
# Function to preprocess image for segmentation | |
def preprocess_for_segmentation(image): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
# Convert to RGB if needed | |
if len(image.shape) == 2: # Grayscale | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
elif image.shape[2] == 4: # RGBA | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
# Resize to segmentation model's input size | |
image = cv2.resize(image, SEG_DIMS) | |
# Normalize | |
image = image / 255.0 | |
# Add batch dimension | |
image = np.expand_dims(image, axis=0) | |
return image | |
# Function to classify COVID-19 using PyTorch model | |
def classify_image(image): | |
if image is None: | |
return "No image provided", None, 0 | |
try: | |
if models_loaded and resnet is not None and classification_model is not None: | |
# Preprocess and extract features | |
img_tensor = preprocess_for_classification(image) | |
with torch.no_grad(): | |
features = resnet(img_tensor).view(-1).cpu().numpy() | |
# Select features using the feature selector | |
features_selected = selector.transform(features.reshape(1, -1)) | |
input_tensor = torch.tensor(features_selected, dtype=torch.float32).to(device) | |
# Make prediction | |
with torch.no_grad(): | |
output = classification_model(input_tensor) | |
print("Classification output:",output) | |
predicted_class = torch.argmax(output, dim=1).item() | |
print("Classification predicted class:",predicted_class) | |
probabilities = F.softmax(output, dim=1) | |
print("Classification probabilities:",probabilities) | |
confidence = probabilities[0][predicted_class].item() | |
# Map class index to label (0 -> COVID, 1 -> Non-COVID) | |
status = "COVID" if predicted_class == 0 else "Non-COVID" | |
return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f})", image, predicted_class | |
else: | |
# Demo mode with simulated predictions | |
import random | |
predicted_class = random.randint(0, 1) # 0 or 1 | |
confidence = random.uniform(0.7, 0.99) | |
status = "COVID" if predicted_class == 0 else "Non-COVID" | |
return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f}) [DEMO]", image, predicted_class | |
except Exception as e: | |
return f"Error during classification: {str(e)}", image, 0 | |
# Function to segment lesions in CT images | |
def segment_image(image): | |
if image is None: | |
return "No segmentation performed", None, None | |
try: | |
if models_loaded and segmentation_model is not None: | |
# Preprocess for segmentation | |
input_image = preprocess_for_segmentation(image) | |
# Predict mask | |
pred_mask = segmentation_model.predict(input_image) | |
binary_mask = (pred_mask > 0.5).astype(np.uint8) | |
# Create colored overlay | |
if isinstance(image, Image.Image): | |
display_image = np.array(image) | |
else: | |
display_image = np.array(image) | |
# Resize original image for display | |
display_image = cv2.resize(display_image, DISPLAY_DIMS) | |
# Resize predicted mask to match display image | |
display_mask = cv2.resize(binary_mask[0].squeeze(), DISPLAY_DIMS) | |
# Create overlay | |
overlay = display_image.copy() | |
if len(overlay.shape) == 2: # If grayscale | |
overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB) | |
elif overlay.shape[2] == 4: # If RGBA | |
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGBA2RGB) | |
# Apply red mask on segmented areas | |
overlay[:, :, 0] = np.maximum(overlay[:, :, 0], display_mask * 255) # Red channel | |
overlay[:, :, 1] = np.where(display_mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) # Reduce green | |
overlay[:, :, 2] = np.where(display_mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) # Reduce blue | |
# Calculate lesion percentage | |
lesion_percentage = np.sum(binary_mask) / binary_mask.size * 100 | |
# Enhance the segmentation mask for visibility | |
# Convert to 3-channel image with a heatmap colormap | |
enhanced_mask = cv2.normalize(display_mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) | |
enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) # Apply color map for visibility | |
# return f"Lesion Coverage: {lesion_percentage:.2f}%", enhanced_mask, overlay | |
return enhanced_mask, overlay | |
else: | |
# Demo mode with simulated segmentation | |
return simulate_segmentation(image) | |
except Exception as e: | |
return f"Error during segmentation: {str(e)}", None, image | |
# Function to simulate segmentation for demo mode | |
def simulate_segmentation(image): | |
# For demo mode, create a simulated segmentation | |
import random | |
if isinstance(image, Image.Image): | |
display_image = np.array(image) | |
else: | |
display_image = np.array(image) | |
if len(display_image.shape) == 2: | |
display_image = cv2.cvtColor(display_image, cv2.COLOR_GRAY2RGB) | |
elif display_image.shape[2] == 4: | |
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGBA2RGB) | |
display_image = cv2.resize(display_image, DISPLAY_DIMS) | |
# Create a blank mask | |
mask = np.zeros(DISPLAY_DIMS, dtype=np.uint8) | |
# Simulate random blobs | |
num_blobs = random.randint(1, 3) | |
for i in range(num_blobs): | |
center_x = random.randint(50, DISPLAY_DIMS[0]-50) | |
center_y = random.randint(50, DISPLAY_DIMS[1]-50) | |
radius = random.randint(10, 30) | |
cv2.circle(mask, (center_x, center_y), radius, 1, -1) | |
# Create colored overlay | |
overlay = display_image.copy() | |
# Apply red mask on segmented areas | |
overlay[:, :, 0] = np.maximum(overlay[:, :, 0], mask * 255) # Red channel | |
overlay[:, :, 1] = np.where(mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) # Reduce green | |
overlay[:, :, 2] = np.where(mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) # Reduce blue | |
# Enhance the mask for visibility | |
enhanced_mask = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) | |
enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) # Apply color map for visibility | |
lesion_percentage = np.sum(mask) / mask.size * 100 | |
# return f"Lesion Coverage: {lesion_percentage:.2f}% [DEMO]", enhanced_mask, overlay | |
return enhanced_mask, overlay | |
# Function to run both classification and segmentation | |
def process_image(image): | |
if image is None: | |
return None, "No image provided", None, "No image provided" | |
# Run classification | |
classification_result, processed_image, predicted_class = classify_image(image) | |
# Run segmentation (now for all images regardless of class) | |
# segmentation_result, segmentation_map, overlay_image = segment_image(image) | |
segmentation_map, overlay_image = segment_image(image) | |
# Combine results | |
# combined_result = f"{classification_result}\n{segmentation_result}" | |
# return overlay_image, combined_result, segmentation_map, classification_result | |
return overlay_image, classification_result, segmentation_map, classification_result | |
# Load example images | |
def load_covid_examples(): | |
examples = [] | |
try: | |
# Look for COVID example images | |
for i in range(1, 6): | |
covid_path = f"./examples/Covid ({i}).png" | |
if os.path.exists(covid_path): | |
examples.append([covid_path]) | |
# If no COVID examples were found, create placeholders | |
if len(examples) == 0: | |
for i in range(1, 6): | |
covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200 | |
cv2.putText(covid_img, f"COVID Example {i}", (30, 128), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2) | |
examples.append([covid_img]) | |
except Exception as e: | |
print(f"Could not load COVID examples: {e}") | |
return examples | |
def load_non_covid_examples(): | |
examples = [] | |
try: | |
# Look for Non-COVID example images | |
for i in range(1, 6): | |
non_covid_path = f"./examples/Non-Covid ({i}).png" | |
if os.path.exists(non_covid_path): | |
examples.append([non_covid_path]) | |
# If no Non-COVID examples were found, create placeholders | |
if len(examples) == 0: | |
for i in range(1, 6): | |
non_covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200 | |
cv2.putText(non_covid_img, f"Non-COVID Example {i}", (30, 128), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2) | |
examples.append([non_covid_img]) | |
except Exception as e: | |
print(f"Could not load Non-COVID examples: {e}") | |
return examples | |
class GradioInterface: | |
def __init__(self): | |
self.covid_examples = load_covid_examples() | |
self.non_covid_examples = load_non_covid_examples() | |
def create_interface(self): | |
app_styles = """ | |
<style> | |
/* Global Styles */ | |
body, #root { | |
font-family: Helvetica, Arial, sans-serif; | |
background-color: #1a1a1a; | |
color: #fafafa; | |
} | |
/* Header Styles */ | |
.app-header { | |
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
padding: 24px; | |
border-radius: 8px; | |
margin-bottom: 24px; | |
text-align: center; | |
} | |
.app-title { | |
font-size: 48px; | |
margin: 0; | |
color: #fafafa; | |
} | |
.app-subtitle { | |
font-size: 24px; | |
margin: 8px 0 16px; | |
color: #fafafa; | |
} | |
.app-description { | |
font-size: 16px; | |
line-height: 1.6; | |
opacity: 0.8; | |
margin-bottom: 24px; | |
} | |
/* Button Styles */ | |
.publication-links { | |
display: flex; | |
justify-content: center; | |
flex-wrap: wrap; | |
gap: 8px; | |
margin-bottom: 16px; | |
} | |
.publication-link { | |
display: inline-flex; | |
align-items: center; | |
padding: 8px 16px; | |
background-color: #333; | |
color: #fff !important; | |
text-decoration: none !important; | |
border-radius: 20px; | |
font-size: 14px; | |
transition: background-color 0.3s; | |
} | |
.publication-link:hover { | |
background-color: #555; | |
} | |
.publication-link i { | |
margin-right: 8px; | |
} | |
/* Content Styles */ | |
.content-container { | |
background-color: #2a2a2a; | |
border-radius: 8px; | |
padding: 24px; | |
margin-bottom: 24px; | |
} | |
/* Image Styles */ | |
.image-preview img { | |
max-width: 256px; | |
max-height: 256px; | |
margin: 0 auto; | |
border-radius: 4px; | |
display: block; | |
object-fit: contain; | |
} | |
/* Control Styles */ | |
.control-panel { | |
background-color: #333; | |
padding: 16px; | |
border-radius: 8px; | |
margin-top: 16px; | |
} | |
/* Gradio Component Overrides */ | |
.gr-button { | |
background-color: #4a4a4a; | |
color: #fff; | |
border: none; | |
border-radius: 4px; | |
padding: 8px 16px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #5a5a5a; | |
} | |
.gr-input, .gr-dropdown { | |
background-color: #3a3a3a; | |
color: #fff; | |
border: 1px solid #4a4a4a; | |
border-radius: 4px; | |
padding: 8px; | |
} | |
.gr-form { | |
background-color: transparent; | |
} | |
.gr-panel { | |
border: none; | |
background-color: transparent; | |
} | |
</style> | |
""" | |
header_html = f""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css"> | |
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> | |
{app_styles} | |
<div class="app-header"> | |
<h1 class="app-title">COVID-19 CT Analysis System</h1> | |
<h2 class="app-subtitle">Classification & Lesion Segmentation</h2> | |
<p class="app-description"> | |
Upload CT scan images to detect COVID-19 and segment lesions if present. | |
The system uses ResNet-18 for feature extraction and a U-Net for lesion segmentation. | |
</p> | |
</div> | |
""" | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'dark') { | |
url.searchParams.set('__theme', 'dark'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: | |
gr.HTML(header_html) | |
with gr.Row(elem_classes="content-container"): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload CT Scan Image", type="pil", image_mode="RGB", elem_classes="image-preview") | |
run_button = gr.Button("Analyze Image", elem_classes="gr-button") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
covid_examples_title = gr.Markdown("### COVID Examples") | |
covid_examples = gr.Examples( | |
examples=self.covid_examples, | |
inputs=input_image, | |
label="" | |
) | |
with gr.Column(scale=1): | |
non_covid_examples_title = gr.Markdown("### Non-COVID Examples") | |
non_covid_examples = gr.Examples( | |
examples=self.non_covid_examples, | |
inputs=input_image, | |
label="" | |
) | |
with gr.Column(): | |
with gr.Tab("Results"): | |
overlay_image = gr.Image(label="Segmentation Overlay", elem_classes="image-preview") | |
result_text = gr.Textbox(label="Analysis Results") | |
with gr.Tab("Segmentation Details"): | |
segmentation_image = gr.Image(label="Lesion Segmentation Map", elem_classes="image-preview") | |
classification_text = gr.Textbox(label="Classification Details") | |
run_button.click( | |
fn=process_image, | |
inputs=input_image, | |
outputs=[overlay_image, result_text, segmentation_image, classification_text], | |
) | |
return demo | |
def main(): | |
interface = GradioInterface() | |
demo = interface.create_interface() | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
main() |