File size: 4,679 Bytes
a8be84c
 
267c122
 
 
 
 
 
 
a8be84c
267c122
a8be84c
267c122
a8be84c
 
267c122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8be84c
267c122
 
 
a8be84c
267c122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8be84c
 
267c122
 
 
 
 
a8be84c
 
267c122
 
 
a8be84c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import keras
import traceback
from PIL import Image
from skimage.transform import resize

# ----------- Constants -----------
CLASSES = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
IMG_SIZE = (224, 224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------- Segmentation Model Definition -----------
swin = timm.create_model('swin_base_patch4_window7_224', pretrained = False, features_only = True)

class UNetDecoder(nn.Module):
    def __init__(self):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(768, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(384, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(192, 64)

        self.final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, features):
        e1, e2, e3, e4 = features  # e4 is reduced 512 channels

        d3 = self.up3(e4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))  # concat 256 + 512 = 768

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))  # concat 128 + 256 = 384

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))  # concat 64 + 128 = 192

        out = F.interpolate(d1, scale_factor=4, mode='bilinear', align_corners=False)
        return torch.sigmoid(self.final(out))
    
class SwinUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = swin
        self.channel_reducer = nn.Conv2d(1024, 512, kernel_size=1)
        self.decoder = UNetDecoder()

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)

        features = self.encoder(x)
        features = [self._to_channels_first(f) for f in features]

        features[3] = self.channel_reducer(features[3])

        output = self.decoder(features)
        return output

    def _to_channels_first(self, feature):
        if feature.dim() == 4:
            return feature.permute(0, 3, 1, 2).contiguous()
        elif feature.dim() == 3:
            B, N, C = feature.shape
            H = W = int(N ** 0.5)
            feature = feature.permute(0, 2, 1).contiguous()
            return feature.view(B, C, H, W)
        else:
            raise ValueError(f"Unexpected feature shape: {feature.shape}")

# ----------- Load Swin-UNet -----------
swinunet_model = SwinUNet()
swinunet_model.load_state_dict(torch.load("swinunet.pth", map_location = device))
swinunet_model = swinunet_model.to(device)
swinunet_model.eval()

# ----------- Load Classifier Model -----------
classifier_model = keras.models.load_model("cnn-swinunet")

# ----------- Transform -----------
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

# ----------- Segmentation -----------
def segmentation(image: Image.Image) -> np.ndarray:
    # Convert to grayscale and tensor
    image = image.convert("L")
    input_tensor = transform(image).unsqueeze(0).to(device)  # [1, 1, 224, 224]

    with torch.no_grad():
        mask_pred = swinunet_model(input_tensor)
        mask_pred = F.interpolate(mask_pred, size=(224, 224), mode="bilinear", align_corners=False)
        mask_pred = (mask_pred > 0.5).float()

    image_np = input_tensor.squeeze().cpu().numpy()  # [224, 224]
    mask_np = mask_pred.squeeze().cpu().numpy()      # [224, 224]
    
    combined = np.stack([image_np, mask_np], axis=-1)  # [224, 224, 2]
    return combined

def predict(image: Image.Image):
    try:
      combined = segmentation(image)
      combined = np.expand_dims(combined, axis=0) # Shape: (1, 224, 224, 2)

      probs = classifier_model.predict(combined)[0]
      return CLASSES[int(np.argmax(probs))]
    except Exception as e:
      traceback_str = traceback.format_exc()
      print(traceback_str)      
      return traceback_str

demo = gr.Interface(
    fn = predict,
    inputs = gr.Image(type = "pil", label = "Brain MRI"),
    outputs = gr.Label(num_top_classes = 4),
    title = "Brain‑Tumor Net)",
    description = "Returns: Glioma, Meningioma, No Tumor, Pituitary"
)

demo.launch()

if __name__ == "main":
    demo.launch()