Spaces:
Paused
Paused
import os | |
import gc | |
import copy | |
import lpips | |
import torch | |
import wandb | |
from glob import glob | |
import numpy as np | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed | |
from PIL import Image | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer, CLIPTextModel | |
from diffusers.optimization import get_scheduler | |
from peft.utils import get_peft_model_state_dict | |
from cleanfid.fid import get_folder_features, build_feature_extractor, frechet_distance | |
import vision_aided_loss | |
from model import make_1step_sched | |
from cyclegan_turbo import CycleGAN_Turbo, VAE_encode, VAE_decode, initialize_unet, initialize_vae | |
from my_utils.training_utils import UnpairedDataset, build_transform, parse_args_unpaired_training | |
from my_utils.dino_struct import DinoStructureLoss | |
def main(args): | |
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with=args.report_to) | |
set_seed(args.seed) | |
if accelerator.is_main_process: | |
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) | |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer", revision=args.revision, use_fast=False,) | |
noise_scheduler_1step = make_1step_sched() | |
text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda() | |
unet, l_modules_unet_encoder, l_modules_unet_decoder, l_modules_unet_others = initialize_unet(args.lora_rank_unet, return_lora_module_names=True) | |
vae_a2b, vae_lora_target_modules = initialize_vae(args.lora_rank_vae, return_lora_module_names=True) | |
weight_dtype = torch.float32 | |
vae_a2b.to(accelerator.device, dtype=weight_dtype) | |
text_encoder.to(accelerator.device, dtype=weight_dtype) | |
unet.to(accelerator.device, dtype=weight_dtype) | |
text_encoder.requires_grad_(False) | |
if args.gan_disc_type == "vagan_clip": | |
net_disc_a = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda") | |
net_disc_a.cv_ensemble.requires_grad_(False) # Freeze feature extractor | |
net_disc_b = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda") | |
net_disc_b.cv_ensemble.requires_grad_(False) # Freeze feature extractor | |
crit_cycle, crit_idt = torch.nn.L1Loss(), torch.nn.L1Loss() | |
if args.enable_xformers_memory_efficient_attention: | |
unet.enable_xformers_memory_efficient_attention() | |
if args.gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if args.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
unet.conv_in.requires_grad_(True) | |
vae_b2a = copy.deepcopy(vae_a2b) | |
params_gen = CycleGAN_Turbo.get_traininable_params(unet, vae_a2b, vae_b2a) | |
vae_enc = VAE_encode(vae_a2b, vae_b2a=vae_b2a) | |
vae_dec = VAE_decode(vae_a2b, vae_b2a=vae_b2a) | |
optimizer_gen = torch.optim.AdamW(params_gen, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), | |
weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,) | |
params_disc = list(net_disc_a.parameters()) + list(net_disc_b.parameters()) | |
optimizer_disc = torch.optim.AdamW(params_disc, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), | |
weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,) | |
dataset_train = UnpairedDataset(dataset_folder=args.dataset_folder, image_prep=args.train_img_prep, split="train", tokenizer=tokenizer) | |
train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) | |
T_val = build_transform(args.val_img_prep) | |
fixed_caption_src = dataset_train.fixed_caption_src | |
fixed_caption_tgt = dataset_train.fixed_caption_tgt | |
l_images_src_test = [] | |
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]: | |
l_images_src_test.extend(glob(os.path.join(args.dataset_folder, "test_A", ext))) | |
l_images_tgt_test = [] | |
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]: | |
l_images_tgt_test.extend(glob(os.path.join(args.dataset_folder, "test_B", ext))) | |
l_images_src_test, l_images_tgt_test = sorted(l_images_src_test), sorted(l_images_tgt_test) | |
# make the reference FID statistics | |
if accelerator.is_main_process: | |
feat_model = build_feature_extractor("clean", "cuda", use_dataparallel=False) | |
""" | |
FID reference statistics for A -> B translation | |
""" | |
output_dir_ref = os.path.join(args.output_dir, "fid_reference_a2b") | |
os.makedirs(output_dir_ref, exist_ok=True) | |
# transform all images according to the validation transform and save them | |
for _path in tqdm(l_images_tgt_test): | |
_img = T_val(Image.open(_path).convert("RGB")) | |
outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(".jpg", ".png") | |
if not os.path.exists(outf): | |
_img.save(outf) | |
# compute the features for the reference images | |
ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None, | |
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"), | |
mode="clean", custom_fn_resize=None, description="", verbose=True, | |
custom_image_tranform=None) | |
a2b_ref_mu, a2b_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False) | |
""" | |
FID reference statistics for B -> A translation | |
""" | |
# transform all images according to the validation transform and save them | |
output_dir_ref = os.path.join(args.output_dir, "fid_reference_b2a") | |
os.makedirs(output_dir_ref, exist_ok=True) | |
for _path in tqdm(l_images_src_test): | |
_img = T_val(Image.open(_path).convert("RGB")) | |
outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(".jpg", ".png") | |
if not os.path.exists(outf): | |
_img.save(outf) | |
# compute the features for the reference images | |
ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None, | |
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"), | |
mode="clean", custom_fn_resize=None, description="", verbose=True, | |
custom_image_tranform=None) | |
b2a_ref_mu, b2a_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False) | |
lr_scheduler_gen = get_scheduler(args.lr_scheduler, optimizer=optimizer_gen, | |
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | |
num_training_steps=args.max_train_steps * accelerator.num_processes, | |
num_cycles=args.lr_num_cycles, power=args.lr_power) | |
lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc, | |
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | |
num_training_steps=args.max_train_steps * accelerator.num_processes, | |
num_cycles=args.lr_num_cycles, power=args.lr_power) | |
net_lpips = lpips.LPIPS(net='vgg') | |
net_lpips.cuda() | |
net_lpips.requires_grad_(False) | |
fixed_a2b_tokens = tokenizer(fixed_caption_tgt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0] | |
fixed_a2b_emb_base = text_encoder(fixed_a2b_tokens.cuda().unsqueeze(0))[0].detach() | |
fixed_b2a_tokens = tokenizer(fixed_caption_src, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0] | |
fixed_b2a_emb_base = text_encoder(fixed_b2a_tokens.cuda().unsqueeze(0))[0].detach() | |
del text_encoder, tokenizer # free up some memory | |
unet, vae_enc, vae_dec, net_disc_a, net_disc_b = accelerator.prepare(unet, vae_enc, vae_dec, net_disc_a, net_disc_b) | |
net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc = accelerator.prepare( | |
net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc | |
) | |
if accelerator.is_main_process: | |
accelerator.init_trackers(args.tracker_project_name, config=dict(vars(args))) | |
first_epoch = 0 | |
global_step = 0 | |
progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, desc="Steps", | |
disable=not accelerator.is_local_main_process,) | |
# turn off eff. attn for the disc | |
for name, module in net_disc_a.named_modules(): | |
if "attn" in name: | |
module.fused_attn = False | |
for name, module in net_disc_b.named_modules(): | |
if "attn" in name: | |
module.fused_attn = False | |
for epoch in range(first_epoch, args.max_train_epochs): | |
for step, batch in enumerate(train_dataloader): | |
l_acc = [unet, net_disc_a, net_disc_b, vae_enc, vae_dec] | |
with accelerator.accumulate(*l_acc): | |
img_a = batch["pixel_values_src"].to(dtype=weight_dtype) | |
img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype) | |
bsz = img_a.shape[0] | |
fixed_a2b_emb = fixed_a2b_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype) | |
fixed_b2a_emb = fixed_b2a_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype) | |
timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * bsz, device=img_a.device).long() | |
""" | |
Cycle Objective | |
""" | |
# A -> fake B -> rec A | |
cyc_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb) | |
cyc_rec_a = CycleGAN_Turbo.forward_with_networks(cyc_fake_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb) | |
loss_cycle_a = crit_cycle(cyc_rec_a, img_a) * args.lambda_cycle | |
loss_cycle_a += net_lpips(cyc_rec_a, img_a).mean() * args.lambda_cycle_lpips | |
# B -> fake A -> rec B | |
cyc_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb) | |
cyc_rec_b = CycleGAN_Turbo.forward_with_networks(cyc_fake_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb) | |
loss_cycle_b = crit_cycle(cyc_rec_b, img_b) * args.lambda_cycle | |
loss_cycle_b += net_lpips(cyc_rec_b, img_b).mean() * args.lambda_cycle_lpips | |
accelerator.backward(loss_cycle_a + loss_cycle_b, retain_graph=False) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm) | |
optimizer_gen.step() | |
lr_scheduler_gen.step() | |
optimizer_gen.zero_grad() | |
""" | |
Generator Objective (GAN) for task a->b and b->a (fake inputs) | |
""" | |
fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb) | |
fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb) | |
loss_gan_a = net_disc_a(fake_b, for_G=True).mean() * args.lambda_gan | |
loss_gan_b = net_disc_b(fake_a, for_G=True).mean() * args.lambda_gan | |
accelerator.backward(loss_gan_a + loss_gan_b, retain_graph=False) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm) | |
optimizer_gen.step() | |
lr_scheduler_gen.step() | |
optimizer_gen.zero_grad() | |
""" | |
Identity Objective | |
""" | |
idt_a = CycleGAN_Turbo.forward_with_networks(img_b, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb) | |
loss_idt_a = crit_idt(idt_a, img_b) * args.lambda_idt | |
loss_idt_a += net_lpips(idt_a, img_b).mean() * args.lambda_idt_lpips | |
idt_b = CycleGAN_Turbo.forward_with_networks(img_a, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb) | |
loss_idt_b = crit_idt(idt_b, img_a) * args.lambda_idt | |
loss_idt_b += net_lpips(idt_b, img_a).mean() * args.lambda_idt_lpips | |
loss_g_idt = loss_idt_a + loss_idt_b | |
accelerator.backward(loss_g_idt, retain_graph=False) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm) | |
optimizer_gen.step() | |
lr_scheduler_gen.step() | |
optimizer_gen.zero_grad() | |
""" | |
Discriminator for task a->b and b->a (fake inputs) | |
""" | |
loss_D_A_fake = net_disc_a(fake_b.detach(), for_real=False).mean() * args.lambda_gan | |
loss_D_B_fake = net_disc_b(fake_a.detach(), for_real=False).mean() * args.lambda_gan | |
loss_D_fake = (loss_D_A_fake + loss_D_B_fake) * 0.5 | |
accelerator.backward(loss_D_fake, retain_graph=False) | |
if accelerator.sync_gradients: | |
params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters()) | |
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | |
optimizer_disc.step() | |
lr_scheduler_disc.step() | |
optimizer_disc.zero_grad() | |
""" | |
Discriminator for task a->b and b->a (real inputs) | |
""" | |
loss_D_A_real = net_disc_a(img_b, for_real=True).mean() * args.lambda_gan | |
loss_D_B_real = net_disc_b(img_a, for_real=True).mean() * args.lambda_gan | |
loss_D_real = (loss_D_A_real + loss_D_B_real) * 0.5 | |
accelerator.backward(loss_D_real, retain_graph=False) | |
if accelerator.sync_gradients: | |
params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters()) | |
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | |
optimizer_disc.step() | |
lr_scheduler_disc.step() | |
optimizer_disc.zero_grad() | |
logs = {} | |
logs["cycle_a"] = loss_cycle_a.detach().item() | |
logs["cycle_b"] = loss_cycle_b.detach().item() | |
logs["gan_a"] = loss_gan_a.detach().item() | |
logs["gan_b"] = loss_gan_b.detach().item() | |
logs["disc_a"] = loss_D_A_fake.detach().item() + loss_D_A_real.detach().item() | |
logs["disc_b"] = loss_D_B_fake.detach().item() + loss_D_B_real.detach().item() | |
logs["idt_a"] = loss_idt_a.detach().item() | |
logs["idt_b"] = loss_idt_b.detach().item() | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
if accelerator.is_main_process: | |
eval_unet = accelerator.unwrap_model(unet) | |
eval_vae_enc = accelerator.unwrap_model(vae_enc) | |
eval_vae_dec = accelerator.unwrap_model(vae_dec) | |
if global_step % args.viz_freq == 1: | |
for tracker in accelerator.trackers: | |
if tracker.name == "wandb": | |
viz_img_a = batch["pixel_values_src"].to(dtype=weight_dtype) | |
viz_img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype) | |
log_dict = { | |
"train/real_a": [wandb.Image(viz_img_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)], | |
"train/real_b": [wandb.Image(viz_img_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)], | |
} | |
log_dict["train/rec_a"] = [wandb.Image(cyc_rec_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)] | |
log_dict["train/rec_b"] = [wandb.Image(cyc_rec_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)] | |
log_dict["train/fake_b"] = [wandb.Image(fake_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)] | |
log_dict["train/fake_a"] = [wandb.Image(fake_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)] | |
tracker.log(log_dict) | |
gc.collect() | |
torch.cuda.empty_cache() | |
if global_step % args.checkpointing_steps == 1: | |
outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") | |
sd = {} | |
sd["l_target_modules_encoder"] = l_modules_unet_encoder | |
sd["l_target_modules_decoder"] = l_modules_unet_decoder | |
sd["l_modules_others"] = l_modules_unet_others | |
sd["rank_unet"] = args.lora_rank_unet | |
sd["sd_encoder"] = get_peft_model_state_dict(eval_unet, adapter_name="default_encoder") | |
sd["sd_decoder"] = get_peft_model_state_dict(eval_unet, adapter_name="default_decoder") | |
sd["sd_other"] = get_peft_model_state_dict(eval_unet, adapter_name="default_others") | |
sd["rank_vae"] = args.lora_rank_vae | |
sd["vae_lora_target_modules"] = vae_lora_target_modules | |
sd["sd_vae_enc"] = eval_vae_enc.state_dict() | |
sd["sd_vae_dec"] = eval_vae_dec.state_dict() | |
torch.save(sd, outf) | |
gc.collect() | |
torch.cuda.empty_cache() | |
# compute val FID and DINO-Struct scores | |
if global_step % args.validation_steps == 1: | |
_timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * 1, device="cuda").long() | |
net_dino = DinoStructureLoss() | |
""" | |
Evaluate "A->B" | |
""" | |
fid_output_dir = os.path.join(args.output_dir, f"fid-{global_step}/samples_a2b") | |
os.makedirs(fid_output_dir, exist_ok=True) | |
l_dino_scores_a2b = [] | |
# get val input images from domain a | |
for idx, input_img_path in enumerate(tqdm(l_images_src_test)): | |
if idx > args.validation_num_images and args.validation_num_images > 0: | |
break | |
outf = os.path.join(fid_output_dir, f"{idx}.png") | |
with torch.no_grad(): | |
input_img = T_val(Image.open(input_img_path).convert("RGB")) | |
img_a = transforms.ToTensor()(input_img) | |
img_a = transforms.Normalize([0.5], [0.5])(img_a).unsqueeze(0).cuda() | |
eval_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", eval_vae_enc, eval_unet, | |
eval_vae_dec, noise_scheduler_1step, _timesteps, fixed_a2b_emb[0:1]) | |
eval_fake_b_pil = transforms.ToPILImage()(eval_fake_b[0] * 0.5 + 0.5) | |
eval_fake_b_pil.save(outf) | |
a = net_dino.preprocess(input_img).unsqueeze(0).cuda() | |
b = net_dino.preprocess(eval_fake_b_pil).unsqueeze(0).cuda() | |
dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item() | |
l_dino_scores_a2b.append(dino_ssim) | |
dino_score_a2b = np.mean(l_dino_scores_a2b) | |
gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None, | |
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"), | |
mode="clean", custom_fn_resize=None, description="", verbose=True, | |
custom_image_tranform=None) | |
ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False) | |
score_fid_a2b = frechet_distance(a2b_ref_mu, a2b_ref_sigma, ed_mu, ed_sigma) | |
print(f"step={global_step}, fid(a2b)={score_fid_a2b:.2f}, dino(a2b)={dino_score_a2b:.3f}") | |
""" | |
compute FID for "B->A" | |
""" | |
fid_output_dir = os.path.join(args.output_dir, f"fid-{global_step}/samples_b2a") | |
os.makedirs(fid_output_dir, exist_ok=True) | |
l_dino_scores_b2a = [] | |
# get val input images from domain b | |
for idx, input_img_path in enumerate(tqdm(l_images_tgt_test)): | |
if idx > args.validation_num_images and args.validation_num_images > 0: | |
break | |
outf = os.path.join(fid_output_dir, f"{idx}.png") | |
with torch.no_grad(): | |
input_img = T_val(Image.open(input_img_path).convert("RGB")) | |
img_b = transforms.ToTensor()(input_img) | |
img_b = transforms.Normalize([0.5], [0.5])(img_b).unsqueeze(0).cuda() | |
eval_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", eval_vae_enc, eval_unet, | |
eval_vae_dec, noise_scheduler_1step, _timesteps, fixed_b2a_emb[0:1]) | |
eval_fake_a_pil = transforms.ToPILImage()(eval_fake_a[0] * 0.5 + 0.5) | |
eval_fake_a_pil.save(outf) | |
a = net_dino.preprocess(input_img).unsqueeze(0).cuda() | |
b = net_dino.preprocess(eval_fake_a_pil).unsqueeze(0).cuda() | |
dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item() | |
l_dino_scores_b2a.append(dino_ssim) | |
dino_score_b2a = np.mean(l_dino_scores_b2a) | |
gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None, | |
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"), | |
mode="clean", custom_fn_resize=None, description="", verbose=True, | |
custom_image_tranform=None) | |
ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False) | |
score_fid_b2a = frechet_distance(b2a_ref_mu, b2a_ref_sigma, ed_mu, ed_sigma) | |
print(f"step={global_step}, fid(b2a)={score_fid_b2a}, dino(b2a)={dino_score_b2a:.3f}") | |
logs["val/fid_a2b"], logs["val/fid_b2a"] = score_fid_a2b, score_fid_b2a | |
logs["val/dino_struct_a2b"], logs["val/dino_struct_b2a"] = dino_score_a2b, dino_score_b2a | |
del net_dino # free up memory | |
progress_bar.set_postfix(**logs) | |
accelerator.log(logs, step=global_step) | |
if global_step >= args.max_train_steps: | |
break | |
if __name__ == "__main__": | |
args = parse_args_unpaired_training() | |
main(args) | |