File size: 7,654 Bytes
1d1ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
83135ce
 
 
 
1d1ede2
 
 
 
 
83135ce
 
1d1ede2
 
 
 
 
 
 
83135ce
 
 
 
1d1ede2
83135ce
 
 
 
1d1ede2
83135ce
1d1ede2
 
 
 
 
 
 
 
83135ce
1d1ede2
 
 
 
 
 
 
 
83135ce
 
 
 
 
 
 
 
 
1d1ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83135ce
1d1ede2
 
 
 
83135ce
1d1ede2
 
83135ce
1d1ede2
 
 
 
 
 
 
 
83135ce
 
1d1ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83135ce
1d1ede2
 
 
 
 
 
 
 
 
 
83135ce
 
1d1ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83135ce
1d1ede2
 
 
 
 
 
 
 
 
 
83135ce
 
1d1ede2
 
 
 
 
 
 
 
83135ce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import time
import torch
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.models import AutoencoderKL
from PIL import Image
import cv2
import numpy as np
import gradio as gr
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
import subprocess

# Install Real-ESRGAN with dependencies
try:
    subprocess.run("pip install git+https://github.com/inference-sh/Real-ESRGAN.git basicsr opencv-python-headless --no-cache-dir", shell=True, check=True)
except subprocess.CalledProcessError as e:
    print(f"Failed to install Real-ESRGAN: {e}")

from RealESRGAN import RealESRGAN

# Force CPU usage
device = torch.device("cpu")
ENABLE_CPU_OFFLOAD = True
USE_TORCH_COMPILE = False

# Create model directories
os.makedirs("models/Stable-diffusion", exist_ok=True)
os.makedirs("models/ControlNet", exist_ok=True)
os.makedirs("models/VAE", exist_ok=True)
os.makedirs("models/upscalers", exist_ok=True)

# Download models on-demand
def download_model(repo_id, filename, local_dir):
    try:
        print(f"Downloading {filename} from {repo_id}...")
        hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
        print(f"Successfully downloaded {filename}")
    except Exception as e:
        print(f"Failed to download {filename}: {e}")
        raise

# Timer decorator
def timer_func(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        print(f"{func.__name__} took {time.time() - start_time:.2f} seconds")
        return result
    return wrapper

# Lazy pipeline
class LazyLoadPipeline:
    def __init__(self):
        self.pipe = None

    @timer_func
    def load(self):
        if self.pipe is None:
            print("Setting up pipeline...")
            # Download models if not present
            for model, (repo_id, filename, local_dir) in [
                ("MODEL", ("dantea1118/juggernaut_reborn", "juggernaut_reborn.safetensors", "models/Stable-diffusion")),
                ("CONTROLNET", ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")),
                ("VAE", ("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors", "models/VAE")),
            ]:
                if not os.path.exists(os.path.join(local_dir, filename)):
                    download_model(repo_id, filename, local_dir)

            controlnet = ControlNetModel.from_single_file(
                "models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
            )
            model_path = "models/Stable-diffusion/juggernaut_reborn.safetensors"
            pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
                model_path,
                controlnet=controlnet,
                torch_dtype=torch.float16,
                use_safetensors=True,
            )
            vae = AutoencoderKL.from_single_file(
                "models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
                torch_dtype=torch.float16
            )
            pipe.vae = vae
            pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
            pipe.to(device)
            if ENABLE_CPU_OFFLOAD:
                print("Enabling CPU offloading...")
                pipe.enable_model_cpu_offload()
            self.pipe = pipe
        return self.pipe

    def __call__(self, *args, **kwargs):
        if self.pipe is None:
            self.load()
        return self.pipe(*args, **kwargs)

# Lazy Real-ESRGAN
class LazyRealESRGAN:
    def __init__(self, device, scale):
        self.device = device
        self.scale = scale
        self.model = None

    def load_model(self):
        if self.model is None:
            if not os.path.exists(f"models/upscalers/RealESRGAN_x{self.scale}.pth"):
                download_model("ai-forever/Real-ESRGAN", f"RealESRGAN_x{self.scale}.pth", "models/upscalers")
            self.model = RealESRGAN(self.device, scale=self.scale)
            self.model.load_weights(f'models/upscalers/RealESRGAN_x{self.scale}.pth', download=False)

    def predict(self, img):
        self.load_model()
        return self.model.predict(img)

lazy_realesrgan_x2 = LazyRealESRGAN(device, scale=2)

@timer_func
def resize_and_upscale(input_image, resolution):
    input_image = input_image.convert("RGB")
    W, H = input_image.size
    k = float(resolution) / min(H, W)
    H = int(round(H * k / 64.0)) * 64
    W = int(round(W * k / 64.0)) * 64
    img = input_image.resize((W, H), resample=Image.LANCZOS)
    img = lazy_realesrgan_x2.predict(img)
    return img

@timer_func
def create_hdr_effect(original_image, hdr):
    if hdr == 0:
        return original_image
    cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
    factors = [1.0 - 0.7 * hdr, 1.0, 1.0 + 0.2 * hdr]
    images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
    merge_mertens = cv2.createMergeMertens()
    hdr_image = merge_mertens.process(images)
    hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
    return Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))

lazy_pipe = LazyLoadPipeline()

@timer_func
def gradio_process_image(input_image, resolution, num_inference_steps, strength, hdr, guidance_scale):
    if input_image is None:
        raise gr.Error("Please upload an input image.")
    print("Starting image processing...")
    condition_image = resize_and_upscale(input_image, resolution)
    condition_image = create_hdr_effect(condition_image, hdr)
    
    prompt = "masterpiece, best quality, highres"
    negative_prompt = "low quality, normal quality, blurry, lowres"
    
    options = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "image": condition_image,
        "control_image": condition_image,
        "width": condition_image.size[0],
        "height": condition_image.size[1],
        "strength": strength,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "generator": torch.Generator(device=device).manual_seed(0),
    }
    
    print("Running inference...")
    result = lazy_pipe(**options).images[0]
    print("Image processing completed successfully")
    
    return [np.array(input_image), np.array(result)]

# Gradio interface
title = """<h1 align="center">Image Upscaler with Tile ControlNet</h1>
<p align="center">CPU-optimized for Hugging Face Spaces</p>"""

with gr.Blocks() as demo:
    gr.HTML(title)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            run_button = gr.Button("Enhance Image")
        with gr.Column():
            output_slider = ImageSlider(label="Before / After", type="numpy")
    with gr.Accordion("Advanced Options", open=False):
        resolution = gr.Slider(minimum=256, maximum=768, value=512, step=64, label="Resolution")
        num_inference_steps = gr.Slider(minimum=1, maximum=15, value=10, step=1, label="Inference Steps")
        strength = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.01, label="Strength")
        hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
        guidance_scale = gr.Slider(minimum=0, maximum=10, value=3, step=0.5, label="Guidance Scale")

    run_button.click(fn=gradio_process_image, 
                     inputs=[input_image, resolution, num_inference_steps, strength, hdr, guidance_scale],
                     outputs=output_slider)

if __name__ == "__main__":
    demo.launch()