File size: 2,058 Bytes
a9a2d42
 
 
 
 
92bae4b
a9a2d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from PIL import Image
import numpy as np

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('model/model_epoch_49.pth',map_location=device)

def evaluate(model,image):
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        #outputs= model(image.unsqueeze(0))
        outputs= model(image)
        return outputs.squeeze(0).squeeze(0).cpu()
    
def predict(input_image):
    #input_image=Image.open(inp_img).convert('RGB')
    input_image=input_image.resize((512,512))
    input_image_torch=torch.tensor(np.array(input_image)).permute(2,0,1).unsqueeze(0).float()/255.0
    mask=evaluate(model,input_image_torch)
    mask=mask.permute(1,2,0).numpy()
    return mask

def calculate_input_illuminance(image):
    """
    Calculate illuminance: I_input = C_r + C_g + C_b
    """
    return np.sum(image, axis=2)

def generate_recovery_weight_matrix(illuminance_matrix, alpha=15):
    """
    Generate recovery weights using power function
    Formula: W_r = ((I_input - min) / (max - min))^α
    """
    I_min = np.min(illuminance_matrix)
    I_max = np.max(illuminance_matrix)
    
    if I_max == I_min:
        normalized = np.zeros_like(illuminance_matrix)
    else:
        normalized = (illuminance_matrix - I_min) / (I_max - I_min)
    
    # Apply power function with α = 15
    W_r = np.power(normalized, alpha)
    return W_r

def recover_light_sources(original_image, network_output, alpha=15):
    """
    Final recovery: I_final = (1 - W_r) ⊙ N(C) + W_r ⊙ C
    """
    # Calculate illuminance and recovery weights
    I_input = calculate_input_illuminance(original_image)
    W_r = generate_recovery_weight_matrix(I_input, alpha)
    
    # Expand to match image dimensions
    W_r_expanded = np.expand_dims(W_r, axis=2)
    W_r_expanded = np.repeat(W_r_expanded, 3, axis=2)
    
    # Convex combination for light source recovery
    I_final = (1 - W_r_expanded) * network_output + W_r_expanded * original_image
    return np.clip(I_final, 0, 255).astype(np.uint8)