Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
import subprocess | |
from pathlib import Path | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
INPUT_DIR = "samples" | |
OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: | |
w, h = img.size | |
scale = size / min(w, h) | |
new_w, new_h = int(w * scale), int(h * scale) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
left = (new_w - size) // 2 | |
top = (new_h - size) // 2 | |
return img.crop((left, top, left + size, top + size)) | |
def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image: | |
try: | |
orig = Image.open(image_path).convert("RGB") | |
except Exception as e: | |
fallback = Image.new("RGB", (512, 512), (200, 200, 200)) | |
from PIL import ImageDraw | |
draw = ImageDraw.Draw(fallback) | |
draw.text((20, 20), f"Error:\n{e}", fill="red") | |
return fallback | |
base = resize_and_center_crop(orig, 512) | |
scale_int = int(scale_option.replace("x", "")) | |
if scale_int == 1: sizes = [512] * 4 | |
else: sizes = [512 // (scale_int * (2 ** i)) for i in range(4)] | |
from PIL import ImageDraw | |
draw = ImageDraw.Draw(base) | |
colors = ["red", "lime", "cyan", "yellow"] | |
width = 3 | |
for idx, s in enumerate(sizes): | |
x0 = (512 - s) // 2 | |
y0 = (512 - s) // 2 | |
x1 = x0 + s | |
y1 = y0 + s | |
draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx], width=width) | |
return base | |
def run_with_upload(uploaded_image_path, upscale_option, session_id=None): | |
""" | |
Each invocation creates/uses: | |
- samples/<session_id>/input.png ← user’s uploaded image | |
- inference_results/coz_vlmprompt/<session_id>/per-sample/input/*.png ← inference outputs | |
""" | |
if uploaded_image_path is None: | |
return [] | |
# 1) Prepare a per-session input directory | |
print(session_id) | |
session_folder = os.path.join(INPUT_DIR, str(session_id)) | |
os.makedirs(session_folder, exist_ok=True) | |
# 2) Clear only this session’s folder | |
for fn in os.listdir(session_folder): | |
full_path = os.path.join(session_folder, fn) | |
if os.path.isfile(full_path) or os.path.islink(full_path): | |
os.remove(full_path) | |
elif os.path.isdir(full_path): | |
shutil.rmtree(full_path) | |
# 3) Save uploaded image to session_folder/input.png | |
try: | |
pil_img = Image.open(uploaded_image_path).convert("RGB") | |
save_path = Path(session_folder) / "input.png" | |
pil_img.save(save_path, format="PNG") | |
except Exception as e: | |
print(f"Error: could not save uploaded image: {e}") | |
return [] | |
# 4) Define a per-session output directory | |
session_output_dir = os.path.join(OUTPUT_DIR, str(session_id)) | |
os.makedirs(session_output_dir, exist_ok=True) | |
# 5) Build and run the inference command | |
upscale_value = upscale_option.replace("x", "") | |
cmd = [ | |
"python", "inference_coz.py", | |
"-i", session_folder, | |
"-o", session_output_dir, | |
"--rec_type", "recursive_multiscale", | |
"--prompt_type", "vlm", | |
"--upscale", upscale_value, | |
"--lora_path", "ckpt/SR_LoRA/model_20001.pkl", | |
"--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", | |
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", | |
"--ram_ft_path", "ckpt/DAPE/DAPE.pth", | |
"--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" | |
] | |
try: | |
subprocess.run(cmd, check=True) | |
except subprocess.CalledProcessError as err: | |
print("Inference failed:", err) | |
return [] | |
# 6) Gather output file paths (1.png through 4.png) | |
per_sample_dir = os.path.join(session_output_dir, "per-sample", "input") | |
expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)] | |
for fp in expected_files: | |
if not os.path.isfile(fp): | |
print(f"Warning: expected file not found: {fp}") | |
return [] | |
return expected_files | |
def get_caption(src_gallery, evt: gr.SelectData): | |
if not src_gallery or not os.path.isfile(src_gallery[evt.index][0]): | |
return "No caption available." | |
selected_image_path = src_gallery[evt.index][0] | |
base = os.path.basename(selected_image_path) # e.g. "2.png" | |
stem = os.path.splitext(base)[0] # e.g. "2" | |
txt_folder = os.path.join(OUTPUT_DIR, str(evt.index), "per-sample", "input", "txt") | |
txt_path = os.path.join(txt_folder, f"{int(stem) - 1}.txt") | |
if not os.path.isfile(txt_path): | |
return f"Caption file not found: {int(stem) - 1}.txt" | |
try: | |
with open(txt_path, "r", encoding="utf-8") as f: | |
caption = f.read().strip() | |
return caption if caption else "(Caption file is empty.)" | |
except Exception as e: | |
return f"Error reading caption: {e}" | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
<h1>Chain-of-Zoom</h1> | |
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p> | |
</div> | |
<br> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://github.com/bryanswkim/Chain-of-Zoom"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
</div> | |
""" | |
) | |
with gr.Column(elem_id="col-container"): | |
with gr.Row(): | |
with gr.Column(): | |
upload_image = gr.Image(label="Upload your input image", type="filepath") | |
upscale_radio = gr.Radio(choices=["1x", "2x", "4x"], value="2x", show_label=False) | |
run_button = gr.Button("Chain-of-Zoom it") | |
preview_with_box = gr.Image(label="Preview (512×512 with centered boxes)", type="pil", interactive=False) | |
with gr.Column(): | |
output_gallery = gr.Gallery(label="Inference Results", show_label=True, columns=[2], rows=[2]) | |
caption_text = gr.Textbox(label="Caption", lines=4, placeholder="Click on any image above to see its caption here.") | |
upload_image.change( | |
fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None, | |
inputs=[upload_image, upscale_radio], | |
outputs=[preview_with_box] | |
) | |
upscale_radio.change( | |
fn=lambda img_path, scale_opt: make_preview_with_boxes(img_path, scale_opt) if img_path is not None else None, | |
inputs=[upload_image, upscale_radio], | |
outputs=[preview_with_box] | |
) | |
# Note: gr.State() will pass session_id automatically | |
run_button.click( | |
fn=run_with_upload, | |
inputs=[upload_image, upscale_radio, gr.State()], | |
outputs=[output_gallery] | |
) | |
output_gallery.select( | |
fn=get_caption, | |
inputs=[output_gallery], | |
outputs=[caption_text] | |
) | |
demo.launch(share=True) | |