chain-of-zoom / app.py
broadfield-dev's picture
Update app.py
33b3ad3 verified
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
import tempfile
from PIL import Image
from huggingface_hub import hf_hub_download, login
import torch
import logging
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Configuration ---
LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
DAPE_PATH = "ckpt/DAPE/DAPE.pth"
CHECKPOINT_FILES_CONFIG = {
"SR_LoRA": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_LoRA/model_20001.pkl", "target_path": LORA_PATH},
"SR_VAE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_VAE/vae_encoder_20001.pt", "target_path": VAE_PATH},
"DAPE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "DAPE/DAPE.pth", "target_path": DAPE_PATH},
}
# --- Device Detection ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")
# --- Hugging Face Token ---
HF_AUTH_TOKEN = os.environ.get("HF_TOKEN")
if HF_AUTH_TOKEN:
try:
login(token=HF_AUTH_TOKEN)
logger.info("Successfully logged in to Hugging Face Hub.")
except Exception as e:
logger.warning(f"Hugging Face login failed: {e}")
else:
logger.warning("HF_TOKEN not found. Downloads of gated models may fail.")
# --- Model Download Function ---
def download_coz_support_models():
logger.info("Checking and downloading CoZ support models...")
for model_key, model_info in CHECKPOINT_FILES_CONFIG.items():
target_file_path = Path(model_info["target_path"])
if not target_file_path.exists():
logger.info(f"Downloading {model_key} from {model_info['repo_id']}...")
target_file_path.parent.mkdir(parents=True, exist_ok=True)
try:
cached_file_path = hf_hub_download(
repo_id=model_info['repo_id'],
filename=model_info['filename'],
token=HF_AUTH_TOKEN,
force_download=False,
resume_download=True
)
shutil.copy(cached_file_path, target_file_path)
logger.info(f"{model_key} downloaded to {target_file_path}")
except Exception as e:
logger.error(f"Error downloading {model_key}: {e}")
raise
else:
logger.info(f"{model_key} already exists at {target_file_path}")
logger.info("All CoZ support models checked.")
# Download models at startup
try:
logger.info("Starting model download...")
download_coz_support_models()
logger.info("Model download completed.")
except Exception as e:
logger.error(f"Failed to download models: {e}")
raise
# --- Preload Stable Diffusion Model ---
logger.info("Preloading Stable Diffusion model configuration...")
try:
from diffusers import StableDiffusion3Pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
use_auth_token=HF_AUTH_TOKEN
)
logger.info("Stable Diffusion 3 model configuration preloaded.")
except Exception as e:
logger.error(f"Failed to preload Stable Diffusion model: {e}")
raise
# --- Main Inference Function ---
def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str, seed: int):
if input_image is None:
logger.error("No input image provided.")
raise gr.Error("Please upload an image.")
logger.info(f"Starting inference with magnification={magnification}, seed={seed}, caption={caption}")
with tempfile.TemporaryDirectory() as temp_base_str:
temp_base_dir = Path(temp_base_str)
input_img_parent_dir = temp_base_dir / "input_images_root"
input_img_parent_dir.mkdir(parents=True, exist_ok=True)
input_image_filename = "source_image.png"
input_image_path = input_img_parent_dir / input_image_filename
input_image.save(input_image_path, "PNG")
logger.info(f"Input image saved to {input_image_path}")
output_img_dir = temp_base_dir / "output_data"
output_img_dir.mkdir(parents=True, exist_ok=True)
# Check if inference_coz.py exists
if not Path("inference_coz.py").exists():
logger.error("inference_coz.py not found in repository.")
raise gr.Error("inference_coz.py not found in repository. Please check the Chain-of-Zoom repository.")
command = [
"python", "inference_coz.py",
"-i", str(input_img_parent_dir),
"-o", str(output_img_dir),
"--rec_type", "recursive_multiscale",
"--prompt_type", "vlm",
"--lora_path", LORA_PATH,
"--vae_path", VAE_PATH,
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
"--ram_ft_path", DAPE_PATH,
"--efficient_memory",
"--magnification", str(magnification),
"--seed", str(seed),
"--image_num", "16" if DEVICE == "cpu" else "32",
]
if DEVICE == "cpu":
command.append("--no_cuda") # Assumes inference_coz.py supports this flag
if caption and caption.strip():
command.extend(["--caption", caption.strip()])
if HF_AUTH_TOKEN:
command.extend(["--hf_token", HF_AUTH_TOKEN])
logger.info(f"Running command: {' '.join(command)}")
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
stdout_lines = []
stderr_lines = []
output_image_path_from_log = None
# Stream stdout
if process.stdout:
for line in iter(process.stdout.readline, ""):
logger.info(f"[CoZ STDOUT] {line.strip()}")
stdout_lines.append(line)
if "Saving image to" in line:
try:
output_image_path_from_log = line.split("Saving image to")[-1].strip()
except Exception:
pass
# Stream stderr
if process.stderr:
for line in iter(process.stderr.readline, ""):
logger.warning(f"[CoZ STDERR] {line.strip()}")
stderr_lines.append(line)
process.wait()
if process.returncode != 0:
error_message = f"Chain-of-Zoom failed.\nSTDOUT:\n{''.join(stdout_lines[-5:])}\nSTDERR:\n{''.join(stderr_lines[-5:])}"
logger.error(error_message)
raise gr.Error(f"Processing failed: {error_message}")
# Find output image
final_output_image_path = None
if output_image_path_from_log and Path(output_image_path_from_log).exists():
final_output_image_path = Path(output_image_path_from_log)
else:
processed_output_subdir = output_img_dir / input_img_parent_dir.name
potential_files = list(processed_output_subdir.glob(f"{Path(input_image_filename).stem}_x{magnification}_*.png"))
if potential_files:
final_output_image_path = potential_files[0]
if not final_output_image_path or not final_output_image_path.exists():
all_files = list(output_img_dir.rglob("*"))
logger.error(f"Output image not found in {output_img_dir}. Files found: {all_files}")
raise gr.Error(f"Output image not found in {output_img_dir}. Files found: {all_files}")
output_image = Image.open(final_output_image_path)
logger.info(f"Output image generated: {final_output_image_path}")
return output_image
# --- Gradio Interface ---
css = """
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
"""
title = "Chain-of-Zoom: Extreme Image Super-Resolution Demo"
description = """
Upload an image and select a magnification factor. Provide an optional caption (if empty, a VLM will generate one).
Optimized for CPU and GPU environments. Ensure HF_TOKEN is set in Space secrets for model access.
[Chain-of-Zoom GitHub](https://github.com/bryanswkim/Chain-of-Zoom)
"""
article = "<p style='text-align: center;'><a href='https://github.com/bryanswkim/Chain-of-Zoom' target='_blank'>Chain-of-Zoom GitHub</a></p>"
logger.info("Initializing Gradio interface...")
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(f"<h1 style='text-align: center'>{title}</h1>")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
input_image_gr = gr.Image(type="pil", label="Input Image")
magnification_gr = gr.Slider(minimum=2, maximum=16, step=1, value=4, label="Magnification Factor (2x-16x)")
caption_gr = gr.Textbox(label="Optional Caption", placeholder="e.g., a photo of a cat")
seed_gr = gr.Number(label="Seed", value=42, precision=0)
run_button = gr.Button("Zoom In!", variant="primary")
with gr.Column(scale=1):
output_image_gr = gr.Image(type="pil", label="Output Super-Resolved Image")
gr.Markdown(article)
run_button.click(
fn=run_chain_of_zoom,
inputs=[input_image_gr, magnification_gr, caption_gr, seed_gr],
outputs=output_image_gr
)
if __name__ == "__main__":
logger.info("Launching Gradio app...")
try:
demo.launch(server_name="0.0.0.0", server_port=7860)
logger.info("Gradio app launched successfully.")
except Exception as e:
logger.error(f"Failed to launch Gradio app: {e}")
raise