xray-reg / app.py
SuperSecureHuman's picture
Update app.py
407fa6f verified
raw
history blame
5.62 kB
import logging
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import numpy as np
from pytorch_grad_cam.utils.image import show_cam_on_image
from scripts.trainer import XrayReg
# Model selection dropdown options (extendable)
MODEL_OPTIONS = {
"XrayReg (912yp4l6) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
"XrayReg (ie399gjr) [vit_small_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/ie399gjr/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_small_patch16_224_in21k"
},
"XrayReg (kcku20nx) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/kcku20nx/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
"XrayReg (ohtmkj0i) [vit_base_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/ohtmkj0i/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_base_patch16_224_in21k"
},
"XrayReg (vlk8qrkx) [vit_large_patch16_224_in21k]": {
"ckpt": "xray_regression_noaug/vlk8qrkx/checkpoints/epoch=99-step=5900.ckpt",
"model_name": "vit_large_patch16_224_in21k"
},
}
def preprocess_image(inp):
"""
Preprocess the input image.
Returns:
input_tensor: Tensor to be fed into the model.
rgb_img: NumPy array normalized to [0, 1] for GradCAM visualization.
"""
try:
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
)
input_tensor = preprocess(inp).unsqueeze(0)
rgb_img = np.array(inp.resize((224, 224))).astype(np.float32) / 255.0
return input_tensor, rgb_img
except Exception as e:
logging.error("Error in image preprocessing: %s", e)
raise
def load_custom_model(model_key):
model_info = MODEL_OPTIONS[model_key]
# Pass model_name to config for correct model instantiation
config = {"model": {"name": model_info["model_name"]}}
model = XrayReg.load_from_checkpoint(model_info["ckpt"])
model = model.model.cuda() if torch.cuda.is_available() else model.model
model.eval()
for param in model.parameters():
param.requires_grad = True
return model
def preprocess_image_custom(inp):
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
]
)
input_tensor = preprocess(inp).unsqueeze(0)
rgb_img = np.array(inp.resize((224, 224)).convert("RGB")).astype(np.float32) / 255.0
return input_tensor, rgb_img
def predict_custom(model, input_tensor):
with torch.no_grad():
input_tensor = (
input_tensor.cuda() if torch.cuda.is_available() else input_tensor
)
pred = model(input_tensor)
pred = pred.cpu().numpy().flatten()[0]
return float(pred)
def predict_and_cam_custom(inp, model):
input_tensor, rgb_img = preprocess_image_custom(inp)
value = predict_custom(model, input_tensor)
# GradCAM for regression: use last conv layer, target output
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
target_layers = [
layer
for name, layer in model.named_modules()
if isinstance(layer, torch.nn.Conv2d)
][-1:]
gradcam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0)] # For regression, just use output
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image)
# Return as tuple (number, image) for Gradio
return value, cam_pil
def create_interface_custom():
# Use stateful model cache to avoid reloading on every prediction
from functools import lru_cache
@lru_cache(maxsize=5)
def cached_load_model(model_key):
return load_custom_model(model_key)
def predict_wrapper(inp, model_key):
model = cached_load_model(model_key)
return predict_and_cam_custom(inp, model)
interface = gr.Interface(
fn=predict_wrapper,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model"),
],
outputs=[
gr.Number(label="Regression Output"),
gr.Image(type="pil", label="GradCAM Visualization"),
],
examples=None,
title="Xray Regression Gradio App",
description="Upload an X-ray image and select a model to get regression output and GradCAM visualization.",
allow_flagging="never",
live=True, # Ensures model reloads on dropdown change
)
return interface
def download_models():
import huggingface_hub
repo_name = "SuperSecureHuman/xray-reg-models"
local_dir = "./"
huggingface_hub.snapshot_download(
repo_id=repo_name,
local_dir=local_dir,
)
def main():
# Download models if not already present
try:
download_models()
except Exception as e:
logging.error("Error downloading models: %s", e)
exit(1)
logging.basicConfig(level=logging.INFO)
interface = create_interface_custom()
interface.launch()
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()