Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
import gradio as gr | |
import numpy as np | |
import spaces | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from scripts.trainer import XrayReg | |
# Model selection dropdown options (extendable) | |
MODEL_OPTIONS = { | |
"XrayReg (912yp4l6) [vit_large_patch16_224_in21k]": { | |
"ckpt": "xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt", | |
"model_name": "vit_large_patch16_224_in21k" | |
}, | |
"XrayReg (ie399gjr) [vit_small_patch16_224_in21k]": { | |
"ckpt": "xray_regression_noaug/ie399gjr/checkpoints/epoch=99-step=5900.ckpt", | |
"model_name": "vit_small_patch16_224_in21k" | |
}, | |
"XrayReg (kcku20nx) [vit_large_patch16_224_in21k]": { | |
"ckpt": "xray_regression_noaug/kcku20nx/checkpoints/epoch=99-step=5900.ckpt", | |
"model_name": "vit_large_patch16_224_in21k" | |
}, | |
"XrayReg (ohtmkj0i) [vit_base_patch16_224_in21k]": { | |
"ckpt": "xray_regression_noaug/ohtmkj0i/checkpoints/epoch=99-step=5900.ckpt", | |
"model_name": "vit_base_patch16_224_in21k" | |
}, | |
"XrayReg (vlk8qrkx) [vit_large_patch16_224_in21k]": { | |
"ckpt": "xray_regression_noaug/vlk8qrkx/checkpoints/epoch=99-step=5900.ckpt", | |
"model_name": "vit_large_patch16_224_in21k" | |
}, | |
} | |
def preprocess_image(inp): | |
""" | |
Preprocess the input image. | |
Returns: | |
input_tensor: Tensor to be fed into the model. | |
rgb_img: NumPy array normalized to [0, 1] for GradCAM visualization. | |
""" | |
try: | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
] | |
) | |
input_tensor = preprocess(inp).unsqueeze(0) | |
rgb_img = np.array(inp.resize((224, 224))).astype(np.float32) / 255.0 | |
return input_tensor, rgb_img | |
except Exception as e: | |
logging.error("Error in image preprocessing: %s", e) | |
raise | |
def load_custom_model(model_key): | |
model_info = MODEL_OPTIONS[model_key] | |
# Pass model_name to config for correct model instantiation | |
config = {"model": {"name": model_info["model_name"]}} | |
model = XrayReg.load_from_checkpoint(model_info["ckpt"]) | |
model = model.model.cuda() if torch.cuda.is_available() else model.model | |
model.eval() | |
for param in model.parameters(): | |
param.requires_grad = True | |
return model | |
def preprocess_image_custom(inp): | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((224, 224)), | |
transforms.Grayscale(num_output_channels=1), | |
transforms.ToTensor(), | |
] | |
) | |
input_tensor = preprocess(inp).unsqueeze(0) | |
rgb_img = np.array(inp.resize((224, 224)).convert("RGB")).astype(np.float32) / 255.0 | |
return input_tensor, rgb_img | |
def predict_custom(model, input_tensor): | |
with torch.no_grad(): | |
input_tensor = ( | |
input_tensor.cuda() if torch.cuda.is_available() else input_tensor | |
) | |
pred = model(input_tensor) | |
pred = pred.cpu().numpy().flatten()[0] | |
return float(pred) | |
def predict_and_cam_custom(inp, model): | |
input_tensor, rgb_img = preprocess_image_custom(inp) | |
value = predict_custom(model, input_tensor) | |
# GradCAM for regression: use last conv layer, target output | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
target_layers = [ | |
layer | |
for name, layer in model.named_modules() | |
if isinstance(layer, torch.nn.Conv2d) | |
][-1:] | |
gradcam = GradCAM(model=model, target_layers=target_layers) | |
targets = [ClassifierOutputTarget(0)] # For regression, just use output | |
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0] | |
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
cam_pil = Image.fromarray(cam_image) | |
# Return as tuple (number, image) for Gradio | |
return value, cam_pil | |
def create_interface_custom(): | |
# Use stateful model cache to avoid reloading on every prediction | |
from functools import lru_cache | |
def cached_load_model(model_key): | |
return load_custom_model(model_key) | |
def predict_wrapper(inp, model_key): | |
model = cached_load_model(model_key) | |
return predict_and_cam_custom(inp, model) | |
interface = gr.Interface( | |
fn=predict_wrapper, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model"), | |
], | |
outputs=[ | |
gr.Number(label="Regression Output"), | |
gr.Image(type="pil", label="GradCAM Visualization"), | |
], | |
examples=None, | |
title="Xray Regression Gradio App", | |
description="Upload an X-ray image and select a model to get regression output and GradCAM visualization.", | |
allow_flagging="never", | |
live=True, # Ensures model reloads on dropdown change | |
) | |
return interface | |
def download_models(): | |
import huggingface_hub | |
repo_name = "SuperSecureHuman/xray-reg-models" | |
local_dir = "./" | |
huggingface_hub.snapshot_download( | |
repo_id=repo_name, | |
local_dir=local_dir, | |
) | |
def main(): | |
# Download models if not already present | |
try: | |
download_models() | |
except Exception as e: | |
logging.error("Error downloading models: %s", e) | |
exit(1) | |
logging.basicConfig(level=logging.INFO) | |
interface = create_interface_custom() | |
interface.launch() | |
if __name__ == "__main__": | |
main() |