# coding=utf-8 import torch import torch.nn as nn import torch.nn.functional as F import argparse import os import time from cp_dataset import CPDataset, CPDataLoader from networks import GicLoss, GMM, UnetGenerator, VGGLoss, load_checkpoint, save_checkpoint from tensorboardX import SummaryWriter from visualization import board_add_image, board_add_images def get_opt(): parser = argparse.ArgumentParser() parser.add_argument("--name", default="GMM") # parser.add_argument("--name", default="TOM") parser.add_argument("--gpu_ids", default="") parser.add_argument('-j', '--workers', type=int, default=1) parser.add_argument('-b', '--batch-size', type=int, default=4) parser.add_argument("--dataroot", default="data") parser.add_argument("--datamode", default="train") parser.add_argument("--stage", default="GMM") # parser.add_argument("--stage", default="TOM") parser.add_argument("--data_list", default="train_pairs.txt") parser.add_argument("--fine_width", type=int, default=192) parser.add_argument("--fine_height", type=int, default=256) parser.add_argument("--radius", type=int, default=5) parser.add_argument("--grid_size", type=int, default=5) parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos') parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos') parser.add_argument('--checkpoint', type=str, default='', help='model checkpoint for initialization') parser.add_argument("--display_count", type=int, default=20) parser.add_argument("--save_count", type=int, default=5000) parser.add_argument("--keep_step", type=int, default=100000) parser.add_argument("--decay_step", type=int, default=100000) parser.add_argument("--shuffle", action='store_true', help='shuffle input data') opt = parser.parse_args() return opt def train_gmm(opt, train_loader, model, board): model.cuda() model.train() # criterion criterionL1 = nn.L1Loss() gicloss = GicLoss(opt) # optimizer optimizer = torch.optim.Adam( model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(opt.decay_step + 1)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() im = inputs['image'].cuda() im_pose = inputs['pose_image'].cuda() im_h = inputs['head'].cuda() shape = inputs['shape'].cuda() agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() im_c = inputs['parse_cloth'].cuda() im_g = inputs['grid_image'].cuda() grid, theta = model(agnostic, cm) # can be added c too for new training warped_cloth = F.grid_sample(c, grid, padding_mode='border') warped_mask = F.grid_sample(cm, grid, padding_mode='zeros') warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros') visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c], [warped_grid, (warped_cloth+im)*0.5, im]] # loss for warped cloth Lwarp = criterionL1(warped_cloth, im_c) # changing to previous code as it corresponds to the working code # Actual loss function as in the paper given below (comment out previous line and uncomment below to train as per the paper) # Lwarp = criterionL1(warped_mask, cm) # loss for warped mask thanks @xuxiaochun025 for fixing the git code. # grid regularization loss Lgic = gicloss(grid) # 200x200 = 40.000 * 0.001 Lgic = Lgic / (grid.shape[0] * grid.shape[1] * grid.shape[2]) loss = Lwarp + 40 * Lgic # total GMM loss optimizer.zero_grad() loss.backward() optimizer.step() if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) board.add_scalar('loss', loss.item(), step+1) board.add_scalar('40*Lgic', (40*Lgic).item(), step+1) board.add_scalar('Lwarp', Lwarp.item(), step+1) t = time.time() - iter_start_time print('step: %8d, time: %.3f, loss: %4f, (40*Lgic): %.8f, Lwarp: %.6f' % (step+1, t, loss.item(), (40*Lgic).item(), Lwarp.item()), flush=True) if (step+1) % opt.save_count == 0: save_checkpoint(model, os.path.join( opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1))) def train_tom(opt, train_loader, model, board): model.cuda() model.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() # optimizer optimizer = torch.optim.Adam( model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(opt.decay_step + 1)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() im = inputs['image'].cuda() im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() pcm = inputs['parse_cloth_mask'].cuda() # outputs = model(torch.cat([agnostic, c], 1)) # CP-VTON outputs = model(torch.cat([agnostic, c, cm], 1)) # CP-VTON+ p_rendered, m_composite = torch.split(outputs, 3, 1) p_rendered = F.tanh(p_rendered) m_composite = F.sigmoid(m_composite) p_tryon = c * m_composite + p_rendered * (1 - m_composite) """visuals = [[im_h, shape, im_pose], [c, cm*2-1, m_composite*2-1], [p_rendered, p_tryon, im]]""" # CP-VTON visuals = [[im_h, shape, im_pose], [c, pcm*2-1, m_composite*2-1], [p_rendered, p_tryon, im]] # CP-VTON+ loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) # loss_mask = criterionMask(m_composite, cm) # CP-VTON loss_mask = criterionMask(m_composite, pcm) # CP-VTON+ loss = loss_l1 + loss_vgg + loss_mask optimizer.zero_grad() loss.backward() optimizer.step() if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) board.add_scalar('metric', loss.item(), step+1) board.add_scalar('L1', loss_l1.item(), step+1) board.add_scalar('VGG', loss_vgg.item(), step+1) board.add_scalar('MaskL1', loss_mask.item(), step+1) t = time.time() - iter_start_time print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()), flush=True) if (step+1) % opt.save_count == 0: save_checkpoint(model, os.path.join( opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1))) def main(): opt = get_opt() print(opt) print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name)) # create dataset train_dataset = CPDataset(opt) # create dataloader train_loader = CPDataLoader(opt, train_dataset) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name)) # create model & train & save the final checkpoint if opt.stage == 'GMM': model = GMM(opt) if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) train_gmm(opt, train_loader, model, board) save_checkpoint(model, os.path.join( opt.checkpoint_dir, opt.name, 'gmm_final.pth')) elif opt.stage == 'TOM': # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON model = UnetGenerator( 26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON+ if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) train_tom(opt, train_loader, model, board) save_checkpoint(model, os.path.join( opt.checkpoint_dir, opt.name, 'tom_final.pth')) else: raise NotImplementedError('Model [%s] is not implemented' % opt.stage) print('Finished training %s, named: %s!' % (opt.stage, opt.name)) if __name__ == "__main__": main()