File size: 12,961 Bytes
267c122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "877fa66b",
   "metadata": {},
   "outputs": [
    {
     "ename": "OSError",
     "evalue": "Unable to synchronously open file (file signature not found)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mOSError\u001b[0m                                   Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[2], line 95\u001b[0m\n\u001b[0;32m     92\u001b[0m swinunet_model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m     94\u001b[0m \u001b[38;5;66;03m# ----------- Load Classifier Model -----------\u001b[39;00m\n\u001b[1;32m---> 95\u001b[0m classifier_model \u001b[38;5;241m=\u001b[39m \u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkeras\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./cnn-swinunet.keras\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m     97\u001b[0m \u001b[38;5;66;03m# ----------- Transform -----------\u001b[39;00m\n\u001b[0;32m     98\u001b[0m transform \u001b[38;5;241m=\u001b[39m T\u001b[38;5;241m.\u001b[39mCompose([\n\u001b[0;32m     99\u001b[0m     T\u001b[38;5;241m.\u001b[39mResize((\u001b[38;5;241m224\u001b[39m, \u001b[38;5;241m224\u001b[39m)),\n\u001b[0;32m    100\u001b[0m     T\u001b[38;5;241m.\u001b[39mToTensor()\n\u001b[0;32m    101\u001b[0m ])\n",
      "File \u001b[1;32mc:\\Users\\PC\\anaconda3\\envs\\main-gpu\\lib\\site-packages\\keras\\utils\\traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[0;32m     68\u001b[0m     \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[0;32m     69\u001b[0m     \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[1;32m---> 70\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m     71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m     72\u001b[0m     \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[1;32mc:\\Users\\PC\\anaconda3\\envs\\main-gpu\\lib\\site-packages\\h5py\\_hl\\files.py:564\u001b[0m, in \u001b[0;36mFile.__init__\u001b[1;34m(self, name, mode, driver, libver, userblock_size, swmr, rdcc_nslots, rdcc_nbytes, rdcc_w0, track_order, fs_strategy, fs_persist, fs_threshold, fs_page_size, page_buf_size, min_meta_keep, min_raw_keep, locking, alignment_threshold, alignment_interval, meta_block_size, **kwds)\u001b[0m\n\u001b[0;32m    555\u001b[0m     fapl \u001b[38;5;241m=\u001b[39m make_fapl(driver, libver, rdcc_nslots, rdcc_nbytes, rdcc_w0,\n\u001b[0;32m    556\u001b[0m                      locking, page_buf_size, min_meta_keep, min_raw_keep,\n\u001b[0;32m    557\u001b[0m                      alignment_threshold\u001b[38;5;241m=\u001b[39malignment_threshold,\n\u001b[0;32m    558\u001b[0m                      alignment_interval\u001b[38;5;241m=\u001b[39malignment_interval,\n\u001b[0;32m    559\u001b[0m                      meta_block_size\u001b[38;5;241m=\u001b[39mmeta_block_size,\n\u001b[0;32m    560\u001b[0m                      \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[0;32m    561\u001b[0m     fcpl \u001b[38;5;241m=\u001b[39m make_fcpl(track_order\u001b[38;5;241m=\u001b[39mtrack_order, fs_strategy\u001b[38;5;241m=\u001b[39mfs_strategy,\n\u001b[0;32m    562\u001b[0m                      fs_persist\u001b[38;5;241m=\u001b[39mfs_persist, fs_threshold\u001b[38;5;241m=\u001b[39mfs_threshold,\n\u001b[0;32m    563\u001b[0m                      fs_page_size\u001b[38;5;241m=\u001b[39mfs_page_size)\n\u001b[1;32m--> 564\u001b[0m     fid \u001b[38;5;241m=\u001b[39m \u001b[43mmake_fid\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muserblock_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfapl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfcpl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswmr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mswmr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    566\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(libver, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[0;32m    567\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_libver \u001b[38;5;241m=\u001b[39m libver\n",
      "File \u001b[1;32mc:\\Users\\PC\\anaconda3\\envs\\main-gpu\\lib\\site-packages\\h5py\\_hl\\files.py:238\u001b[0m, in \u001b[0;36mmake_fid\u001b[1;34m(name, mode, userblock_size, fapl, fcpl, swmr)\u001b[0m\n\u001b[0;32m    236\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m swmr \u001b[38;5;129;01mand\u001b[39;00m swmr_support:\n\u001b[0;32m    237\u001b[0m         flags \u001b[38;5;241m|\u001b[39m\u001b[38;5;241m=\u001b[39m h5f\u001b[38;5;241m.\u001b[39mACC_SWMR_READ\n\u001b[1;32m--> 238\u001b[0m     fid \u001b[38;5;241m=\u001b[39m \u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflags\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfapl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfapl\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    239\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr+\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m    240\u001b[0m     fid \u001b[38;5;241m=\u001b[39m h5f\u001b[38;5;241m.\u001b[39mopen(name, h5f\u001b[38;5;241m.\u001b[39mACC_RDWR, fapl\u001b[38;5;241m=\u001b[39mfapl)\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
      "File \u001b[1;32mh5py\\\\h5f.pyx:102\u001b[0m, in \u001b[0;36mh5py.h5f.open\u001b[1;34m()\u001b[0m\n",
      "\u001b[1;31mOSError\u001b[0m: Unable to synchronously open file (file signature not found)"
     ]
    }
   ],
   "source": [
    "import gradio as gr\n",
    "import numpy as np\n",
    "import timm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as T\n",
    "import tensorflow as tf\n",
    "from PIL import Image\n",
    "from skimage.transform import resize\n",
    "\n",
    "# ----------- Constants -----------\n",
    "CLASSES = [\"Glioma\", \"Meningioma\", \"No Tumor\", \"Pituitary\"]\n",
    "IMG_SIZE = (224, 224)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# ----------- Segmentation Model Definition -----------\n",
    "swin = timm.create_model('swin_base_patch4_window7_224', pretrained = False, features_only = True)\n",
    "\n",
    "class UNetDecoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        def conv_block(in_c, out_c):\n",
    "            return nn.Sequential(\n",
    "                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),\n",
    "                nn.ReLU(inplace=True)\n",
    "            )\n",
    "\n",
    "        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n",
    "        self.dec3 = conv_block(768, 256)\n",
    "\n",
    "        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
    "        self.dec2 = conv_block(384, 128)\n",
    "\n",
    "        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
    "        self.dec1 = conv_block(192, 64)\n",
    "\n",
    "        self.final = nn.Conv2d(64, 1, kernel_size=1)\n",
    "\n",
    "    def forward(self, features):\n",
    "        e1, e2, e3, e4 = features  # e4 is reduced 512 channels\n",
    "\n",
    "        d3 = self.up3(e4)\n",
    "        d3 = self.dec3(torch.cat([d3, e3], dim=1))  # concat 256 + 512 = 768\n",
    "\n",
    "        d2 = self.up2(d3)\n",
    "        d2 = self.dec2(torch.cat([d2, e2], dim=1))  # concat 128 + 256 = 384\n",
    "\n",
    "        d1 = self.up1(d2)\n",
    "        d1 = self.dec1(torch.cat([d1, e1], dim=1))  # concat 64 + 128 = 192\n",
    "\n",
    "        out = F.interpolate(d1, scale_factor=4, mode='bilinear', align_corners=False)\n",
    "        return torch.sigmoid(self.final(out))\n",
    "    \n",
    "class SwinUNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.encoder = swin\n",
    "        self.channel_reducer = nn.Conv2d(1024, 512, kernel_size=1)\n",
    "        self.decoder = UNetDecoder()\n",
    "\n",
    "    def forward(self, x):\n",
    "        if x.shape[1] == 1:\n",
    "            x = x.repeat(1, 3, 1, 1)\n",
    "\n",
    "        features = self.encoder(x)\n",
    "        features = [self._to_channels_first(f) for f in features]\n",
    "\n",
    "        features[3] = self.channel_reducer(features[3])\n",
    "\n",
    "        output = self.decoder(features)\n",
    "        return output\n",
    "\n",
    "    def _to_channels_first(self, feature):\n",
    "        if feature.dim() == 4:\n",
    "            return feature.permute(0, 3, 1, 2).contiguous()\n",
    "        elif feature.dim() == 3:\n",
    "            B, N, C = feature.shape\n",
    "            H = W = int(N ** 0.5)\n",
    "            feature = feature.permute(0, 2, 1).contiguous()\n",
    "            return feature.view(B, C, H, W)\n",
    "        else:\n",
    "            raise ValueError(f\"Unexpected feature shape: {feature.shape}\")\n",
    "\n",
    "# ----------- Load Swin-UNet -----------\n",
    "swinunet_model = SwinUNet()\n",
    "swinunet_model.load_state_dict(torch.load(\"./swinunet.pth\", map_location = device))\n",
    "swinunet_model = swinunet_model.to(device)\n",
    "swinunet_model.eval()\n",
    "\n",
    "# ----------- Load Classifier Model -----------\n",
    "classifier_model = tf.keras.models.load_model(\"./cnn-swinunet.keras\")\n",
    "\n",
    "# ----------- Transform -----------\n",
    "transform = T.Compose([\n",
    "    T.Resize((224, 224)),\n",
    "    T.ToTensor()\n",
    "])\n",
    "\n",
    "# ----------- Segmentation -----------\n",
    "def segmentation(image: Image.Image) -> np.ndarray:\n",
    "    # Convert to grayscale and tensor\n",
    "    image = image.convert(\"L\")\n",
    "    input_tensor = transform(image).unsqueeze(0).to(device)  # [1, 1, 224, 224]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        mask_pred = swinunet_model(input_tensor)\n",
    "        mask_pred = F.interpolate(mask_pred, size=(224, 224), mode=\"bilinear\", align_corners=False)\n",
    "        mask_pred = (mask_pred > 0.5).float()\n",
    "\n",
    "    image_np = input_tensor.squeeze().cpu().numpy()  # [224, 224]\n",
    "    mask_np = mask_pred.squeeze().cpu().numpy()      # [224, 224]\n",
    "    \n",
    "    combined = np.stack([image_np, mask_np], axis=-1)  # [224, 224, 2]\n",
    "    return combined\n",
    "\n",
    "def predict(image: Image.Image):\n",
    "    combined = segmentation(image)\n",
    "    combined = np.expand_dims(combined, axis=0) # Shape: (1, 224, 224, 2)\n",
    "\n",
    "    probs = classifier_model.predict(combined)[0]\n",
    "\n",
    "    return CLASSES[int(np.argmax(probs, axis=1)[0])]\n",
    "\n",
    "demo = gr.Interface(\n",
    "    fn = predict,\n",
    "    inputs = gr.Image(type=\"pil\", label=\"Brain MRI\"),\n",
    "    outputs = gr.Label(num_top_classes=4),\n",
    "    title = \"Brain‑Tumor Classifier (.tflite)\",\n",
    "    description = \"Returns: Glioma, Meningioma, No Tumor, Pituitary\"\n",
    ")\n",
    "\n",
    "\n",
    "demo.launch()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "main-gpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}