Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,599 Bytes
465d7e4 bce143a 465d7e4 bce143a 465d7e4 407fa6f 465d7e4 407fa6f 465d7e4 bce143a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import logging
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import numpy as np
import spaces
from pytorch_grad_cam.utils.image import show_cam_on_image
from scripts.trainer import XrayReg
# Model selection dropdown options (extendable)
MODEL_OPTIONS = {
"XrayReg (912yp4l6) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
"XrayReg (ie399gjr) [vit_small_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/ie399gjr/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_small_patch16_224_in21k"
},
"XrayReg (kcku20nx) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/kcku20nx/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
"XrayReg (ohtmkj0i) [vit_base_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/ohtmkj0i/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_base_patch16_224_in21k"
},
"XrayReg (vlk8qrkx) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/vlk8qrkx/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
}
def preprocess_image(inp):
"""
Preprocess the input image.
Returns:
input_tensor: Tensor to be fed into the model.
rgb_img: NumPy array normalized to [0, 1] for GradCAM visualization.
"""
try:
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
)
input_tensor = preprocess(inp).unsqueeze(0)
rgb_img = np.array(inp.resize((224, 224))).astype(np.float32) / 255.0
return input_tensor, rgb_img
except Exception as e:
logging.error("Error in image preprocessing: %s", e)
raise
def load_custom_model(model_key):
model_info = MODEL_OPTIONS[model_key]
# Pass model_name to config for correct model instantiation
config = {"model": {"name": model_info["model_name"]}}
model = XrayReg.load_from_checkpoint(model_info["ckpt"])
model = model.model.cuda() if torch.cuda.is_available() else model.model
model.eval()
for param in model.parameters():
param.requires_grad = True
return model
def preprocess_image_custom(inp):
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
]
)
input_tensor = preprocess(inp).unsqueeze(0)
rgb_img = np.array(inp.resize((224, 224)).convert("RGB")).astype(np.float32) / 255.0
return input_tensor, rgb_img
def predict_custom(model, input_tensor):
with torch.no_grad():
input_tensor = (
input_tensor.cuda() if torch.cuda.is_available() else input_tensor
)
pred = model(input_tensor)
pred = pred.cpu().numpy().flatten()[0]
return float(pred)
@spaces.GPU
def predict_and_cam_custom(inp, model):
input_tensor, rgb_img = preprocess_image_custom(inp)
value = predict_custom(model, input_tensor)
# GradCAM for regression: use last conv layer, target output
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
target_layers = [
layer
for name, layer in model.named_modules()
if isinstance(layer, torch.nn.Conv2d)
][-1:]
gradcam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0)] # For regression, just use output
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Return as tuple (number, image) for Gradio
return value, cam_pil
def create_interface_custom():
# Use stateful model cache to avoid reloading on every prediction
from functools import lru_cache
@lru_cache(maxsize=5)
def cached_load_model(model_key):
return load_custom_model(model_key)
def predict_wrapper(inp, model_key):
model = cached_load_model(model_key)
return predict_and_cam_custom(inp, model)
interface = gr.Interface(
fn=predict_wrapper,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model"),
],
outputs=[
gr.Number(label="Regression Output"),
gr.Image(type="pil", label="GradCAM Visualization"),
],
examples=None,
title="Xray Regression Gradio App",
description="Upload an X-ray image and select a model to get regression output and GradCAM visualization.",
allow_flagging="never",
live=True, # Ensures model reloads on dropdown change
)
return interface
def download_models():
import huggingface_hub
repo_name = "SuperSecureHuman/xray-reg-models"
local_dir = "./"
huggingface_hub.snapshot_download(
repo_id=repo_name,
local_dir=local_dir,
)
def main():
# Download models if not already present
try:
download_models()
except Exception as e:
logging.error("Error downloading models: %s", e)
exit(1)
logging.basicConfig(level=logging.INFO)
interface = create_interface_custom()
interface.launch()
if __name__ == "__main__":
main() |