File size: 9,797 Bytes
030df3b
 
 
 
 
 
 
0c11cd5
 
b2c1451
 
 
 
 
030df3b
 
 
 
0c11cd5
030df3b
 
6de914c
030df3b
 
 
 
0c11cd5
 
b2c1451
0c11cd5
030df3b
 
 
 
 
b2c1451
030df3b
b2c1451
030df3b
b2c1451
030df3b
 
 
b2c1451
030df3b
 
 
b2c1451
030df3b
 
 
 
 
0c11cd5
b2c1451
 
030df3b
 
b2c1451
030df3b
b2c1451
 
030df3b
b2c1451
 
030df3b
0c11cd5
b2c1451
 
 
 
 
 
 
 
 
 
 
33b3ad3
 
b2c1451
33b3ad3
 
b2c1451
33b3ad3
b2c1451
 
 
030df3b
 
1084fb5
030df3b
b2c1451
030df3b
 
b2c1451
030df3b
 
0c11cd5
030df3b
0c11cd5
030df3b
 
b2c1451
030df3b
 
 
 
0c11cd5
 
b2c1451
0c11cd5
030df3b
 
 
0c11cd5
030df3b
 
 
 
 
 
 
 
 
 
6de914c
030df3b
 
0c11cd5
6de914c
0c11cd5
030df3b
 
0c11cd5
030df3b
 
 
b2c1451
0c11cd5
 
030df3b
 
 
 
 
 
 
b2c1451
030df3b
0c11cd5
030df3b
 
 
0c11cd5
 
030df3b
 
 
b2c1451
030df3b
0c11cd5
030df3b
 
 
0c11cd5
b2c1451
0c11cd5
 
 
030df3b
 
 
 
0c11cd5
 
 
 
 
030df3b
0c11cd5
b2c1451
0c11cd5
030df3b
 
b2c1451
030df3b
 
 
 
 
 
 
 
 
 
 
 
0c11cd5
 
030df3b
0c11cd5
030df3b
b2c1451
030df3b
0c11cd5
030df3b
 
 
 
 
0c11cd5
 
030df3b
 
 
 
0c11cd5
030df3b
 
 
 
 
 
 
 
 
b2c1451
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
import tempfile
from PIL import Image
from huggingface_hub import hf_hub_download, login
import torch
import logging

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Configuration ---
LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
DAPE_PATH = "ckpt/DAPE/DAPE.pth"

CHECKPOINT_FILES_CONFIG = {
    "SR_LoRA": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_LoRA/model_20001.pkl", "target_path": LORA_PATH},
    "SR_VAE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "SR_VAE/vae_encoder_20001.pt", "target_path": VAE_PATH},
    "DAPE": {"repo_id": "bryandmc/Chain-of-Zoom", "filename": "DAPE/DAPE.pth", "target_path": DAPE_PATH},
}

# --- Device Detection ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")

# --- Hugging Face Token ---
HF_AUTH_TOKEN = os.environ.get("HF_TOKEN")
if HF_AUTH_TOKEN:
    try:
        login(token=HF_AUTH_TOKEN)
        logger.info("Successfully logged in to Hugging Face Hub.")
    except Exception as e:
        logger.warning(f"Hugging Face login failed: {e}")
else:
    logger.warning("HF_TOKEN not found. Downloads of gated models may fail.")

# --- Model Download Function ---
def download_coz_support_models():
    logger.info("Checking and downloading CoZ support models...")
    for model_key, model_info in CHECKPOINT_FILES_CONFIG.items():
        target_file_path = Path(model_info["target_path"])
        if not target_file_path.exists():
            logger.info(f"Downloading {model_key} from {model_info['repo_id']}...")
            target_file_path.parent.mkdir(parents=True, exist_ok=True)
            try:
                cached_file_path = hf_hub_download(
                    repo_id=model_info['repo_id'],
                    filename=model_info['filename'],
                    token=HF_AUTH_TOKEN,
                    force_download=False,
                    resume_download=True
                )
                shutil.copy(cached_file_path, target_file_path)
                logger.info(f"{model_key} downloaded to {target_file_path}")
            except Exception as e:
                logger.error(f"Error downloading {model_key}: {e}")
                raise
        else:
            logger.info(f"{model_key} already exists at {target_file_path}")
    logger.info("All CoZ support models checked.")

# Download models at startup
try:
    logger.info("Starting model download...")
    download_coz_support_models()
    logger.info("Model download completed.")
except Exception as e:
    logger.error(f"Failed to download models: {e}")
    raise

# --- Preload Stable Diffusion Model ---
logger.info("Preloading Stable Diffusion model configuration...")
try:
    from diffusers import StableDiffusion3Pipeline
    pipeline = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3-medium-diffusers",
        torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
        use_auth_token=HF_AUTH_TOKEN
    )
    logger.info("Stable Diffusion 3 model configuration preloaded.")
except Exception as e:
    logger.error(f"Failed to preload Stable Diffusion model: {e}")
    raise

