Spaces:
Runtime error
Runtime error
import gradio as gr | |
import subprocess | |
import os | |
import shutil | |
from pathlib import Path | |
import tempfile | |
from PIL import Image | |
from huggingface_hub import hf_hub_download, login | |
import torch | |
import logging | |
# --- Logging Setup --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl" | |
VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt" | |
DAPE_PATH = "ckpt/DAPE/DAPE.pth" | |
CHECKPOINT_FILES_CONFIG = { | |
"SR_LoRA": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_LoRA/model_20001.pkl", "target_path": LORA_PATH}, | |
"SR_VAE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_VAE/vae_encoder_20001.pt", "target_path": VAE_PATH}, | |
"DAPE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "DAPE/DAPE.pth", "target_path": DAPE_PATH}, | |
} | |
# --- Device Detection --- | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {DEVICE}") | |
# --- Hugging Face Token --- | |
HF_AUTH_TOKEN = os.environ.get("HF_TOKEN") | |
if HF_AUTH_TOKEN: | |
try: | |
login(token=HF_AUTH_TOKEN) | |
logger.info("Successfully logged in to Hugging Face Hub.") | |
except Exception as e: | |
logger.warning(f"Hugging Face login failed: {e}") | |
else: | |
logger.warning("HF_TOKEN not found. Downloads of gated models may fail.") | |
# --- Model Download Function --- | |
def download_coz_support_models(): | |
logger.info("Checking and downloading CoZ support models...") | |
for model_key, model_info in CHECKPOINT_FILES_CONFIG.items(): | |
target_file_path = Path(model_info["target_path"]) | |
if not target_file_path.exists(): | |
logger.info(f"Downloading {model_key} from {model_info['repo_id']}...") | |
target_file_path.parent.mkdir(parents=True, exist_ok=True) | |
try: | |
cached_file_path = hf_hub_download( | |
repo_id=model_info['repo_id'], | |
filename=model_info['filename'], | |
token=HF_AUTH_TOKEN, | |
force_download=False, | |
resume_download=True | |
) | |
shutil.copy(cached_file_path, target_file_path) | |
logger.info(f"{model_key} downloaded to {target_file_path}") | |
except Exception as e: | |
logger.error(f"Error downloading {model_key}: {e}") | |
raise | |
else: | |
logger.info(f"{model_key} already exists at {target_file_path}") | |
logger.info("All CoZ support models checked.") | |
# Download models at startup | |
try: | |
logger.info("Starting model download...") | |
download_coz_support_models() | |
logger.info("Model download completed.") | |
except Exception as e: | |
logger.error(f"Failed to download models: {e}") | |
raise | |
# --- Preload Stable Diffusion Model --- | |
logger.info("Preloading Stable Diffusion model configuration...") | |
try: | |
from diffusers import StableDiffusion3Pipeline | |
pipeline = StableDiffusion3Pipeline.from_pretrained( | |
"stabilityai/stable-diffusion-3-medium-diffusers", | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
use_auth_token=HF_AUTH_TOKEN | |
) | |
logger.info("Stable Diffusion 3 model configuration preloaded.") | |
except Exception as e: | |
logger.error(f"Failed to preload Stable Diffusion model: {e}") | |
raise | |
# --- Main Inference Function --- | |
def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str, seed: int): | |
if input_image is None: | |
logger.error("No input image provided.") | |
raise gr.Error("Please upload an image.") | |
logger.info(f"Starting inference with magnification={magnification}, seed={seed}, caption={caption}") | |
with tempfile.TemporaryDirectory() as temp_base_str: | |
temp_base_dir = Path(temp_base_str) | |
input_img_parent_dir = temp_base_dir / "input_images_root" | |
input_img_parent_dir.mkdir(parents=True, exist_ok=True) | |
input_image_filename = "source_image.png" | |
input_image_path = input_img_parent_dir / input_image_filename | |
input_image.save(input_image_path, "PNG") | |
logger.info(f"Input image saved to {input_image_path}") | |
output_img_dir = temp_base_dir / "output_data" | |
output_img_dir.mkdir(parents=True, exist_ok=True) | |
# Check if inference_coz.py exists | |
if not Path("inference_coz.py").exists(): | |
logger.error("inference_coz.py not found in repository.") | |
raise gr.Error("inference_coz.py not found in repository. Please check the Chain-of-Zoom repository.") | |
command = [ | |
"python", "inference_coz.py", | |
"-i", str(input_img_parent_dir), | |
"-o", str(output_img_dir), | |
"--rec_type", "recursive_multiscale", | |
"--prompt_type", "vlm", | |
"--lora_path", LORA_PATH, | |
"--vae_path", VAE_PATH, | |
"--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", | |
"--ram_ft_path", DAPE_PATH, | |
"--efficient_memory", | |
"--magnification", str(magnification), | |
"--seed", str(seed), | |
"--image_num", "16" if DEVICE == "cpu" else "32", | |
] | |
if DEVICE == "cpu": | |
command.append("--no_cuda") # Assumes inference_coz.py supports this flag | |
if caption and caption.strip(): | |
command.extend(["--caption", caption.strip()]) | |
if HF_AUTH_TOKEN: | |
command.extend(["--hf_token", HF_AUTH_TOKEN]) | |
logger.info(f"Running command: {' '.join(command)}") | |
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1) | |
stdout_lines = [] | |
stderr_lines = [] | |
output_image_path_from_log = None | |
# Stream stdout | |
if process.stdout: | |
for line in iter(process.stdout.readline, ""): | |
logger.info(f"[CoZ STDOUT] {line.strip()}") | |
stdout_lines.append(line) | |
if "Saving image to" in line: | |
try: | |
output_image_path_from_log = line.split("Saving image to")[-1].strip() | |
except Exception: | |
pass | |
# Stream stderr | |
if process.stderr: | |
for line in iter(process.stderr.readline, ""): | |
logger.warning(f"[CoZ STDERR] {line.strip()}") | |
stderr_lines.append(line) | |
process.wait() | |
if process.returncode != 0: | |
error_message = f"Chain-of-Zoom failed.\nSTDOUT:\n{''.join(stdout_lines[-5:])}\nSTDERR:\n{''.join(stderr_lines[-5:])}" | |
logger.error(error_message) | |
raise gr.Error(f"Processing failed: {error_message}") | |
# Find output image | |
final_output_image_path = None | |
if output_image_path_from_log and Path(output_image_path_from_log).exists(): | |
final_output_image_path = Path(output_image_path_from_log) | |
else: | |
processed_output_subdir = output_img_dir / input_img_parent_dir.name | |
potential_files = list(processed_output_subdir.glob(f"{Path(input_image_filename).stem}_x{magnification}_*.png")) | |
if potential_files: | |
final_output_image_path = potential_files[0] | |
if not final_output_image_path or not final_output_image_path.exists(): | |
all_files = list(output_img_dir.rglob("*")) | |
logger.error(f"Output image not found in {output_img_dir}. Files found: {all_files}") | |
raise gr.Error(f"Output image not found in {output_img_dir}. Files found: {all_files}") | |
output_image = Image.open(final_output_image_path) | |
logger.info(f"Output image generated: {final_output_image_path}") | |
return output_image | |
# --- Gradio Interface --- | |
css = """ | |
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; } | |
.gr-button { color: white; border-color: black; background: black; } | |
footer { display: none !important; } | |
""" | |
title = "Chain-of-Zoom: Extreme Image Super-Resolution Demo" | |
description = """ | |
Upload an image and select a magnification factor. Provide an optional caption (if empty, a VLM will generate one). | |
Optimized for CPU and GPU environments. Ensure HF_TOKEN is set in Space secrets for model access. | |
[Chain-of-Zoom GitHub](https://github.com/bryanswkim/Chain-of-Zoom) | |
""" | |
article = "<p style='text-align: center;'><a href='https://github.com/bryanswkim/Chain-of-Zoom' target='_blank'>Chain-of-Zoom GitHub</a></p>" | |
logger.info("Initializing Gradio interface...") | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f"<h1 style='text-align: center'>{title}</h1>") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image_gr = gr.Image(type="pil", label="Input Image") | |
magnification_gr = gr.Slider(minimum=2, maximum=16, step=1, value=4, label="Magnification Factor (2x-16x)") | |
caption_gr = gr.Textbox(label="Optional Caption", placeholder="e.g., a photo of a cat") | |
seed_gr = gr.Number(label="Seed", value=42, precision=0) | |
run_button = gr.Button("Zoom In!", variant="primary") | |
with gr.Column(scale=1): | |
output_image_gr = gr.Image(type="pil", label="Output Super-Resolved Image") | |
gr.Markdown(article) | |
run_button.click( | |
fn=run_chain_of_zoom, | |
inputs=[input_image_gr, magnification_gr, caption_gr, seed_gr], | |
outputs=output_image_gr | |
) | |
if __name__ == "__main__": | |
logger.info("Launching Gradio app...") | |
try: | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
logger.info("Gradio app launched successfully.") | |
except Exception as e: | |
logger.error(f"Failed to launch Gradio app: {e}") | |
raise |