import gradio as gr import numpy as np import timm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T import keras import traceback from PIL import Image from skimage.transform import resize # ----------- Constants ----------- CLASSES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"] IMG_SIZE = (224, 224) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------- Segmentation Model Definition ----------- swin = timm.create_model('swin_base_patch4_window7_224', pretrained = False, features_only = True) class UNetDecoder(nn.Module): def __init__(self): super().__init__() def conv_block(in_c, out_c): return nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.dec3 = conv_block(768, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.dec2 = conv_block(384, 128) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.dec1 = conv_block(192, 64) self.final = nn.Conv2d(64, 1, kernel_size=1) def forward(self, features): e1, e2, e3, e4 = features # e4 is reduced 512 channels d3 = self.up3(e4) d3 = self.dec3(torch.cat([d3, e3], dim=1)) # concat 256 + 512 = 768 d2 = self.up2(d3) d2 = self.dec2(torch.cat([d2, e2], dim=1)) # concat 128 + 256 = 384 d1 = self.up1(d2) d1 = self.dec1(torch.cat([d1, e1], dim=1)) # concat 64 + 128 = 192 out = F.interpolate(d1, scale_factor=4, mode='bilinear', align_corners=False) return torch.sigmoid(self.final(out)) class SwinUNet(nn.Module): def __init__(self): super().__init__() self.encoder = swin self.channel_reducer = nn.Conv2d(1024, 512, kernel_size=1) self.decoder = UNetDecoder() def forward(self, x): if x.shape[1] == 1: x = x.repeat(1, 3, 1, 1) features = self.encoder(x) features = [self._to_channels_first(f) for f in features] features[3] = self.channel_reducer(features[3]) output = self.decoder(features) return output def _to_channels_first(self, feature): if feature.dim() == 4: return feature.permute(0, 3, 1, 2).contiguous() elif feature.dim() == 3: B, N, C = feature.shape H = W = int(N ** 0.5) feature = feature.permute(0, 2, 1).contiguous() return feature.view(B, C, H, W) else: raise ValueError(f"Unexpected feature shape: {feature.shape}") # ----------- Load Swin-UNet ----------- swinunet_model = SwinUNet() swinunet_model.load_state_dict(torch.load("swinunet.pth", map_location = device)) swinunet_model = swinunet_model.to(device) swinunet_model.eval() # ----------- Load Classifier Model ----------- classifier_model = keras.models.load_model("cnn-swinunet") # ----------- Transform ----------- transform = T.Compose([ T.Resize((224, 224)), T.ToTensor() ]) # ----------- Segmentation ----------- def segmentation(image: Image.Image) -> np.ndarray: # Convert to grayscale and tensor image = image.convert("L") input_tensor = transform(image).unsqueeze(0).to(device) # [1, 1, 224, 224] with torch.no_grad(): mask_pred = swinunet_model(input_tensor) mask_pred = F.interpolate(mask_pred, size=(224, 224), mode="bilinear", align_corners=False) mask_pred = (mask_pred > 0.5).float() image_np = input_tensor.squeeze().cpu().numpy() # [224, 224] mask_np = mask_pred.squeeze().cpu().numpy() # [224, 224] combined = np.stack([image_np, mask_np], axis=-1) # [224, 224, 2] return combined def predict(image: Image.Image): try: combined = segmentation(image) combined = np.expand_dims(combined, axis=0) # Shape: (1, 224, 224, 2) probs = classifier_model.predict(combined)[0] return CLASSES[int(np.argmax(probs))] except Exception as e: traceback_str = traceback.format_exc() print(traceback_str) return traceback_str demo = gr.Interface( fn = predict, inputs = gr.Image(type = "pil", label = "Brain MRI"), outputs = gr.Label(num_top_classes = 4), title = "Brain‑Tumor Net)", description = "Returns: Glioma, Meningioma, No Tumor, Pituitary" ) demo.launch() if __name__ == "main": demo.launch()