| import argparse |
| import os |
| import yaml |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torchvision |
| import utils |
| from models.unet import DiffusionUNet |
| import torchdiffeq |
| import math |
| from torchvision.transforms.functional import crop |
|
|
|
|
| def dict2namespace(config): |
| namespace = argparse.Namespace() |
| for key, value in config.items(): |
| if isinstance(value, dict): |
| new_value = dict2namespace(value) |
| else: |
| new_value = value |
| setattr(namespace, key, new_value) |
| return namespace |
|
|
|
|
| def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): |
| def sigmoid(x): |
| return 1 / (np.exp(-x) + 1) |
|
|
| if beta_schedule == "quad": |
| betas = ( |
| np.linspace( |
| beta_start**0.5, |
| beta_end**0.5, |
| num_diffusion_timesteps, |
| dtype=np.float64, |
| ) |
| ** 2 |
| ) |
| elif beta_schedule == "linear": |
| betas = np.linspace( |
| beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 |
| ) |
| elif beta_schedule == "const": |
| betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) |
| elif beta_schedule == "jsd": |
| betas = 1.0 / np.linspace( |
| num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 |
| ) |
| elif beta_schedule == "sigmoid": |
| betas = np.linspace(-6, 6, num_diffusion_timesteps) |
| betas = sigmoid(betas) * (beta_end - beta_start) + beta_start |
| else: |
| raise NotImplementedError(beta_schedule) |
| return betas |
|
|
|
|
| class VPDiffusionFlow: |
| def __init__(self, args, config): |
| self.args = args |
| self.flow_mode = getattr(args, "flow_mode", "vp") |
| self.config = config |
| self.device = config.device |
|
|
| |
| self.model = DiffusionUNet(config) |
| self.model.to(self.device) |
| |
|
|
| |
| self.num_timesteps = config.diffusion.num_diffusion_timesteps |
| betas = get_beta_schedule( |
| beta_schedule=config.diffusion.beta_schedule, |
| beta_start=config.diffusion.beta_start, |
| beta_end=config.diffusion.beta_end, |
| num_diffusion_timesteps=self.num_timesteps, |
| ) |
| self.betas = torch.from_numpy(betas).float().to(self.device) |
|
|
| |
| |
| |
| |
| |
| |
| |
| self.beta_start = config.diffusion.beta_start |
| self.beta_end = config.diffusion.beta_end |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
|
| def load_ckpt(self, load_path): |
| checkpoint = torch.load(load_path, map_location=self.device) |
| |
| if "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| else: |
| state_dict = checkpoint |
|
|
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith("module."): |
| new_state_dict[k[7:]] = v |
| else: |
| new_state_dict[k] = v |
| state_dict = new_state_dict |
|
|
| self.model.load_state_dict(state_dict, strict=True) |
| print(f"=> loaded checkpoint '{load_path}'") |
| self.model.eval() |
|
|
| def get_beta_t(self, t): |
| |
| |
| scalar_t = t.item() if isinstance(t, torch.Tensor) else t |
| |
| scalar_t = max(0.0, min(1.0, scalar_t)) |
| return self.beta_start + scalar_t * (self.beta_end - self.beta_start) |
|
|
| def get_alpha_bar_t(self, t): |
| |
| scalar_t = t.item() if isinstance(t, torch.Tensor) else t |
| scalar_t = max(0.0, min(1.0, scalar_t)) |
|
|
| N = self.num_timesteps |
|
|
| |
| |
| b0 = self.beta_start |
| b1 = self.beta_end |
| integral = N * (b0 * scalar_t + 0.5 * (b1 - b0) * scalar_t**2) |
| return math.exp(-integral) |
|
|
| def overlapping_grid_indices(self, x_cond, output_size, r=None): |
| _, c, h, w = x_cond.shape |
| r = 16 if r is None else r |
| h_list = [i for i in range(0, h - output_size + 1, r)] |
| w_list = [i for i in range(0, w - output_size + 1, r)] |
| return h_list, w_list |
|
|
| def get_blending_window(self, patch_size): |
| |
| |
| w = torch.hann_window(patch_size, periodic=False, device=self.device) |
| w2d = w.unsqueeze(0) * w.unsqueeze(1) |
| return w2d.view(1, 1, patch_size, patch_size) |
|
|
| def get_velocity(self, x, t, x_cond, patch_size=None, r_stride=16): |
| |
| if patch_size is None or ( |
| x.shape[2] == patch_size and x.shape[3] == patch_size |
| ): |
| return self._get_velocity_single(x, t, x_cond) |
|
|
| |
| N = self.num_timesteps |
| t_idx = min(int(t * N), N - 1) |
| t_input_scalar = t_idx |
|
|
| |
| |
| pad_size = patch_size // 2 |
| x_padded = torch.nn.functional.pad( |
| x, (pad_size, pad_size, pad_size, pad_size), mode="reflect" |
| ) |
| x_cond_padded = torch.nn.functional.pad( |
| x_cond, (pad_size, pad_size, pad_size, pad_size), mode="reflect" |
| ) |
|
|
| |
| h_list, w_list = self.overlapping_grid_indices(x_padded, patch_size, r=r_stride) |
| corners = [(i, j) for i in h_list for j in w_list] |
|
|
| |
| window = self.get_blending_window(patch_size) |
|
|
| |
| x_grid_mask = torch.zeros_like(x_padded, device=self.device) |
| for hi, wi in corners: |
| x_grid_mask[:, :, hi : hi + patch_size, wi : wi + patch_size] += window |
|
|
| |
| output_accum = torch.zeros_like(x_padded, device=self.device) |
|
|
| |
| batch_size = 64 |
|
|
| |
| if self.flow_mode == "vp": |
| beta_discrete = self.get_beta_t(t) |
| beta_cont = beta_discrete * N |
| ab = self.alphas_cumprod[t_idx] |
|
|
| |
| |
| for i in range(0, len(corners), batch_size): |
| batch_corners = corners[i : i + batch_size] |
|
|
| |
| x_batch = torch.cat( |
| [ |
| crop(x_padded, hi, wi, patch_size, patch_size) |
| for (hi, wi) in batch_corners |
| ], |
| dim=0, |
| ) |
| cond_batch = torch.cat( |
| [ |
| crop(x_cond_padded, hi, wi, patch_size, patch_size) |
| for (hi, wi) in batch_corners |
| ], |
| dim=0, |
| ) |
| t_batch = torch.tensor( |
| [t_input_scalar] * x_batch.size(0), device=self.device |
| ) |
|
|
| with torch.no_grad(): |
| model_output = self.model( |
| torch.cat([cond_batch, x_batch], dim=1), t_batch |
| ) |
|
|
| |
| |
| weighted_output = model_output * window |
|
|
| for idx, (hi, wi) in enumerate(batch_corners): |
| output_accum[0, :, hi : hi + patch_size, wi : wi + patch_size] += ( |
| weighted_output[idx] |
| ) |
|
|
| |
| |
| model_output_full = torch.div(output_accum, x_grid_mask + 1e-8) |
|
|
| |
| |
| |
| if pad_size > 0: |
| model_output_full = model_output_full[ |
| :, :, pad_size:-pad_size, pad_size:-pad_size |
| ] |
|
|
| |
| if self.flow_mode == "reflow": |
| |
| v = model_output_full |
| else: |
| |
| epsilon = model_output_full |
| coeff1 = -0.5 * beta_cont |
| coeff2 = 0.5 * beta_cont / torch.sqrt(1 - ab) |
| v = coeff1 * x + coeff2 * epsilon |
|
|
| return v |
|
|
| def _get_velocity_single(self, x, t, x_cond): |
| |
| |
| |
| |
|
|
| N = self.num_timesteps |
| t_idx = min(int(t * N), N - 1) |
| t_input = torch.tensor([t_idx] * x.size(0), device=self.device) |
|
|
| with torch.no_grad(): |
| model_output = self.model(torch.cat([x_cond, x], dim=1), t_input) |
|
|
| if self.flow_mode == "reflow": |
| return model_output |
| else: |
| epsilon = model_output |
| beta_discrete = self.get_beta_t(t) |
| beta_cont = beta_discrete * N |
| ab = self.alphas_cumprod[t_idx] |
|
|
| coeff1 = -0.5 * beta_cont |
| coeff2 = 0.5 * beta_cont / torch.sqrt(1 - ab) |
|
|
| v = coeff1 * x + coeff2 * epsilon |
| return v |
|
|
|
|
| def ode_solve( |
| flow_model, |
| x_init, |
| x_cond, |
| steps=100, |
| method="dopri5", |
| patch_size=64, |
| atol=1e-4, |
| rtol=1e-4, |
| ): |
| |
| step = 0 |
|
|
| print(f"ODE Solve: Method={method}, Steps={steps}, atol={atol}, rtol={rtol}") |
|
|
| def drift_func(t, x): |
| nonlocal step |
| step += 1 |
| print(f"Step {step}, t={t.item():.6f}") |
| |
| return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size) |
|
|
| t_eval = torch.linspace(1.0, 0.0, steps + 1, device=x_init.device) |
| out = torchdiffeq.odeint( |
| drift_func, x_init, t_eval, method=method, atol=atol, rtol=rtol |
| ) |
| return out[-1] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--resume", type=str, required=True) |
| parser.add_argument( |
| "--data_dir", type=str, default=None, help="Override data_dir in config" |
| ) |
| parser.add_argument( |
| "--dataset", type=str, default=None, help="Override dataset name" |
| ) |
| parser.add_argument("--steps", type=int, default=100) |
| parser.add_argument("--output", type=str, default="results/diff2flow") |
| parser.add_argument("--seed", type=int, default=61) |
| parser.add_argument( |
| "--patch_size", type=int, default=64, help="Patch size for model" |
| ) |
| parser.add_argument( |
| "--method", type=str, default="dopri5", help="ODE solver method" |
| ) |
| parser.add_argument( |
| "--atol", type=float, default=1e-4, help="Absolute tolerance for ODE solver" |
| ) |
| parser.add_argument( |
| "--rtol", type=float, default=1e-4, help="Relative tolerance for ODE solver" |
| ) |
| parser.add_argument( |
| "--flow_mode", |
| type=str, |
| default="vp", |
| choices=["vp", "reflow"], |
| help="Flow mode: vp (default) or reflow", |
| ) |
| args = parser.parse_args() |
|
|
| |
| with open(os.path.join("configs", args.config), "r") as f: |
| config_dict = yaml.safe_load(f) |
| config = dict2namespace(config_dict) |
|
|
| if args.data_dir: |
| config.data.data_dir = args.data_dir |
| if args.dataset: |
| config.data.dataset = args.dataset |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| config.device = device |
|
|
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| print("Initializing VPDiffusionFlow...") |
| flow = VPDiffusionFlow(args, config) |
| flow.load_ckpt(args.resume) |
|
|
| os.makedirs(args.output, exist_ok=True) |
|
|
| import datasets |
|
|
| print(f"Loading dataset {config.data.dataset}...") |
| DATASET = datasets.__dict__[config.data.dataset](config) |
| |
| _, val_loader = DATASET.get_loaders( |
| parse_patches=False, |
| validation=config.data.dataset if args.dataset else "raindrop", |
| ) |
|
|
| for i, (x_batch, img_id) in enumerate(val_loader): |
| print(f"Processing image {img_id}...") |
|
|
| x_batch = x_batch.to(device) |
| |
|
|
| x_cond = x_batch[:, :3, :, :] |
| |
|
|
| x_cond = utils.sampling.data_transform(x_cond) |
|
|
| B, C, H, W = x_cond.shape |
| x_init = torch.randn(B, 3, H, W, device=device) |
|
|
| print(f"Starting flow matching inference for image {img_id}, shape {H}x{W}...") |
| x_pred = ode_solve( |
| flow, |
| x_init, |
| x_cond, |
| steps=args.steps, |
| patch_size=args.patch_size, |
| method=args.method, |
| atol=args.atol, |
| rtol=args.rtol, |
| ) |
|
|
| x_pred = utils.sampling.inverse_data_transform(x_pred) |
| x_cond_img = utils.sampling.inverse_data_transform(x_cond) |
|
|
| |
| if isinstance(img_id, tuple) or isinstance(img_id, list): |
| idx = img_id[0] |
| else: |
| idx = img_id |
|
|
| utils.logging.save_image( |
| x_cond_img[0], os.path.join(args.output, f"{idx}_input.png") |
| ) |
| utils.logging.save_image( |
| x_pred[0], os.path.join(args.output, f"{idx}_flow.png") |
| ) |
|
|
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|