upscaler / app.py
aronsaras's picture
Update app.py
83135ce verified
import os
import time
import torch
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.models import AutoencoderKL
from PIL import Image
import cv2
import numpy as np
import gradio as gr
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
import subprocess
# Install Real-ESRGAN with dependencies
try:
subprocess.run("pip install git+https://github.com/inference-sh/Real-ESRGAN.git basicsr opencv-python-headless --no-cache-dir", shell=True, check=True)
except subprocess.CalledProcessError as e:
print(f"Failed to install Real-ESRGAN: {e}")
from RealESRGAN import RealESRGAN
# Force CPU usage
device = torch.device("cpu")
ENABLE_CPU_OFFLOAD = True
USE_TORCH_COMPILE = False
# Create model directories
os.makedirs("models/Stable-diffusion", exist_ok=True)
os.makedirs("models/ControlNet", exist_ok=True)
os.makedirs("models/VAE", exist_ok=True)
os.makedirs("models/upscalers", exist_ok=True)
# Download models on-demand
def download_model(repo_id, filename, local_dir):
try:
print(f"Downloading {filename} from {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
print(f"Successfully downloaded {filename}")
except Exception as e:
print(f"Failed to download {filename}: {e}")
raise
# Timer decorator
def timer_func(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
print(f"{func.__name__} took {time.time() - start_time:.2f} seconds")
return result
return wrapper
# Lazy pipeline
class LazyLoadPipeline:
def __init__(self):
self.pipe = None
@timer_func
def load(self):
if self.pipe is None:
print("Setting up pipeline...")
# Download models if not present
for model, (repo_id, filename, local_dir) in [
("MODEL", ("dantea1118/juggernaut_reborn", "juggernaut_reborn.safetensors", "models/Stable-diffusion")),
("CONTROLNET", ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")),
("VAE", ("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors", "models/VAE")),
]:
if not os.path.exists(os.path.join(local_dir, filename)):
download_model(repo_id, filename, local_dir)
controlnet = ControlNetModel.from_single_file(
"models/ControlNet/control_v11f1e_sd15_tile.pth", torch_dtype=torch.float16
)
model_path = "models/Stable-diffusion/juggernaut_reborn.safetensors"
pipe = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
use_safetensors=True,
)
vae = AutoencoderKL.from_single_file(
"models/VAE/vae-ft-mse-840000-ema-pruned.safetensors",
torch_dtype=torch.float16
)
pipe.vae = vae
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
if ENABLE_CPU_OFFLOAD:
print("Enabling CPU offloading...")
pipe.enable_model_cpu_offload()
self.pipe = pipe
return self.pipe
def __call__(self, *args, **kwargs):
if self.pipe is None:
self.load()
return self.pipe(*args, **kwargs)
# Lazy Real-ESRGAN
class LazyRealESRGAN:
def __init__(self, device, scale):
self.device = device
self.scale = scale
self.model = None
def load_model(self):
if self.model is None:
if not os.path.exists(f"models/upscalers/RealESRGAN_x{self.scale}.pth"):
download_model("ai-forever/Real-ESRGAN", f"RealESRGAN_x{self.scale}.pth", "models/upscalers")
self.model = RealESRGAN(self.device, scale=self.scale)
self.model.load_weights(f'models/upscalers/RealESRGAN_x{self.scale}.pth', download=False)
def predict(self, img):
self.load_model()
return self.model.predict(img)
lazy_realesrgan_x2 = LazyRealESRGAN(device, scale=2)
@timer_func
def resize_and_upscale(input_image, resolution):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H = int(round(H * k / 64.0)) * 64
W = int(round(W * k / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
img = lazy_realesrgan_x2.predict(img)
return img
@timer_func
def create_hdr_effect(original_image, hdr):
if hdr == 0:
return original_image
cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
factors = [1.0 - 0.7 * hdr, 1.0, 1.0 + 0.2 * hdr]
images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
merge_mertens = cv2.createMergeMertens()
hdr_image = merge_mertens.process(images)
hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype('uint8')
return Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
lazy_pipe = LazyLoadPipeline()
@timer_func
def gradio_process_image(input_image, resolution, num_inference_steps, strength, hdr, guidance_scale):
if input_image is None:
raise gr.Error("Please upload an input image.")
print("Starting image processing...")
condition_image = resize_and_upscale(input_image, resolution)
condition_image = create_hdr_effect(condition_image, hdr)
prompt = "masterpiece, best quality, highres"
negative_prompt = "low quality, normal quality, blurry, lowres"
options = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": condition_image,
"control_image": condition_image,
"width": condition_image.size[0],
"height": condition_image.size[1],
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": torch.Generator(device=device).manual_seed(0),
}
print("Running inference...")
result = lazy_pipe(**options).images[0]
print("Image processing completed successfully")
return [np.array(input_image), np.array(result)]
# Gradio interface
title = """<h1 align="center">Image Upscaler with Tile ControlNet</h1>
<p align="center">CPU-optimized for Hugging Face Spaces</p>"""
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
run_button = gr.Button("Enhance Image")
with gr.Column():
output_slider = ImageSlider(label="Before / After", type="numpy")
with gr.Accordion("Advanced Options", open=False):
resolution = gr.Slider(minimum=256, maximum=768, value=512, step=64, label="Resolution")
num_inference_steps = gr.Slider(minimum=1, maximum=15, value=10, step=1, label="Inference Steps")
strength = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.01, label="Strength")
hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
guidance_scale = gr.Slider(minimum=0, maximum=10, value=3, step=0.5, label="Guidance Scale")
run_button.click(fn=gradio_process_image,
inputs=[input_image, resolution, num_inference_steps, strength, hdr, guidance_scale],
outputs=output_slider)
if __name__ == "__main__":
demo.launch()