File size: 5,393 Bytes
c8f6bca
 
 
 
 
 
 
d688c98
 
c8f6bca
55b930b
 
 
e89a371
 
 
 
 
 
d688c98
 
c8f6bca
d688c98
 
 
c8f6bca
 
d688c98
 
 
 
 
 
 
 
 
55b930b
 
d688c98
 
 
 
 
 
 
 
 
55b930b
d688c98
 
 
 
 
 
 
 
55b930b
e89a371
55b930b
 
e89a371
55b930b
c8f6bca
55b930b
 
d688c98
c8f6bca
d688c98
c8f6bca
 
e89a371
 
c8f6bca
e89a371
 
 
 
55b930b
d688c98
 
 
e89a371
55b930b
e89a371
d688c98
 
55b930b
e89a371
 
55b930b
e89a371
 
 
 
55b930b
e89a371
c8f6bca
 
 
 
 
 
 
 
 
 
 
 
d688c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8f6bca
d688c98
 
 
 
 
 
 
 
 
c8f6bca
 
d688c98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from PIL import Image
import os
import numpy as np
import tensorflow as tf
import requests
from skimage.color import lab2rgb
from models.autoencoder_gray2color import SpatialAttention
from models.unet_gray2color import SelfAttentionLayer

# Set float32 policy
tf.keras.mixed_precision.set_global_policy('float32')

# Model-specific input shapes
MODEL_INPUT_SHAPES = {
    "autoencoder": (512, 512),
    "unet": (1024, 1024),
    "transformer": (1024, 1024)
}

# Define model paths
load_model_paths = [
    "./ckpts/autoencoder/autoencoder_colorization_model.h5",
    "./ckpts/unet/unet_colorization_model.keras",
    "./ckpts/transformer/transformer_colorization_model.keras"
]

# Load models at startup
models = {}
print("Loading models...")
for path in load_model_paths:
    model_name = os.path.basename(os.path.dirname(path))
    if not os.path.exists(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        url_map = {
            "autoencoder": "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/best_model.h5",
            "unet": "https://huggingface.co/<your-username>/unet-grayscale2color-landscape/resolve/main/ckpts/unet_colorization_model.keras",  # Replace with valid URL
            "transformer": "https://huggingface.co/<your-username>/transformer-grayscale2color-landscape/resolve/main/ckpts/transformer_colorization_model.keras"  # Replace with valid URL
        }
        if model_name in url_map:
            print(f"Downloading {model_name} model from {url_map[model_name]}...")
            with requests.get(url_map[model_name], stream=True) as r:
                r.raise_for_status()
                with open(path, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            print(f"Download complete for {model_name}.")
    
    custom_objects = {
        "autoencoder": {'SpatialAttention': SpatialAttention},
        "unet": {'SelfAttentionLayer': SelfAttentionLayer},
        "transformer": None
    }
    print(f"Loading {model_name} model from {path}...")
    models[model_name] = tf.keras.models.load_model(
        path,
        custom_objects=custom_objects[model_name],
        compile=False
    )
    models[model_name].compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=7e-5),
        loss=tf.keras.losses.MeanSquaredError()
    )
    print(f"{model_name} model loaded.")

print("All models loaded.")

def process_image(input_img, model_name):
    # Store original input dimensions
    original_width, original_height = input_img.size
    # Get model-specific input shape
    width, height = MODEL_INPUT_SHAPES[model_name.lower()]
    # Convert PIL Image to grayscale and resize to model input size
    img = input_img.convert("L")
    img = img.resize((width, height))
    img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
    img_array = img_array[None, ..., 0:1]  # Shape: (1, height, width, 1)
    
    # Select model
    selected_model = models[model_name.lower()]
    # Run inference
    output_array = selected_model.predict(img_array)  # Shape: (1, height, width, 2)
    
    # Extract L* and a*b*
    L_channel = img_array[0, :, :, 0] * 100.0  # Denormalize L* to [0, 100]
    ab_channels = output_array[0] * 128.0  # Denormalize a*b* to [-128, 128]
    
    # Combine L*, a*, b*
    lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)
    
    # Convert to RGB
    rgb_array = lab2rgb(lab_image)
    rgb_array = np.clip(rgb_array, 0, 1) * 255.0
    rgb_image = Image.fromarray(rgb_array.astype(np.uint8), mode="RGB")
    
    # Resize output to original resolution
    rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
    return rgb_image

custom_css = """
body {background: linear-gradient(135deg, #f0f4f8 0%, #d9e2ec 100%) !important;}
.gradio-container {background: transparent !important;}
h1, .gr-title {color: #007bff !important; font-family: 'Segoe UI', sans-serif;}
.gr-description {color: #333333 !important; font-size: 1.1em;}
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.1);}
.gr-button {background: linear-gradient(90deg, #007bff 0%, #00c4cc 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
"""

demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L"),
        gr.Dropdown(
            choices=["Autoencoder", "Unet", "Transformer"],
            label="Select Model",
            value="Autoencoder"
        )
    ],
    outputs=gr.Image(type="pil", label="Colorized Output"),
    title="🌄 Gray2Color Landscape Colorization",
    description=(
        "<div style='font-size:1.15em;line-height:1.6em;'>"
        "Transform your <b>grayscale landscape</b> photos into vivid color using advanced deep learning models.<br>"
        "Upload a grayscale image, select a model (Autoencoder, U-Net, or Transformer), and see the results!"
        "</div>"
    ),
    theme="soft",
    css=custom_css,
    allow_flagging="never",
    examples=[
        ["examples/example_input_1.jpg", "Autoencoder"],
        ["examples/example_input_2.jpg", "Unet"]
    ]
)

if __name__ == "__main__":
    demo.launch()