Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
import numpy as np | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model=torch.load('model/model_epoch_49.pth',map_location=device) | |
def evaluate(model,image): | |
model.eval() | |
with torch.no_grad(): | |
image = image.to(device) | |
#outputs= model(image.unsqueeze(0)) | |
outputs= model(image) | |
return outputs.squeeze(0).squeeze(0).cpu() | |
def predict(input_image): | |
#input_image=Image.open(inp_img).convert('RGB') | |
input_image=input_image.resize((512,512)) | |
input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0 | |
mask=evaluate(model,input_image_torch) | |
mask=mask.permute(1,2,0).numpy() | |
return mask | |
def calculate_input_illuminance(image): | |
""" | |
Calculate illuminance: I_input = C_r + C_g + C_b | |
""" | |
return np.sum(image, axis=2) | |
def generate_recovery_weight_matrix(illuminance_matrix, alpha=15): | |
""" | |
Generate recovery weights using power function | |
Formula: W_r = ((I_input - min) / (max - min))^α | |
""" | |
I_min = np.min(illuminance_matrix) | |
I_max = np.max(illuminance_matrix) | |
if I_max == I_min: | |
normalized = np.zeros_like(illuminance_matrix) | |
else: | |
normalized = (illuminance_matrix - I_min) / (I_max - I_min) | |
# Apply power function with α = 15 | |
W_r = np.power(normalized, alpha) | |
return W_r | |
def recover_light_sources(original_image, network_output, alpha=15): | |
""" | |
Final recovery: I_final = (1 - W_r) ⊙ N(C) + W_r ⊙ C | |
""" | |
# Calculate illuminance and recovery weights | |
I_input = calculate_input_illuminance(original_image) | |
W_r = generate_recovery_weight_matrix(I_input, alpha) | |
# Expand to match image dimensions | |
W_r_expanded = np.expand_dims(W_r, axis=2) | |
W_r_expanded = np.repeat(W_r_expanded, 3, axis=2) | |
# Convex combination for light source recovery | |
I_final = (1 - W_r_expanded) * network_output + W_r_expanded * original_image | |
return np.clip(I_final, 0, 255).astype(np.uint8) |