import gradio as gr import torch from PIL import Image import numpy as np from app.utils import recover_light_sources from model.model import model from torchvision import transforms device=torch.device("cuda" if torch.cuda.is_available() else "cpu") chk=torch.load('model/model_epoch_49.pth',map_location=device) model.load_state_dict(chk['model_state_dict']) transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), # -> (C,H,W), dtype=float32, range [0,1] ]) def evaluate(model, image): """ Run the model on the given image tensor and return output as numpy array (H,W,C) in [0,255]. """ model.eval() with torch.no_grad(): #image = image.to(device, dtype=torch.float32) outputs = model(image) outputs = torch.clamp(outputs, 0.0, 1.0) outputs_np = outputs.squeeze(0).permute(1, 2, 0).cpu().numpy() return (outputs_np * 255).astype(np.uint8) # Convert to uint8 for recovery step def predict(input_image): """ Predict clean image from flare image, then recover light sources. """ # Resize and prepare input tensor input_img = input_image.convert('RGB').resize((512, 512), Image.BILINEAR) input_tensor = transform(input_img).unsqueeze(0).to(device, dtype=torch.float32) # Get predicted clean image from model pred_clean_img = evaluate(model, input_tensor) # uint8 predicted clean # Recover light sources final_img = recover_light_sources(network_output=pred_clean_img,original_image=input_img) return final_img demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"]) demo.launch()