# --- Main Inference Function ---
def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str, seed: int):
    if input_image is None:
        logger.error("No input image provided.")
        raise gr.Error("Please upload an image.")

    logger.info(f"Starting inference with magnification={magnification}, seed={seed}, caption={caption}")
    with tempfile.TemporaryDirectory() as temp_base_str:
        temp_base_dir = Path(temp_base_str)
        input_img_parent_dir = temp_base_dir / "input_images_root"
        input_img_parent_dir.mkdir(parents=True, exist_ok=True)
        input_image_filename = "source_image.png"
        input_image_path = input_img_parent_dir / input_image_filename
        input_image.save(input_image_path, "PNG")
        logger.info(f"Input image saved to {input_image_path}")

        output_img_dir = temp_base_dir / "output_data"
        output_img_dir.mkdir(parents=True, exist_ok=True)

        # Check if inference_coz.py exists
        if not Path("inference_coz.py").exists():
            logger.error("inference_coz.py not found in repository.")
            raise gr.Error("inference_coz.py not found in repository. Please check the Chain-of-Zoom repository.")

        command = [
            "python", "inference_coz.py",
            "-i", str(input_img_parent_dir),
            "-o", str(output_img_dir),
            "--rec_type", "recursive_multiscale",
            "--prompt_type", "vlm",
            "--lora_path", LORA_PATH,
            "--vae_path", VAE_PATH,
            "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
            "--ram_ft_path", DAPE_PATH,
            "--efficient_memory",
            "--magnification", str(magnification),
            "--seed", str(seed),
            "--image_num", "16" if DEVICE == "cpu" else "32",
        ]

        if DEVICE == "cpu":
            command.append("--no_cuda")  # Assumes inference_coz.py supports this flag

        if caption and caption.strip():
            command.extend(["--caption", caption.strip()])

        if HF_AUTH_TOKEN:
            command.extend(["--hf_token", HF_AUTH_TOKEN])

        logger.info(f"Running command: {' '.join(command)}")
        process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)

        stdout_lines = []
        stderr_lines = []
        output_image_path_from_log = None

        # Stream stdout
        if process.stdout:
            for line in iter(process.stdout.readline, ""):
                logger.info(f"[CoZ STDOUT] {line.strip()}")
                stdout_lines.append(line)
                if "Saving image to" in line:
                    try:
                        output_image_path_from_log = line.split("Saving image to")[-1].strip()
                    except Exception:
                        pass

        # Stream stderr
        if process.stderr:
            for line in iter(process.stderr.readline, ""):
                logger.warning(f"[CoZ STDERR] {line.strip()}")
                stderr_lines.append(line)

        process.wait()

        if process.returncode != 0:
            error_message = f"Chain-of-Zoom failed.\nSTDOUT:\n{''.join(stdout_lines[-5:])}\nSTDERR:\n{''.join(stderr_lines[-5:])}"
            logger.error(error_message)
            raise gr.Error(f"Processing failed: {error_message}")

        # Find output image
        final_output_image_path = None
        if output_image_path_from_log and Path(output_image_path_from_log).exists():
            final_output_image_path = Path(output_image_path_from_log)
        else:
            processed_output_subdir = output_img_dir / input_img_parent_dir.name
            potential_files = list(processed_output_subdir.glob(f"{Path(input_image_filename).stem}_x{magnification}_*.png"))
            if potential_files:
                final_output_image_path = potential_files[0]

        if not final_output_image_path or not final_output_image_path.exists():
            all_files = list(output_img_dir.rglob("*"))
            logger.error(f"Output image not found in {output_img_dir}. Files found: {all_files}")
            raise gr.Error(f"Output image not found in {output_img_dir}. Files found: {all_files}")

        output_image = Image.open(final_output_image_path)
        logger.info(f"Output image generated: {final_output_image_path}")
        return output_image

# --- Gradio Interface ---
css = """
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
"""

title = "Chain-of-Zoom: Extreme Image Super-Resolution Demo"
description = """
Upload an image and select a magnification factor. Provide an optional caption (if empty, a VLM will generate one).
Optimized for CPU and GPU environments. Ensure HF_TOKEN is set in Space secrets for model access.
[Chain-of-Zoom GitHub](https://github.com/bryanswkim/Chain-of-Zoom)
"""
article = "<p style='text-align: center;'><a href='https://github.com/bryanswkim/Chain-of-Zoom' target='_blank'>Chain-of-Zoom GitHub</a></p>"

logger.info("Initializing Gradio interface...")
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"<h1 style='text-align: center'>{title}</h1>")
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=1):
            input_image_gr = gr.Image(type="pil", label="Input Image")
            magnification_gr = gr.Slider(minimum=2, maximum=16, step=1, value=4, label="Magnification Factor (2x-16x)")
            caption_gr = gr.Textbox(label="Optional Caption", placeholder="e.g., a photo of a cat")
            seed_gr = gr.Number(label="Seed", value=42, precision=0)
            run_button = gr.Button("Zoom In!", variant="primary")
        with gr.Column(scale=1):
            output_image_gr = gr.Image(type="pil", label="Output Super-Resolved Image")

    gr.Markdown(article)

    run_button.click(
        fn=run_chain_of_zoom,
        inputs=[input_image_gr, magnification_gr, caption_gr, seed_gr],
        outputs=output_image_gr
    )

if __name__ == "__main__":
    logger.info("Launching Gradio app...")
    try:
        demo.launch(server_name="0.0.0.0", server_port=7860)
        logger.info("Gradio app launched successfully.")
    except Exception as e:
        logger.error(f"Failed to launch Gradio app: {e}")
        raise