Spaces:
Sleeping
Sleeping
File size: 4,679 Bytes
a8be84c 267c122 a8be84c 267c122 a8be84c 267c122 a8be84c 267c122 a8be84c 267c122 a8be84c 267c122 a8be84c 267c122 a8be84c 267c122 a8be84c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import keras
import traceback
from PIL import Image
from skimage.transform import resize
# ----------- Constants -----------
CLASSES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
IMG_SIZE = (224, 224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------- Segmentation Model Definition -----------
swin = timm.create_model('swin_base_patch4_window7_224', pretrained = False, features_only = True)
class UNetDecoder(nn.Module):
def __init__(self):
super().__init__()
def conv_block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = conv_block(768, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = conv_block(384, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = conv_block(192, 64)
self.final = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, features):
e1, e2, e3, e4 = features # e4 is reduced 512 channels
d3 = self.up3(e4)
d3 = self.dec3(torch.cat([d3, e3], dim=1)) # concat 256 + 512 = 768
d2 = self.up2(d3)
d2 = self.dec2(torch.cat([d2, e2], dim=1)) # concat 128 + 256 = 384
d1 = self.up1(d2)
d1 = self.dec1(torch.cat([d1, e1], dim=1)) # concat 64 + 128 = 192
out = F.interpolate(d1, scale_factor=4, mode='bilinear', align_corners=False)
return torch.sigmoid(self.final(out))
class SwinUNet(nn.Module):
def __init__(self):
super().__init__()
self.encoder = swin
self.channel_reducer = nn.Conv2d(1024, 512, kernel_size=1)
self.decoder = UNetDecoder()
def forward(self, x):
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
features = self.encoder(x)
features = [self._to_channels_first(f) for f in features]
features[3] = self.channel_reducer(features[3])
output = self.decoder(features)
return output
def _to_channels_first(self, feature):
if feature.dim() == 4:
return feature.permute(0, 3, 1, 2).contiguous()
elif feature.dim() == 3:
B, N, C = feature.shape
H = W = int(N ** 0.5)
feature = feature.permute(0, 2, 1).contiguous()
return feature.view(B, C, H, W)
else:
raise ValueError(f"Unexpected feature shape: {feature.shape}")
# ----------- Load Swin-UNet -----------
swinunet_model = SwinUNet()
swinunet_model.load_state_dict(torch.load("swinunet.pth", map_location = device))
swinunet_model = swinunet_model.to(device)
swinunet_model.eval()
# ----------- Load Classifier Model -----------
classifier_model = keras.models.load_model("cnn-swinunet")
# ----------- Transform -----------
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor()
])
# ----------- Segmentation -----------
def segmentation(image: Image.Image) -> np.ndarray:
# Convert to grayscale and tensor
image = image.convert("L")
input_tensor = transform(image).unsqueeze(0).to(device) # [1, 1, 224, 224]
with torch.no_grad():
mask_pred = swinunet_model(input_tensor)
mask_pred = F.interpolate(mask_pred, size=(224, 224), mode="bilinear", align_corners=False)
mask_pred = (mask_pred > 0.5).float()
image_np = input_tensor.squeeze().cpu().numpy() # [224, 224]
mask_np = mask_pred.squeeze().cpu().numpy() # [224, 224]
combined = np.stack([image_np, mask_np], axis=-1) # [224, 224, 2]
return combined
def predict(image: Image.Image):
try:
combined = segmentation(image)
combined = np.expand_dims(combined, axis=0) # Shape: (1, 224, 224, 2)
probs = classifier_model.predict(combined)[0]
return CLASSES[int(np.argmax(probs))]
except Exception as e:
traceback_str = traceback.format_exc()
print(traceback_str)
return traceback_str
demo = gr.Interface(
fn = predict,
inputs = gr.Image(type = "pil", label = "Brain MRI"),
outputs = gr.Label(num_top_classes = 4),
title = "Brain‑Tumor Net)",
description = "Returns: Glioma, Meningioma, No Tumor, Pituitary"
)
demo.launch()
if __name__ == "main":
demo.launch()
|