Spaces:
Sleeping
Sleeping
File size: 1,736 Bytes
a9a2d42 afc145a 903bf4c fdc7f23 a9a2d42 d43a4a2 d9be14e d43a4a2 fdc7f23 0aeeb0e fdc7f23 a9a2d42 fdc7f23 0aeeb0e fdc7f23 b3cb0d8 d9be14e 0aeeb0e d43a4a2 0aeeb0e d9be14e fdc7f23 4cc73a0 fdc7f23 d9be14e d43a4a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import gradio as gr
import torch
from PIL import Image
import numpy as np
from app.utils import recover_light_sources
from model.model import model
from torchvision import transforms
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
chk=torch.load('model/model_epoch_49.pth',map_location=device)
model.load_state_dict(chk['model_state_dict'])
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(), # -> (C,H,W), dtype=float32, range [0,1]
])
def evaluate(model, image):
"""
Run the model on the given image tensor and return output as numpy array (H,W,C) in [0,255].
"""
model.eval()
with torch.no_grad():
#image = image.to(device, dtype=torch.float32)
outputs = model(image)
outputs = torch.clamp(outputs, 0.0, 1.0)
outputs_np = outputs.squeeze(0).permute(1, 2, 0).cpu().numpy()
return (outputs_np * 255).astype(np.uint8) # Convert to uint8 for recovery step
def predict(input_image):
"""
Predict clean image from flare image, then recover light sources.
"""
# Resize and prepare input tensor
input_img = input_image.convert('RGB').resize((512, 512), Image.BILINEAR)
input_tensor = transform(input_img).unsqueeze(0).to(device, dtype=torch.float32)
# Get predicted clean image from model
pred_clean_img = evaluate(model, input_tensor) # uint8 predicted clean
# Recover light sources
final_img = recover_light_sources(network_output=pred_clean_img,original_image=input_img)
return final_img
demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"),outputs=gr.Image(), examples=["test_imgs/test1.png", "test_imgs/test2.png","test_imgs/test3.png"])
demo.launch()
|