# ============================================================ # Gradio App: ResNet18 Classification + SwinUNet Segmentation # With Proper AES Encryption for Uploaded Images # ============================================================ import os import io import torch import torch.nn as nn from torchvision import models, transforms import timm import gradio as gr import numpy as np from PIL import Image import cv2 from cryptography.fernet import Fernet # ---------------- Security Setup ---------------- KEY_PATH = "secret.key" if not os.path.exists(KEY_PATH): with open(KEY_PATH, "wb") as f: f.write(Fernet.generate_key()) with open(KEY_PATH, "rb") as f: key = f.read() fernet = Fernet(key) def encrypt_bytes(image_bytes): """Encrypt image bytes and return bytes.""" return fernet.encrypt(image_bytes) def decrypt_bytes(encrypted_bytes): """Decrypt bytes and return PIL image.""" decrypted = fernet.decrypt(encrypted_bytes) return Image.open(io.BytesIO(decrypted)) # ---------------- Device ---------------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---------------- Classification Model ---------------- class BrainTumorResNet18(nn.Module): def __init__(self, num_classes=4, pretrained=False): super().__init__() self.model = models.resnet18(pretrained=pretrained) in_features = self.model.fc.in_features self.model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(in_features, num_classes) ) def forward(self, x): return self.model(x) clf_model = BrainTumorResNet18(num_classes=4).to(DEVICE) clf_model.load_state_dict(torch.load("models/best_resnet18_mri.pth", map_location=DEVICE)) clf_model.eval() clf_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) CLASS_NAMES = ["glioma", "meningioma", "notumor", "pituitary"] # ---------------- Segmentation Model ---------------- class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.block(x) class SwinUNet(nn.Module): def __init__(self, encoder_name="swin_small_patch4_window7_224", pretrained=True, num_classes=1): super().__init__() self.encoder = timm.create_model(encoder_name, pretrained=pretrained, features_only=True, out_indices=(0,1,2,3)) enc_chs = self.encoder.feature_info.channels() self.up3 = nn.ConvTranspose2d(enc_chs[3], enc_chs[2], 2, stride=2) self.dec3 = ConvBlock(enc_chs[2]*2, enc_chs[2]) self.up2 = nn.ConvTranspose2d(enc_chs[2], enc_chs[1], 2, stride=2) self.dec2 = ConvBlock(enc_chs[1]*2, enc_chs[1]) self.up1 = nn.ConvTranspose2d(enc_chs[1], enc_chs[0], 2, stride=2) self.dec1 = ConvBlock(enc_chs[0]*2, enc_chs[0]) self.final_up = nn.ConvTranspose2d(enc_chs[0], 64, 2, stride=2) self.final_conv = nn.Sequential( nn.Conv2d(64, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, num_classes, 1) ) def _ensure_nchw(self, feat, expected_ch): if feat.ndim==4: if feat.shape[1]==expected_ch: return feat if feat.shape[-1]==expected_ch: return feat.permute(0,3,1,2).contiguous() return feat def forward(self, x): feats = self.encoder(x) expected = self.encoder.feature_info.channels() for i in range(len(feats)): feats[i] = self._ensure_nchw(feats[i], expected[i]) f0,f1,f2,f3 = feats d3 = self.up3(f3) if d3.shape[-2:] != f2.shape[-2:]: d3 = nn.functional.interpolate(d3, size=f2.shape[-2:], mode='bilinear', align_corners=False) d3 = self.dec3(torch.cat([d3,f2], dim=1)) d2 = self.up2(d3) if d2.shape[-2:] != f1.shape[-2:]: d2 = nn.functional.interpolate(d2, size=f1.shape[-2:], mode='bilinear', align_corners=False) d2 = self.dec2(torch.cat([d2,f1], dim=1)) d1 = self.up1(d2) if d1.shape[-2:] != f0.shape[-2:]: d1 = nn.functional.interpolate(d1, size=f0.shape[-2:], mode='bilinear', align_corners=False) d1 = self.dec1(torch.cat([d1,f0], dim=1)) out = self.final_up(d1) return self.final_conv(out) seg_model = SwinUNet().to(DEVICE) seg_model.load_state_dict(torch.load("models/swinunet_best (6).pth", map_location=DEVICE), strict=False) seg_model.eval() seg_transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor() ]) # ---------------- Inference Function ---------------- def predict(img): # Convert uploaded image to PIL pil_img = Image.fromarray(img).convert("RGB") # ---- Encrypt image in memory ---- img_bytes = io.BytesIO() pil_img.save(img_bytes, format="PNG") encrypted_bytes = encrypt_bytes(img_bytes.getvalue()) # ---- Decrypt immediately for inference ---- decrypted_img = decrypt_bytes(encrypted_bytes) # ---- Classification ---- x = clf_transform(decrypted_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = clf_model(x) probs = torch.softmax(logits, dim=1)[0].cpu().numpy() pred_class = CLASS_NAMES[np.argmax(probs)] conf = float(np.max(probs)) # ---- Segmentation ---- seg_in = seg_transform(decrypted_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): mask = seg_model(seg_in)[0,0].cpu().numpy() mask = (mask > 0.5).astype(np.uint8) # Overlay mask img_np = np.array(decrypted_img.resize((224,224))) mask_resized = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST) overlay = img_np.copy() overlay[mask_resized > 0] = [255, 0, 0] # red blended = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0) return blended, f"Prediction: {pred_class} (conf: {conf:.2f})" # ---------------- Gradio UI ---------------- example_images = [ "images/img1.jpg", "images/img2.jpg", "images/img3.jpg", "images/img4.jpg", "images/img5.jpg", "images/img6.jpg", "images/img7.jpg", "images/img8.jpg", "images/img9.jpg", "images/img10.jpg", ] demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs=[gr.Image(type="numpy"), gr.Textbox()], title="ONCOSCAN - (Brain Tumor Classification + Segmentation) ", description="Upload an MRI or click on one of the example images. The app will classify tumor type (ResNet18) and segment tumor region (SwinUNet).", examples=example_images, cache_examples=False ) demo.launch(debug=True)