Chain-of-Zoom / app.py
alexnasa's picture
Update app.py
f6e8319 verified
raw
history blame
7.18 kB
import os
import shutil
import subprocess
from pathlib import Path
from PIL import Image
import gradio as gr
import spaces
INPUT_DIR = "samples"
OUTPUT_DIR = "inference_results/coz_vlmprompt"
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
w, h = img.size
scale = size / min(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img.crop((left, top, left + size, top + size))
def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
try:
orig = Image.open(image_path).convert("RGB")
except Exception as e:
fallback = Image.new("RGB", (512, 512), (200, 200, 200))
from PIL import ImageDraw
draw = ImageDraw.Draw(fallback)
draw.text((20, 20), f"Error:\n{e}", fill="red")
return fallback
base = resize_and_center_crop(orig, 512)
scale_int = int(scale_option.replace("x", ""))
if scale_int == 1: sizes = [512] * 4
else: sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
from PIL import ImageDraw
draw = ImageDraw.Draw(base)
colors = ["red", "lime", "cyan", "yellow"]
width = 3
for idx, s in enumerate(sizes):
x0 = (512 - s) // 2
y0 = (512 - s) // 2
x1 = x0 + s
y1 = y0 + s
draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx], width=width)
return base
@spaces.GPU(duration=120)
def run_with_upload(uploaded_image_path, upscale_option, session_id=None):
"""
Each invocation creates/uses:
- samples/<session_id>/input.png ← user’s uploaded image
- inference_results/coz_vlmprompt/<session_id>/per-sample/input/*.png ← inference outputs
"""
if uploaded_image_path is None:
return []
# 1) Prepare a per-session input directory
print(session_id)
session_folder = os.path.join(INPUT_DIR, str(session_id))
os.makedirs(session_folder, exist_ok=True)
# 2) Clear only this session’s folder
for fn in os.listdir(session_folder):
full_path = os.path.join(session_folder, fn)
if os.path.isfile(full_path) or os.path.islink(full_path):
os.remove(full_path)
elif os.path.isdir(full_path):
shutil.rmtree(full_path)
# 3) Save uploaded image to session_folder/input.png
try:
pil_img = Image.open(uploaded_image_path).convert("RGB")
save_path = Path(session_folder) / "input.png"
pil_img.save(save_path, format="PNG")
except Exception as e:
print(f"Error: could not save uploaded image: {e}")
return []
# 4) Define a per-session output directory
session_output_dir = os.path.join(OUTPUT_DIR, str(session_id))
os.makedirs(session_output_dir, exist_ok=True)
# 5) Build and run the inference command
upscale_value = upscale_option.replace("x", "")
cmd = [
"python", "inference_coz.py",
"-i", session_folder,
"-o", session_output_dir,
"--rec_type", "recursive_multiscale",
"--prompt_type", "vlm",
"--upscale", upscale_value,
"--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
"--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
"--ram_ft_path", "ckpt/DAPE/DAPE.pth",
"--ram_path", "ckpt/RAM/ram_swin_large_14m.pth"
]
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as err:
print("Inference failed:", err)
return []
# 6) Gather output file paths (1.png through 4.png)
per_sample_dir = os.path.join(session_output_dir, "per-sample", "input")
expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)]
for fp in expected_files:
if not os.path.isfile(fp):
print(f"Warning: expected file not found: {fp}")
return []
return expected_files
def get_caption(src_gallery, evt: gr.SelectData):
if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]):
return "No caption available."
selected_image_path = src_gallery[evt.index][0]
base = os.path.basename(selected_image_path) # e.g. "2.png"
stem = os.path.splitext(base)[0] # e.g. "2"
txt_folder = os.path.join(OUTPUT_DIR, str(evt.index), "per-sample", "input", "txt")
txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt")
if not os.path.isfile(txt_path):
return f"Caption file not found: {int(stem) - 1}.txt"
try:
with open(txt_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
return caption if caption else "(Caption file is empty.)"
except Exception as e:
return f"Error reading caption: {e}"
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<h1>Chain-of-Zoom</h1>
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/bryanswkim/Chain-of-Zoom">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
upload_image = gr.Image(label="Upload your input image", type="filepath")
upscale_radio = gr.Radio(choices=["1x", "2x", "4x"], value="2x", show_label=False)
run_button = gr.Button("Chain-of-Zoom it")
preview_with_box = gr.Image(label="Preview (512×512 with centered boxes)", type="pil", interactive=False)
with gr.Column():
output_gallery = gr.Gallery(label="Inference Results", show_label=True, columns=[2], rows=[2])
caption_text = gr.Textbox(label="Caption", lines=4, placeholder="Click on any image above to see its caption here.")
upload_image.change(
fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
inputs=[upload_image, upscale_radio],
outputs=[preview_with_box]
)
upscale_radio.change(
fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None,
inputs=[upload_image, upscale_radio],
outputs=[preview_with_box]
)
# Note: gr.State() will pass session_id automatically
run_button.click(
fn=run_with_upload,
inputs=[upload_image, upscale_radio, gr.State()],
outputs=[output_gallery]
)
output_gallery.select(
fn=get_caption,
inputs=[output_gallery],
outputs=[caption_text]
)
demo.launch(share=True)