import logging import torch from PIL import Image from torchvision import transforms import gradio as gr import numpy as np import spaces from pytorch_grad_cam.utils.image import show_cam_on_image from scripts.trainer import XrayReg # Model selection dropdown options (extendable) MODEL_OPTIONS = { "XrayReg (912yp4l6) [vit_large_patch16_224_in21k]": { "ckpt": "xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt", "model_name": "vit_large_patch16_224_in21k" }, "XrayReg (ie399gjr) [vit_small_patch16_224_in21k]": { "ckpt": "xray_regression_noaug/ie399gjr/checkpoints/epoch=99-step=5900.ckpt", "model_name": "vit_small_patch16_224_in21k" }, "XrayReg (kcku20nx) [vit_large_patch16_224_in21k]": { "ckpt": "xray_regression_noaug/kcku20nx/checkpoints/epoch=99-step=5900.ckpt", "model_name": "vit_large_patch16_224_in21k" }, "XrayReg (ohtmkj0i) [vit_base_patch16_224_in21k]": { "ckpt": "xray_regression_noaug/ohtmkj0i/checkpoints/epoch=99-step=5900.ckpt", "model_name": "vit_base_patch16_224_in21k" }, "XrayReg (vlk8qrkx) [vit_large_patch16_224_in21k]": { "ckpt": "xray_regression_noaug/vlk8qrkx/checkpoints/epoch=99-step=5900.ckpt", "model_name": "vit_large_patch16_224_in21k" }, } def preprocess_image(inp): """ Preprocess the input image. Returns: input_tensor: Tensor to be fed into the model. rgb_img: NumPy array normalized to [0, 1] for GradCAM visualization. """ try: preprocess = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), ] ) input_tensor = preprocess(inp).unsqueeze(0) rgb_img = np.array(inp.resize((224, 224))).astype(np.float32) / 255.0 return input_tensor, rgb_img except Exception as e: logging.error("Error in image preprocessing: %s", e) raise def load_custom_model(model_key): model_info = MODEL_OPTIONS[model_key] # Pass model_name to config for correct model instantiation config = {"model": {"name": model_info["model_name"]}} model = XrayReg.load_from_checkpoint(model_info["ckpt"]) model = model.model.cuda() if torch.cuda.is_available() else model.model model.eval() for param in model.parameters(): param.requires_grad = True return model def preprocess_image_custom(inp): preprocess = transforms.Compose( [ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), ] ) input_tensor = preprocess(inp).unsqueeze(0) rgb_img = np.array(inp.resize((224, 224)).convert("RGB")).astype(np.float32) / 255.0 return input_tensor, rgb_img def predict_custom(model, input_tensor): with torch.no_grad(): input_tensor = ( input_tensor.cuda() if torch.cuda.is_available() else input_tensor ) pred = model(input_tensor) pred = pred.cpu().numpy().flatten()[0] return float(pred) @spaces.GPU def predict_and_cam_custom(inp, model): input_tensor, rgb_img = preprocess_image_custom(inp) value = predict_custom(model, input_tensor) # GradCAM for regression: use last conv layer, target output from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget target_layers = [ layer for name, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d) ][-1:] gradcam = GradCAM(model=model, target_layers=target_layers) targets = [ClassifierOutputTarget(0)] # For regression, just use output grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0] cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) cam_pil = Image.fromarray(cam_image) # Return as tuple (number, image) for Gradio return value, cam_pil def create_interface_custom(): # Use stateful model cache to avoid reloading on every prediction from functools import lru_cache @lru_cache(maxsize=5) def cached_load_model(model_key): return load_custom_model(model_key) def predict_wrapper(inp, model_key): model = cached_load_model(model_key) return predict_and_cam_custom(inp, model) interface = gr.Interface( fn=predict_wrapper, inputs=[ gr.Image(type="pil"), gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model"), ], outputs=[ gr.Number(label="Regression Output"), gr.Image(type="pil", label="GradCAM Visualization"), ], examples=None, title="Xray Regression Gradio App", description="Upload an X-ray image and select a model to get regression output and GradCAM visualization.", allow_flagging="never", live=True, # Ensures model reloads on dropdown change ) return interface def download_models(): import huggingface_hub repo_name = "SuperSecureHuman/xray-reg-models" local_dir = "./" huggingface_hub.snapshot_download( repo_id=repo_name, local_dir=local_dir, ) def main(): # Download models if not already present try: download_models() except Exception as e: logging.error("Error downloading models: %s", e) exit(1) logging.basicConfig(level=logging.INFO) interface = create_interface_custom() interface.launch() if __name__ == "__main__": main()