Pushpesh
Fixed utils bug
92bae4b
import torch
from PIL import Image
import numpy as np
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('model/model_epoch_49.pth',map_location=device)
def evaluate(model,image):
model.eval()
with torch.no_grad():
image = image.to(device)
#outputs= model(image.unsqueeze(0))
outputs= model(image)
return outputs.squeeze(0).squeeze(0).cpu()
def predict(input_image):
#input_image=Image.open(inp_img).convert('RGB')
input_image=input_image.resize((512,512))
input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
mask=evaluate(model,input_image_torch)
mask=mask.permute(1,2,0).numpy()
return mask
def calculate_input_illuminance(image):
"""
Calculate illuminance: I_input = C_r + C_g + C_b
"""
return np.sum(image, axis=2)
def generate_recovery_weight_matrix(illuminance_matrix, alpha=15):
"""
Generate recovery weights using power function
Formula: W_r = ((I_input - min) / (max - min))^α
"""
I_min = np.min(illuminance_matrix)
I_max = np.max(illuminance_matrix)
if I_max == I_min:
normalized = np.zeros_like(illuminance_matrix)
else:
normalized = (illuminance_matrix - I_min) / (I_max - I_min)
# Apply power function with α = 15
W_r = np.power(normalized, alpha)
return W_r
def recover_light_sources(original_image, network_output, alpha=15):
"""
Final recovery: I_final = (1 - W_r) ⊙ N(C) + W_r ⊙ C
"""
# Calculate illuminance and recovery weights
I_input = calculate_input_illuminance(original_image)
W_r = generate_recovery_weight_matrix(I_input, alpha)
# Expand to match image dimensions
W_r_expanded = np.expand_dims(W_r, axis=2)
W_r_expanded = np.repeat(W_r_expanded, 3, axis=2)
# Convex combination for light source recovery
I_final = (1 - W_r_expanded) * network_output + W_r_expanded * original_image
return np.clip(I_final, 0, 255).astype(np.uint8)