Spaces:
Sleeping
Sleeping
# ============================================================ | |
# Gradio App: ResNet18 Classification + SwinUNet Segmentation | |
# With Proper AES Encryption for Uploaded Images | |
# ============================================================ | |
import os | |
import io | |
import torch | |
import torch.nn as nn | |
from torchvision import models, transforms | |
import timm | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from cryptography.fernet import Fernet | |
# ---------------- Security Setup ---------------- | |
KEY_PATH = "secret.key" | |
if not os.path.exists(KEY_PATH): | |
with open(KEY_PATH, "wb") as f: | |
f.write(Fernet.generate_key()) | |
with open(KEY_PATH, "rb") as f: | |
key = f.read() | |
fernet = Fernet(key) | |
def encrypt_bytes(image_bytes): | |
"""Encrypt image bytes and return bytes.""" | |
return fernet.encrypt(image_bytes) | |
def decrypt_bytes(encrypted_bytes): | |
"""Decrypt bytes and return PIL image.""" | |
decrypted = fernet.decrypt(encrypted_bytes) | |
return Image.open(io.BytesIO(decrypted)) | |
# ---------------- Device ---------------- | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# ---------------- Classification Model ---------------- | |
class BrainTumorResNet18(nn.Module): | |
def __init__(self, num_classes=4, pretrained=False): | |
super().__init__() | |
self.model = models.resnet18(pretrained=pretrained) | |
in_features = self.model.fc.in_features | |
self.model.fc = nn.Sequential( | |
nn.Dropout(0.5), | |
nn.Linear(in_features, num_classes) | |
) | |
def forward(self, x): | |
return self.model(x) | |
clf_model = BrainTumorResNet18(num_classes=4).to(DEVICE) | |
clf_model.load_state_dict(torch.load("models/best_resnet18_mri.pth", map_location=DEVICE)) | |
clf_model.eval() | |
clf_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5,), std=(0.5,)) | |
]) | |
CLASS_NAMES = ["glioma", "meningioma", "notumor", "pituitary"] | |
# ---------------- Segmentation Model ---------------- | |
class ConvBlock(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super().__init__() | |
self.block = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class SwinUNet(nn.Module): | |
def __init__(self, encoder_name="swin_small_patch4_window7_224", pretrained=True, num_classes=1): | |
super().__init__() | |
self.encoder = timm.create_model(encoder_name, pretrained=pretrained, | |
features_only=True, out_indices=(0,1,2,3)) | |
enc_chs = self.encoder.feature_info.channels() | |
self.up3 = nn.ConvTranspose2d(enc_chs[3], enc_chs[2], 2, stride=2) | |
self.dec3 = ConvBlock(enc_chs[2]*2, enc_chs[2]) | |
self.up2 = nn.ConvTranspose2d(enc_chs[2], enc_chs[1], 2, stride=2) | |
self.dec2 = ConvBlock(enc_chs[1]*2, enc_chs[1]) | |
self.up1 = nn.ConvTranspose2d(enc_chs[1], enc_chs[0], 2, stride=2) | |
self.dec1 = ConvBlock(enc_chs[0]*2, enc_chs[0]) | |
self.final_up = nn.ConvTranspose2d(enc_chs[0], 64, 2, stride=2) | |
self.final_conv = nn.Sequential( | |
nn.Conv2d(64, 32, 3, padding=1, bias=False), | |
nn.BatchNorm2d(32), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(32, num_classes, 1) | |
) | |
def _ensure_nchw(self, feat, expected_ch): | |
if feat.ndim==4: | |
if feat.shape[1]==expected_ch: return feat | |
if feat.shape[-1]==expected_ch: return feat.permute(0,3,1,2).contiguous() | |
return feat | |
def forward(self, x): | |
feats = self.encoder(x) | |
expected = self.encoder.feature_info.channels() | |
for i in range(len(feats)): | |
feats[i] = self._ensure_nchw(feats[i], expected[i]) | |
f0,f1,f2,f3 = feats | |
d3 = self.up3(f3) | |
if d3.shape[-2:] != f2.shape[-2:]: | |
d3 = nn.functional.interpolate(d3, size=f2.shape[-2:], mode='bilinear', align_corners=False) | |
d3 = self.dec3(torch.cat([d3,f2], dim=1)) | |
d2 = self.up2(d3) | |
if d2.shape[-2:] != f1.shape[-2:]: | |
d2 = nn.functional.interpolate(d2, size=f1.shape[-2:], mode='bilinear', align_corners=False) | |
d2 = self.dec2(torch.cat([d2,f1], dim=1)) | |
d1 = self.up1(d2) | |
if d1.shape[-2:] != f0.shape[-2:]: | |
d1 = nn.functional.interpolate(d1, size=f0.shape[-2:], mode='bilinear', align_corners=False) | |
d1 = self.dec1(torch.cat([d1,f0], dim=1)) | |
out = self.final_up(d1) | |
return self.final_conv(out) | |
seg_model = SwinUNet().to(DEVICE) | |
seg_model.load_state_dict(torch.load("models/swinunet_best (6).pth", map_location=DEVICE), strict=False) | |
seg_model.eval() | |
seg_transform = transforms.Compose([ | |
transforms.Resize((224,224)), | |
transforms.ToTensor() | |
]) | |
# ---------------- Inference Function ---------------- | |
def predict(img): | |
# Convert uploaded image to PIL | |
pil_img = Image.fromarray(img).convert("RGB") | |
# ---- Encrypt image in memory ---- | |
img_bytes = io.BytesIO() | |
pil_img.save(img_bytes, format="PNG") | |
encrypted_bytes = encrypt_bytes(img_bytes.getvalue()) | |
# ---- Decrypt immediately for inference ---- | |
decrypted_img = decrypt_bytes(encrypted_bytes) | |
# ---- Classification ---- | |
x = clf_transform(decrypted_img).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
logits = clf_model(x) | |
probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
pred_class = CLASS_NAMES[np.argmax(probs)] | |
conf = float(np.max(probs)) | |
# ---- Segmentation ---- | |
seg_in = seg_transform(decrypted_img).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
mask = seg_model(seg_in)[0,0].cpu().numpy() | |
mask = (mask > 0.5).astype(np.uint8) | |
# Overlay mask | |
img_np = np.array(decrypted_img.resize((224,224))) | |
mask_resized = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST) | |
overlay = img_np.copy() | |
overlay[mask_resized > 0] = [255, 0, 0] # red | |
blended = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0) | |
return blended, f"Prediction: {pred_class} (conf: {conf:.2f})" | |
# ---------------- Gradio UI ---------------- | |
example_images = [ | |
"images/img1.jpg", | |
"images/img2.jpg", | |
"images/img3.jpg", | |
"images/img4.jpg", | |
"images/img5.jpg", | |
"images/img6.jpg", | |
"images/img7.jpg", | |
"images/img8.jpg", | |
"images/img9.jpg", | |
"images/img10.jpg", | |
] | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="numpy"), | |
outputs=[gr.Image(type="numpy"), gr.Textbox()], | |
title="ONCOSCAN - (Brain Tumor Classification + Segmentation) ", | |
description="Upload an MRI or click on one of the example images. The app will classify tumor type (ResNet18) and segment tumor region (SwinUNet).", | |
examples=example_images, | |
cache_examples=False | |
) | |
demo.launch(debug=True) | |