Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import timm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
import keras | |
import traceback | |
from PIL import Image | |
from skimage.transform import resize | |
# ----------- Constants ----------- | |
CLASSES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"] | |
IMG_SIZE = (224, 224) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ----------- Segmentation Model Definition ----------- | |
swin = timm.create_model('swin_base_patch4_window7_224', pretrained = False, features_only = True) | |
class UNetDecoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def conv_block(in_c, out_c): | |
return nn.Sequential( | |
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
self.dec3 = conv_block(768, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
self.dec2 = conv_block(384, 128) | |
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
self.dec1 = conv_block(192, 64) | |
self.final = nn.Conv2d(64, 1, kernel_size=1) | |
def forward(self, features): | |
e1, e2, e3, e4 = features # e4 is reduced 512 channels | |
d3 = self.up3(e4) | |
d3 = self.dec3(torch.cat([d3, e3], dim=1)) # concat 256 + 512 = 768 | |
d2 = self.up2(d3) | |
d2 = self.dec2(torch.cat([d2, e2], dim=1)) # concat 128 + 256 = 384 | |
d1 = self.up1(d2) | |
d1 = self.dec1(torch.cat([d1, e1], dim=1)) # concat 64 + 128 = 192 | |
out = F.interpolate(d1, scale_factor=4, mode='bilinear', align_corners=False) | |
return torch.sigmoid(self.final(out)) | |
class SwinUNet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = swin | |
self.channel_reducer = nn.Conv2d(1024, 512, kernel_size=1) | |
self.decoder = UNetDecoder() | |
def forward(self, x): | |
if x.shape[1] == 1: | |
x = x.repeat(1, 3, 1, 1) | |
features = self.encoder(x) | |
features = [self._to_channels_first(f) for f in features] | |
features[3] = self.channel_reducer(features[3]) | |
output = self.decoder(features) | |
return output | |
def _to_channels_first(self, feature): | |
if feature.dim() == 4: | |
return feature.permute(0, 3, 1, 2).contiguous() | |
elif feature.dim() == 3: | |
B, N, C = feature.shape | |
H = W = int(N ** 0.5) | |
feature = feature.permute(0, 2, 1).contiguous() | |
return feature.view(B, C, H, W) | |
else: | |
raise ValueError(f"Unexpected feature shape: {feature.shape}") | |
# ----------- Load Swin-UNet ----------- | |
swinunet_model = SwinUNet() | |
swinunet_model.load_state_dict(torch.load("swinunet.pth", map_location = device)) | |
swinunet_model = swinunet_model.to(device) | |
swinunet_model.eval() | |
# ----------- Load Classifier Model ----------- | |
classifier_model = keras.models.load_model("cnn-swinunet") | |
# ----------- Transform ----------- | |
transform = T.Compose([ | |
T.Resize((224, 224)), | |
T.ToTensor() | |
]) | |
# ----------- Segmentation ----------- | |
def segmentation(image: Image.Image) -> np.ndarray: | |
# Convert to grayscale and tensor | |
image = image.convert("L") | |
input_tensor = transform(image).unsqueeze(0).to(device) # [1, 1, 224, 224] | |
with torch.no_grad(): | |
mask_pred = swinunet_model(input_tensor) | |
mask_pred = F.interpolate(mask_pred, size=(224, 224), mode="bilinear", align_corners=False) | |
mask_pred = (mask_pred > 0.5).float() | |
image_np = input_tensor.squeeze().cpu().numpy() # [224, 224] | |
mask_np = mask_pred.squeeze().cpu().numpy() # [224, 224] | |
combined = np.stack([image_np, mask_np], axis=-1) # [224, 224, 2] | |
return combined | |
def predict(image: Image.Image): | |
try: | |
combined = segmentation(image) | |
combined = np.expand_dims(combined, axis=0) # Shape: (1, 224, 224, 2) | |
probs = classifier_model.predict(combined)[0] | |
return CLASSES[int(np.argmax(probs))] | |
except Exception as e: | |
traceback_str = traceback.format_exc() | |
print(traceback_str) | |
return traceback_str | |
demo = gr.Interface( | |
fn = predict, | |
inputs = gr.Image(type = "pil", label = "Brain MRI"), | |
outputs = gr.Label(num_top_classes = 4), | |
title = "Brain‑Tumor Net)", | |
description = "Returns: Glioma, Meningioma, No Tumor, Pituitary" | |
) | |
demo.launch() | |
if __name__ == "main": | |
demo.launch() | |