xray-reg / app.py
SuperSecureHuman's picture
Update app.py
bce143a verified
import logging
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import numpy as np
import spaces
from pytorch_grad_cam.utils.image import show_cam_on_image
from scripts.trainer import XrayReg
# Model selection dropdown options (extendable)
MODEL_OPTIONS = {
"XrayReg (912yp4l6) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
"XrayReg (ie399gjr) [vit_small_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/ie399gjr/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_small_patch16_224_in21k"
},
"XrayReg (kcku20nx) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/kcku20nx/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
"XrayReg (ohtmkj0i) [vit_base_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/ohtmkj0i/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_base_patch16_224_in21k"
},
"XrayReg (vlk8qrkx) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/vlk8qrkx/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
}
def preprocess_image(inp):
"""
Preprocess the input image.
Returns:
input_tensor: Tensor to be fed into the model.
rgb_img: NumPy array normalized to [0, 1] for GradCAM visualization.
"""
try:
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
)
input_tensor = preprocess(inp).unsqueeze(0)
rgb_img = np.array(inp.resize((224, 224))).astype(np.float32) / 255.0
return input_tensor, rgb_img
except Exception as e:
logging.error("Error in image preprocessing: %s", e)
raise
def load_custom_model(model_key):
model_info = MODEL_OPTIONS[model_key]
# Pass model_name to config for correct model instantiation
config = {"model": {"name": model_info["model_name"]}}
model = XrayReg.load_from_checkpoint(model_info["ckpt"])
model = model.model.cuda() if torch.cuda.is_available() else model.model
model.eval()
for param in model.parameters():
param.requires_grad = True
return model
def preprocess_image_custom(inp):
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
]
)
input_tensor = preprocess(inp).unsqueeze(0)
rgb_img = np.array(inp.resize((224, 224)).convert("RGB")).astype(np.float32) / 255.0
return input_tensor, rgb_img
def predict_custom(model, input_tensor):
with torch.no_grad():
input_tensor = (
input_tensor.cuda() if torch.cuda.is_available() else input_tensor
)
pred = model(input_tensor)
pred = pred.cpu().numpy().flatten()[0]
return float(pred)
@spaces.GPU
def predict_and_cam_custom(inp, model):
input_tensor, rgb_img = preprocess_image_custom(inp)
value = predict_custom(model, input_tensor)
# GradCAM for regression: use last conv layer, target output
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
target_layers = [
layer
for name, layer in model.named_modules()
if isinstance(layer, torch.nn.Conv2d)
][-1:]
gradcam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0)] # For regression, just use output
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Return as tuple (number, image) for Gradio
return value, cam_pil
def create_interface_custom():
# Use stateful model cache to avoid reloading on every prediction
from functools import lru_cache
@lru_cache(maxsize=5)
def cached_load_model(model_key):
return load_custom_model(model_key)
def predict_wrapper(inp, model_key):
model = cached_load_model(model_key)
return predict_and_cam_custom(inp, model)
interface = gr.Interface(
fn=predict_wrapper,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model"),
],
outputs=[
gr.Number(label="Regression Output"),
gr.Image(type="pil", label="GradCAM Visualization"),
],
examples=None,
title="Xray Regression Gradio App",
description="Upload an X-ray image and select a model to get regression output and GradCAM visualization.",
allow_flagging="never",
live=True, # Ensures model reloads on dropdown change
)
return interface
def download_models():
import huggingface_hub
repo_name = "SuperSecureHuman/xray-reg-models"
local_dir = "./"
huggingface_hub.snapshot_download(
repo_id=repo_name,
local_dir=local_dir,
)
def main():
# Download models if not already present
try:
download_models()
except Exception as e:
logging.error("Error downloading models: %s", e)
exit(1)
logging.basicConfig(level=logging.INFO)
interface = create_interface_custom()
interface.launch()
if __name__ == "__main__":
main()