brain-tumor-net / app.py
xTorch8's picture
Refactor application
267c122
import gradio as gr
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import keras
import traceback
from PIL import Image
from skimage.transform import resize
# ----------- Constants -----------
CLASSES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
IMG_SIZE = (224, 224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------- Segmentation Model Definition -----------
swin = timm.create_model('swin_base_patch4_window7_224', pretrained = False, features_only = True)
class UNetDecoder(nn.Module):
def __init__(self):
super().__init__()
def conv_block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = conv_block(768, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = conv_block(384, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = conv_block(192, 64)
self.final = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, features):
e1, e2, e3, e4 = features # e4 is reduced 512 channels
d3 = self.up3(e4)
d3 = self.dec3(torch.cat([d3, e3], dim=1)) # concat 256 + 512 = 768
d2 = self.up2(d3)
d2 = self.dec2(torch.cat([d2, e2], dim=1)) # concat 128 + 256 = 384
d1 = self.up1(d2)
d1 = self.dec1(torch.cat([d1, e1], dim=1)) # concat 64 + 128 = 192
out = F.interpolate(d1, scale_factor=4, mode='bilinear', align_corners=False)
return torch.sigmoid(self.final(out))
class SwinUNet(nn.Module):
def __init__(self):
super().__init__()
self.encoder = swin
self.channel_reducer = nn.Conv2d(1024, 512, kernel_size=1)
self.decoder = UNetDecoder()
def forward(self, x):
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
features = self.encoder(x)
features = [self._to_channels_first(f) for f in features]
features[3] = self.channel_reducer(features[3])
output = self.decoder(features)
return output
def _to_channels_first(self, feature):
if feature.dim() == 4:
return feature.permute(0, 3, 1, 2).contiguous()
elif feature.dim() == 3:
B, N, C = feature.shape
H = W = int(N ** 0.5)
feature = feature.permute(0, 2, 1).contiguous()
return feature.view(B, C, H, W)
else:
raise ValueError(f"Unexpected feature shape: {feature.shape}")
# ----------- Load Swin-UNet -----------
swinunet_model = SwinUNet()
swinunet_model.load_state_dict(torch.load("swinunet.pth", map_location = device))
swinunet_model = swinunet_model.to(device)
swinunet_model.eval()
# ----------- Load Classifier Model -----------
classifier_model = keras.models.load_model("cnn-swinunet")
# ----------- Transform -----------
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor()
])
# ----------- Segmentation -----------
def segmentation(image: Image.Image) -> np.ndarray:
# Convert to grayscale and tensor
image = image.convert("L")
input_tensor = transform(image).unsqueeze(0).to(device) # [1, 1, 224, 224]
with torch.no_grad():
mask_pred = swinunet_model(input_tensor)
mask_pred = F.interpolate(mask_pred, size=(224, 224), mode="bilinear", align_corners=False)
mask_pred = (mask_pred > 0.5).float()
image_np = input_tensor.squeeze().cpu().numpy() # [224, 224]
mask_np = mask_pred.squeeze().cpu().numpy() # [224, 224]
combined = np.stack([image_np, mask_np], axis=-1) # [224, 224, 2]
return combined
def predict(image: Image.Image):
try:
combined = segmentation(image)
combined = np.expand_dims(combined, axis=0) # Shape: (1, 224, 224, 2)
probs = classifier_model.predict(combined)[0]
return CLASSES[int(np.argmax(probs))]
except Exception as e:
traceback_str = traceback.format_exc()
print(traceback_str)
return traceback_str
demo = gr.Interface(
fn = predict,
inputs = gr.Image(type = "pil", label = "Brain MRI"),
outputs = gr.Label(num_top_classes = 4),
title = "Brain‑Tumor Net)",
description = "Returns: Glioma, Meningioma, No Tumor, Pituitary"
)
demo.launch()
if __name__ == "main":
demo.launch()