import gradio as gr import subprocess import os import shutil from pathlib import Path import tempfile from PIL import Image from huggingface_hub import hf_hub_download, login import torch import logging # --- Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Configuration --- LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl" VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt" DAPE_PATH = "ckpt/DAPE/DAPE.pth" CHECKPOINT_FILES_CONFIG = { "SR_LoRA": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_LoRA/model_20001.pkl", "target_path": LORA_PATH}, "SR_VAE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_VAE/vae_encoder_20001.pt", "target_path": VAE_PATH}, "DAPE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "DAPE/DAPE.pth", "target_path": DAPE_PATH}, } # --- Device Detection --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {DEVICE}") # --- Hugging Face Token --- HF_AUTH_TOKEN = os.environ.get("HF_TOKEN") if HF_AUTH_TOKEN: try: login(token=HF_AUTH_TOKEN) logger.info("Successfully logged in to Hugging Face Hub.") except Exception as e: logger.warning(f"Hugging Face login failed: {e}") else: logger.warning("HF_TOKEN not found. Downloads of gated models may fail.") # --- Model Download Function --- def download_coz_support_models(): logger.info("Checking and downloading CoZ support models...") for model_key, model_info in CHECKPOINT_FILES_CONFIG.items(): target_file_path = Path(model_info["target_path"]) if not target_file_path.exists(): logger.info(f"Downloading {model_key} from {model_info['repo_id']}...") target_file_path.parent.mkdir(parents=True, exist_ok=True) try: cached_file_path = hf_hub_download( repo_id=model_info['repo_id'], filename=model_info['filename'], token=HF_AUTH_TOKEN, force_download=False, resume_download=True ) shutil.copy(cached_file_path, target_file_path) logger.info(f"{model_key} downloaded to {target_file_path}") except Exception as e: logger.error(f"Error downloading {model_key}: {e}") raise else: logger.info(f"{model_key} already exists at {target_file_path}") logger.info("All CoZ support models checked.") # Download models at startup try: logger.info("Starting model download...") download_coz_support_models() logger.info("Model download completed.") except Exception as e: logger.error(f"Failed to download models: {e}") raise # --- Preload Stable Diffusion Model --- logger.info("Preloading Stable Diffusion model configuration...") try: from diffusers import StableDiffusion3Pipeline pipeline = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, use_auth_token=HF_AUTH_TOKEN ) logger.info("Stable Diffusion 3 model configuration preloaded.") except Exception as e: logger.error(f"Failed to preload Stable Diffusion model: {e}") raise # --- Main Inference Function --- def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str, seed: int): if input_image is None: logger.error("No input image provided.") raise gr.Error("Please upload an image.") logger.info(f"Starting inference with magnification={magnification}, seed={seed}, caption={caption}") with tempfile.TemporaryDirectory() as temp_base_str: temp_base_dir = Path(temp_base_str) input_img_parent_dir = temp_base_dir / "input_images_root" input_img_parent_dir.mkdir(parents=True, exist_ok=True) input_image_filename = "source_image.png" input_image_path = input_img_parent_dir / input_image_filename input_image.save(input_image_path, "PNG") logger.info(f"Input image saved to {input_image_path}") output_img_dir = temp_base_dir / "output_data" output_img_dir.mkdir(parents=True, exist_ok=True) # Check if inference_coz.py exists if not Path("inference_coz.py").exists(): logger.error("inference_coz.py not found in repository.") raise gr.Error("inference_coz.py not found in repository. Please check the Chain-of-Zoom repository.") command = [ "python", "inference_coz.py", "-i", str(input_img_parent_dir), "-o", str(output_img_dir), "--rec_type", "recursive_multiscale", "--prompt_type", "vlm", "--lora_path", LORA_PATH, "--vae_path", VAE_PATH, "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", "--ram_ft_path", DAPE_PATH, "--efficient_memory", "--magnification", str(magnification), "--seed", str(seed), "--image_num", "16" if DEVICE == "cpu" else "32", ] if DEVICE == "cpu": command.append("--no_cuda") # Assumes inference_coz.py supports this flag if caption and caption.strip(): command.extend(["--caption", caption.strip()]) if HF_AUTH_TOKEN: command.extend(["--hf_token", HF_AUTH_TOKEN]) logger.info(f"Running command: {' '.join(command)}") process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1) stdout_lines = [] stderr_lines = [] output_image_path_from_log = None # Stream stdout if process.stdout: for line in iter(process.stdout.readline, ""): logger.info(f"[CoZ STDOUT] {line.strip()}") stdout_lines.append(line) if "Saving image to" in line: try: output_image_path_from_log = line.split("Saving image to")[-1].strip() except Exception: pass # Stream stderr if process.stderr: for line in iter(process.stderr.readline, ""): logger.warning(f"[CoZ STDERR] {line.strip()}") stderr_lines.append(line) process.wait() if process.returncode != 0: error_message = f"Chain-of-Zoom failed.\nSTDOUT:\n{''.join(stdout_lines[-5:])}\nSTDERR:\n{''.join(stderr_lines[-5:])}" logger.error(error_message) raise gr.Error(f"Processing failed: {error_message}") # Find output image final_output_image_path = None if output_image_path_from_log and Path(output_image_path_from_log).exists(): final_output_image_path = Path(output_image_path_from_log) else: processed_output_subdir = output_img_dir / input_img_parent_dir.name potential_files = list(processed_output_subdir.glob(f"{Path(input_image_filename).stem}_x{magnification}_*.png")) if potential_files: final_output_image_path = potential_files[0] if not final_output_image_path or not final_output_image_path.exists(): all_files = list(output_img_dir.rglob("*")) logger.error(f"Output image not found in {output_img_dir}. Files found: {all_files}") raise gr.Error(f"Output image not found in {output_img_dir}. Files found: {all_files}") output_image = Image.open(final_output_image_path) logger.info(f"Output image generated: {final_output_image_path}") return output_image # --- Gradio Interface --- css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-color: black; background: black; } footer { display: none !important; } """ title = "Chain-of-Zoom: Extreme Image Super-Resolution Demo" description = """ Upload an image and select a magnification factor. Provide an optional caption (if empty, a VLM will generate one). Optimized for CPU and GPU environments. Ensure HF_TOKEN is set in Space secrets for model access. [Chain-of-Zoom GitHub](https://github.com/bryanswkim/Chain-of-Zoom) """ article = "
" logger.info("Initializing Gradio interface...") with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown(f"