import torch import torch.nn as nn from torchvision.utils import make_grid from networks import make_grid as mkgrid import argparse import os import time from cp_dataset import CPDataset, CPDatasetTest, CPDataLoader from networks import ConditionGenerator, VGGLoss, GANLoss, load_checkpoint, save_checkpoint, define_D from tqdm import tqdm from tensorboardX import SummaryWriter from utils import * from torch.utils.data import Subset def iou_metric(y_pred_batch, y_true_batch): B = y_pred_batch.shape[0] iou = 0 for i in range(B): y_pred = y_pred_batch[i] y_true = y_true_batch[i] # y_pred is not one-hot, so need to threshold it y_pred = y_pred > 0.5 y_pred = y_pred.flatten() y_true = y_true.flatten() intersection = torch.sum(y_pred[y_true == 1]) union = torch.sum(y_pred) + torch.sum(y_true) iou += (intersection + 1e-7) / (union - intersection + 1e-7) / B return iou def remove_overlap(seg_out, warped_cm): assert len(warped_cm.shape) == 4 warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm return warped_cm def get_opt(): parser = argparse.ArgumentParser() parser.add_argument("--name", default="test") parser.add_argument("--gpu_ids", default="") parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=8) parser.add_argument('--fp16', action='store_true', help='use amp') parser.add_argument("--dataroot", default="./data/") parser.add_argument("--datamode", default="train") parser.add_argument("--data_list", default="train_pairs.txt") parser.add_argument("--fine_width", type=int, default=192) parser.add_argument("--fine_height", type=int, default=256) parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos') parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos') parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint') parser.add_argument("--tensorboard_count", type=int, default=100) parser.add_argument("--display_count", type=int, default=100) parser.add_argument("--save_count", type=int, default=10000) parser.add_argument("--load_step", type=int, default=0) parser.add_argument("--keep_step", type=int, default=300000) parser.add_argument("--shuffle", action='store_true', help='shuffle input data') parser.add_argument("--semantic_nc", type=int, default=13) parser.add_argument("--output_nc", type=int, default=13) # network parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1") parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu") parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field") parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D") parser.add_argument('--num_D', type=int, default=2, help='Generator ngf') # Cuda availability parser.add_argument('--cuda',default=False, help='cuda or cpu') # training parser.add_argument("--G_D_seperate", action='store_true') parser.add_argument("--no_GAN_loss", action='store_true') parser.add_argument("--lasttvonly", action='store_true') parser.add_argument("--interflowloss", action='store_true', help="Intermediate flow loss") parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad') parser.add_argument('--edgeawaretv', type=str, choices=['no_edge', 'last_only', 'weighted'], default="no_edge", help="Edge aware TV loss") parser.add_argument('--add_lasttv', action='store_true') # test visualize parser.add_argument("--no_test_visualize", action='store_true') parser.add_argument("--num_test_visualize", type=int, default=3) parser.add_argument("--test_datasetting", default="unpaired") parser.add_argument("--test_dataroot", default="./data/") parser.add_argument("--test_data_list", default="test_pairs.txt") # Hyper-parameters parser.add_argument('--G_lr', type=float, default=0.0002, help='Generator initial learning rate for adam') parser.add_argument('--D_lr', type=float, default=0.0002, help='Discriminator initial learning rate for adam') parser.add_argument('--CElamda', type=float, default=10, help='initial learning rate for adam') parser.add_argument('--GANlambda', type=float, default=1) parser.add_argument('--tvlambda', type=float, default=2) parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear']) parser.add_argument('--val_count', type=int, default='1000') parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D") parser.add_argument('--occlusion', action='store_true', help="Occlusion handling") opt = parser.parse_args() return opt def train(opt, train_loader, test_loader, val_loader, board, tocg, D): # Model tocg.cuda() tocg.train() D.cuda() D.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss(opt) if opt.fp16: criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.HalfTensor) else : criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor if opt.gpu_ids else torch.Tensor) # optimizer optimizer_G = torch.optim.Adam(tocg.parameters(), lr=opt.G_lr, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.D_lr, betas=(0.5, 0.999)) for step in tqdm(range(opt.load_step, opt.keep_step)): iter_start_time = time.time() inputs = train_loader.next_batch() # input1 c_paired = inputs['cloth']['paired'].cuda() cm_paired = inputs['cloth_mask']['paired'].cuda() cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() # input2 parse_agnostic = inputs['parse_agnostic'].cuda() densepose = inputs['densepose'].cuda() openpose = inputs['pose'].cuda() # GT label_onehot = inputs['parse_onehot'].cuda() # CE label = inputs['parse'].cuda() # GAN loss parse_cloth_mask = inputs['pcm'].cuda() # L1 im_c = inputs['parse_cloth'].cuda() # VGG # visualization im = inputs['image'] # inputs input1 = torch.cat([c_paired, cm_paired], 1) input2 = torch.cat([parse_agnostic, densepose], 1) # forward flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) # warped cloth mask one hot warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() # fake segmap cloth channel * warped clothmask if opt.clothmask_composition != 'no_composition': if opt.clothmask_composition == 'detach': cloth_mask = torch.ones_like(fake_segmap.detach()) cloth_mask[:, 3:4, :, :] = warped_cm_onehot fake_segmap = fake_segmap * cloth_mask if opt.clothmask_composition == 'warp_grad': cloth_mask = torch.ones_like(fake_segmap.detach()) cloth_mask[:, 3:4, :, :] = warped_clothmask_paired fake_segmap = fake_segmap * cloth_mask if opt.occlusion: warped_clothmask_paired = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired) warped_cloth_paired = warped_cloth_paired * warped_clothmask_paired + torch.ones_like(warped_cloth_paired) * (1-warped_clothmask_paired) # generated fake cloth mask & misalign mask fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long() misalign = fake_clothmask - warped_cm_onehot misalign[misalign < 0.0] = 0.0 # loss warping loss_l1_cloth = criterionL1(warped_clothmask_paired, parse_cloth_mask) loss_vgg = criterionVGG(warped_cloth_paired, im_c) loss_tv = 0 if opt.edgeawaretv == 'no_edge': if not opt.lasttvonly: for flow in flow_list: y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() loss_tv = loss_tv + y_tv + x_tv else: for flow in flow_list[-1:]: y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() loss_tv = loss_tv + y_tv + x_tv else: if opt.edgeawaretv == 'last_only': flow = flow_list[-1] warped_clothmask_paired_down = F.interpolate(warped_clothmask_paired, flow.shape[1:3], mode='bilinear') y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]) x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]) mask_y = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, 1:, :, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :-1, :, :])) mask_x = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, 1:, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, :-1, :])) y_tv = y_tv * mask_y x_tv = x_tv * mask_x y_tv = y_tv.mean() x_tv = x_tv.mean() loss_tv = loss_tv + y_tv + x_tv elif opt.edgeawaretv == 'weighted': for i in range(5): flow = flow_list[i] warped_clothmask_paired_down = F.interpolate(warped_clothmask_paired, flow.shape[1:3], mode='bilinear') y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]) x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]) mask_y = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, 1:, :, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :-1, :, :])) mask_x = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, 1:, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, :-1, :])) y_tv = y_tv * mask_y x_tv = x_tv * mask_x y_tv = y_tv.mean() / (2 ** (4-i)) x_tv = x_tv.mean() / (2 ** (4-i)) loss_tv = loss_tv + y_tv + x_tv if opt.add_lasttv: for flow in flow_list[-1:]: y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() loss_tv = loss_tv + y_tv + x_tv N, _, iH, iW = c_paired.size() # Intermediate flow loss if opt.interflowloss: for i in range(len(flow_list)-1): flow = flow_list[i] N, fH, fW, _ = flow.size() grid = mkgrid(N, iH, iW) flow = F.interpolate(flow.permute(0, 3, 1, 2), size = c_paired.shape[2:], mode=opt.upsample).permute(0, 2, 3, 1) flow_norm = torch.cat([flow[:, :, :, 0:1] / ((fW - 1.0) / 2.0), flow[:, :, :, 1:2] / ((fH - 1.0) / 2.0)], 3) warped_c = F.grid_sample(c_paired, flow_norm + grid, padding_mode='border') warped_cm = F.grid_sample(cm_paired, flow_norm + grid, padding_mode='border') warped_cm = remove_overlap(F.softmax(fake_segmap, dim=1), warped_cm) loss_l1_cloth += criterionL1(warped_cm, parse_cloth_mask) / (2 ** (4-i)) loss_vgg += criterionVGG(warped_c, im_c) / (2 ** (4-i)) # loss segmentation # generator CE_loss = cross_entropy2d(fake_segmap, label_onehot.transpose(0, 1)[0].long()) if opt.no_GAN_loss: loss_G = (10 * loss_l1_cloth + loss_vgg + opt.tvlambda * loss_tv) + (CE_loss * opt.CElamda) # step optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() else: fake_segmap_softmax = torch.softmax(fake_segmap, 1) pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1)) loss_G_GAN = criterionGAN(pred_segmap, True) if not opt.G_D_seperate: # discriminator fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax.detach()),dim=1)) real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label),dim=1)) loss_D_fake = criterionGAN(fake_segmap_pred, False) loss_D_real = criterionGAN(real_segmap_pred, True) # loss sum loss_G = (10 * loss_l1_cloth + loss_vgg +opt.tvlambda * loss_tv) + (CE_loss * opt.CElamda + loss_G_GAN * opt.GANlambda) # warping + seg_generation loss_D = loss_D_fake + loss_D_real # step optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() else: # train G first after that train D # loss G sum loss_G = (10 * loss_l1_cloth + loss_vgg + opt.tvlambda * loss_tv) + (CE_loss * opt.CElamda + loss_G_GAN * opt.GANlambda) # warping + seg_generation # step G optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() # discriminator with torch.no_grad(): _, fake_segmap, _, _ = tocg(input1, input2) fake_segmap_softmax = torch.softmax(fake_segmap, 1) # loss discriminator fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax.detach()),dim=1)) real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label),dim=1)) loss_D_fake = criterionGAN(fake_segmap_pred, False) loss_D_real = criterionGAN(real_segmap_pred, True) loss_D = loss_D_fake + loss_D_real optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() # Vaildation if (step + 1) % opt.val_count == 0: tocg.eval() iou_list = [] with torch.no_grad(): for cnt in range(2000//opt.batch_size): inputs = val_loader.next_batch() # input1 c_paired = inputs['cloth']['paired'].cuda() cm_paired = inputs['cloth_mask']['paired'].cuda() cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() # input2 parse_agnostic = inputs['parse_agnostic'].cuda() densepose = inputs['densepose'].cuda() openpose = inputs['pose'].cuda() # GT label_onehot = inputs['parse_onehot'].cuda() # CE label = inputs['parse'].cuda() # GAN loss parse_cloth_mask = inputs['pcm'].cuda() # L1 im_c = inputs['parse_cloth'].cuda() # VGG # visualization im = inputs['image'] input1 = torch.cat([c_paired, cm_paired], 1) input2 = torch.cat([parse_agnostic, densepose], 1) # forward flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) # fake segmap cloth channel * warped clothmask if opt.clothmask_composition != 'no_composition': if opt.clothmask_composition == 'detach': cloth_mask = torch.ones_like(fake_segmap.detach()) cloth_mask[:, 3:4, :, :] = warped_cm_onehot fake_segmap = fake_segmap * cloth_mask if opt.clothmask_composition == 'warp_grad': cloth_mask = torch.ones_like(fake_segmap.detach()) cloth_mask[:, 3:4, :, :] = warped_clothmask_paired fake_segmap = fake_segmap * cloth_mask # calculate iou iou = iou_metric(F.softmax(fake_segmap, dim=1).detach(), label) iou_list.append(iou.item()) tocg.train() board.add_scalar('val/iou', np.mean(iou_list), step + 1) # tensorboard if (step + 1) % opt.tensorboard_count == 0: # loss G board.add_scalar('Loss/G', loss_G.item(), step + 1) board.add_scalar('Loss/G/l1_cloth', loss_l1_cloth.item(), step + 1) board.add_scalar('Loss/G/vgg', loss_vgg.item(), step + 1) board.add_scalar('Loss/G/tv', loss_tv.item(), step + 1) board.add_scalar('Loss/G/CE', CE_loss.item(), step + 1) if not opt.no_GAN_loss: board.add_scalar('Loss/G/GAN', loss_G_GAN.item(), step + 1) # loss D board.add_scalar('Loss/D', loss_D.item(), step + 1) board.add_scalar('Loss/D/pred_real', loss_D_real.item(), step + 1) board.add_scalar('Loss/D/pred_fake', loss_D_fake.item(), step + 1) grid = make_grid([(c_paired[0].cpu() / 2 + 0.5), (cm_paired[0].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu()), ((densepose.cpu()[0]+1)/2), (im_c[0].cpu() / 2 + 0.5), parse_cloth_mask[0].cpu().expand(3, -1, -1), (warped_cloth_paired[0].cpu().detach() / 2 + 0.5), (warped_cm_onehot[0].cpu().detach()).expand(3, -1, -1), visualize_segmap(label.cpu()), visualize_segmap(fake_segmap.cpu()), (im[0]/2 +0.5), (misalign[0].cpu().detach()).expand(3, -1, -1)], nrow=4) board.add_images('train_images', grid.unsqueeze(0), step + 1) if not opt.no_test_visualize: inputs = test_loader.next_batch() # input1 c_paired = inputs['cloth'][opt.test_datasetting].cuda() cm_paired = inputs['cloth_mask'][opt.test_datasetting].cuda() cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() # input2 parse_agnostic = inputs['parse_agnostic'].cuda() densepose = inputs['densepose'].cuda() openpose = inputs['pose'].cuda() # GT label_onehot = inputs['parse_onehot'].cuda() # CE label = inputs['parse'].cuda() # GAN loss parse_cloth_mask = inputs['pcm'].cuda() # L1 im_c = inputs['parse_cloth'].cuda() # VGG # visualization im = inputs['image'] tocg.eval() with torch.no_grad(): # inputs input1 = torch.cat([c_paired, cm_paired], 1) input2 = torch.cat([parse_agnostic, densepose], 1) # forward flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() if opt.clothmask_composition != 'no_composition': if opt.clothmask_composition == 'detach': cloth_mask = torch.ones_like(fake_segmap) cloth_mask[:,3:4, :, :] = warped_cm_onehot fake_segmap = fake_segmap * cloth_mask if opt.clothmask_composition == 'warp_grad': cloth_mask = torch.ones_like(fake_segmap) cloth_mask[:,3:4, :, :] = warped_clothmask_paired fake_segmap = fake_segmap * cloth_mask if opt.occlusion: warped_clothmask_paired = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired) warped_cloth_paired = warped_cloth_paired * warped_clothmask_paired + torch.ones_like(warped_cloth_paired) * (1-warped_clothmask_paired) # generated fake cloth mask & misalign mask fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long() misalign = fake_clothmask - warped_cm_onehot misalign[misalign < 0.0] = 0.0 for i in range(opt.num_test_visualize): grid = make_grid([(c_paired[i].cpu() / 2 + 0.5), (cm_paired[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2), (im_c[i].cpu() / 2 + 0.5), parse_cloth_mask[i].cpu().expand(3, -1, -1), (warped_cloth_paired[i].cpu().detach() / 2 + 0.5), (warped_cm_onehot[i].cpu().detach()).expand(3, -1, -1), visualize_segmap(label.cpu(), batch=i), visualize_segmap(fake_segmap.cpu(), batch=i), (im[i]/2 +0.5), (misalign[i].cpu().detach()).expand(3, -1, -1)], nrow=4) board.add_images(f'test_images/{i}', grid.unsqueeze(0), step + 1) tocg.train() # display if (step + 1) % opt.display_count == 0: t = time.time() - iter_start_time if not opt.no_GAN_loss: print("step: %8d, time: %.3f\nloss G: %.4f, L1_cloth loss: %.4f, VGG loss: %.4f, TV loss: %.4f CE: %.4f, G GAN: %.4f\nloss D: %.4f, D real: %.4f, D fake: %.4f" % (step + 1, t, loss_G.item(), loss_l1_cloth.item(), loss_vgg.item(), loss_tv.item(), CE_loss.item(), loss_G_GAN.item(), loss_D.item(), loss_D_real.item(), loss_D_fake.item()), flush=True) # save if (step + 1) % opt.save_count == 0: save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_step_%06d.pth' % (step + 1)),opt) save_checkpoint(D, os.path.join(opt.checkpoint_dir, opt.name, 'D_step_%06d.pth' % (step + 1)),opt) def main(): opt = get_opt() print(opt) print("Start to train %s!" % opt.name) os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids # create train dataset & loader train_dataset = CPDataset(opt) train_loader = CPDataLoader(opt, train_dataset) # create test dataset & loader test_loader = None if not opt.no_test_visualize: train_bsize = opt.batch_size opt.batch_size = opt.num_test_visualize opt.dataroot = opt.test_dataroot opt.datamode = 'test' opt.data_list = opt.test_data_list test_dataset = CPDatasetTest(opt) opt.batch_size = train_bsize val_dataset = Subset(test_dataset, np.arange(2000)) test_loader = CPDataLoader(opt, test_dataset) val_loader = CPDataLoader(opt, val_dataset) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name)) # Model input1_nc = 4 # cloth + cloth-mask input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d) D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2 = opt.Ddownx2, Ddropout = opt.Ddropout, n_layers_D=3, spectral = opt.spectral, num_D = opt.num_D) # Load Checkpoint if not opt.tocg_checkpoint == '' and os.path.exists(opt.tocg_checkpoint): load_checkpoint(tocg, opt.tocg_checkpoint) # Train train(opt, train_loader, val_loader, test_loader, board, tocg, D) # Save Checkpoint save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_final.pth'),opt) save_checkpoint(D, os.path.join(opt.checkpoint_dir, opt.name, 'D_final.pth'),opt) print("Finished training %s!" % opt.name) if __name__ == "__main__": main()