Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
from app.utils import recover_light_sources | |
from model.model import model | |
from torchvision import transforms | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
chk=torch.load('model/model_epoch_49.pth',map_location=device) | |
model.load_state_dict(chk['model_state_dict']) | |
transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), # -> (C,H,W), dtype=float32, range [0,1] | |
]) | |
def evaluate(model, image): | |
""" | |
Run the model on the given image tensor and return output as numpy array (H,W,C) in [0,255]. | |
""" | |
model.eval() | |
with torch.no_grad(): | |
#image = image.to(device, dtype=torch.float32) | |
outputs = model(image) | |
outputs = torch.clamp(outputs, 0.0, 1.0) | |
outputs_np = outputs.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
return (outputs_np * 255).astype(np.uint8) # Convert to uint8 for recovery step | |
def predict(input_image): | |
""" | |
Predict clean image from flare image, then recover light sources. | |
""" | |
# Resize and prepare input tensor | |
input_img = input_image.convert('RGB').resize((512, 512), Image.BILINEAR) | |
input_tensor = transform(input_img).unsqueeze(0).to(device, dtype=torch.float32) | |
# Get predicted clean image from model | |
pred_clean_img = evaluate(model, input_tensor) # uint8 predicted clean | |
# Recover light sources | |
final_img = recover_light_sources(network_output=pred_clean_img,original_image=input_img) | |
return final_img | |
demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"]) | |
demo.launch() | |