# coding=utf-8 import torch import torch.nn as nn import torch.nn.functional as F import argparse import os import time from cp_dataset import CPDataset, CPDataLoader from networks import GMM, UnetGenerator, load_checkpoint from tensorboardX import SummaryWriter from visualization import board_add_image, board_add_images, save_images def get_opt(): parser = argparse.ArgumentParser() parser.add_argument("--name", default="GMM") # parser.add_argument("--name", default="TOM") parser.add_argument("--gpu_ids", default="") parser.add_argument('-j', '--workers', type=int, default=1) parser.add_argument('-b', '--batch-size', type=int, default=4) parser.add_argument("--dataroot", default="data") # parser.add_argument("--datamode", default="train") parser.add_argument("--datamode", default="test") parser.add_argument("--stage", default="GMM") # parser.add_argument("--stage", default="TOM") # parser.add_argument("--data_list", default="train_pairs.txt") parser.add_argument("--data_list", default="test_pairs.txt") # parser.add_argument("--data_list", default="test_pairs_same.txt") parser.add_argument("--fine_width", type=int, default=192) parser.add_argument("--fine_height", type=int, default=256) parser.add_argument("--radius", type=int, default=5) parser.add_argument("--grid_size", type=int, default=5) parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos') parser.add_argument('--result_dir', type=str, default='result', help='save result infos') parser.add_argument('--checkpoint', type=str, default='checkpoints/GMM/gmm_final.pth', help='model checkpoint for test') # parser.add_argument('--checkpoint', type=str, default='checkpoints/TOM/tom_final.pth', help='model checkpoint for test') parser.add_argument("--display_count", type=int, default=1) parser.add_argument("--shuffle", action='store_true', help='shuffle input data') opt = parser.parse_args() return opt def test_gmm(opt, test_loader, model, board): model.cuda() model.eval() base_name = os.path.basename(opt.checkpoint) name = opt.name save_dir = os.path.join(opt.result_dir, name, opt.datamode) if not os.path.exists(save_dir): os.makedirs(save_dir) warp_cloth_dir = os.path.join(save_dir, 'warp-cloth') if not os.path.exists(warp_cloth_dir): os.makedirs(warp_cloth_dir) warp_mask_dir = os.path.join(save_dir, 'warp-mask') if not os.path.exists(warp_mask_dir): os.makedirs(warp_mask_dir) result_dir1 = os.path.join(save_dir, 'result_dir') if not os.path.exists(result_dir1): os.makedirs(result_dir1) overlayed_TPS_dir = os.path.join(save_dir, 'overlayed_TPS') if not os.path.exists(overlayed_TPS_dir): os.makedirs(overlayed_TPS_dir) warped_grid_dir = os.path.join(save_dir, 'warped_grid') if not os.path.exists(warped_grid_dir): os.makedirs(warped_grid_dir) for step, inputs in enumerate(test_loader.data_loader): iter_start_time = time.time() c_names = inputs['c_name'] im_names = inputs['im_name'] im = inputs['image'].cuda() im_pose = inputs['pose_image'].cuda() im_h = inputs['head'].cuda() shape = inputs['shape'].cuda() agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() im_c = inputs['parse_cloth'].cuda() im_g = inputs['grid_image'].cuda() shape_ori = inputs['shape_ori'] # original body shape without blurring grid, theta = model(agnostic, cm) warped_cloth = F.grid_sample(c, grid, padding_mode='border') warped_mask = F.grid_sample(cm, grid, padding_mode='zeros') warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros') overlay = 0.7 * warped_cloth + 0.3 * im visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c], [warped_grid, (warped_cloth+im)*0.5, im]] # save_images(warped_cloth, c_names, warp_cloth_dir) # save_images(warped_mask*2-1, c_names, warp_mask_dir) save_images(warped_cloth, im_names, warp_cloth_dir) save_images(warped_mask * 2 - 1, im_names, warp_mask_dir) save_images(shape_ori.cuda() * 0.2 + warped_cloth * 0.8, im_names, result_dir1) save_images(warped_grid, im_names, warped_grid_dir) save_images(overlay, im_names, overlayed_TPS_dir) if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) t = time.time() - iter_start_time print('step: %8d, time: %.3f' % (step+1, t), flush=True) def test_tom(opt, test_loader, model, board): model.cuda() model.eval() base_name = os.path.basename(opt.checkpoint) # save_dir = os.path.join(opt.result_dir, base_name, opt.datamode) save_dir = os.path.join(opt.result_dir, opt.name, opt.datamode) if not os.path.exists(save_dir): os.makedirs(save_dir) try_on_dir = os.path.join(save_dir, 'try-on') if not os.path.exists(try_on_dir): os.makedirs(try_on_dir) p_rendered_dir = os.path.join(save_dir, 'p_rendered') if not os.path.exists(p_rendered_dir): os.makedirs(p_rendered_dir) m_composite_dir = os.path.join(save_dir, 'm_composite') if not os.path.exists(m_composite_dir): os.makedirs(m_composite_dir) im_pose_dir = os.path.join(save_dir, 'im_pose') if not os.path.exists(im_pose_dir): os.makedirs(im_pose_dir) shape_dir = os.path.join(save_dir, 'shape') if not os.path.exists(shape_dir): os.makedirs(shape_dir) im_h_dir = os.path.join(save_dir, 'im_h') if not os.path.exists(im_h_dir): os.makedirs(im_h_dir) # for test data print('Dataset size: %05d!' % (len(test_loader.dataset)), flush=True) for step, inputs in enumerate(test_loader.data_loader): iter_start_time = time.time() im_names = inputs['im_name'] im = inputs['image'].cuda() im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() # outputs = model(torch.cat([agnostic, c], 1)) # CP-VTON outputs = model(torch.cat([agnostic, c, cm], 1)) # CP-VTON+ p_rendered, m_composite = torch.split(outputs, 3, 1) p_rendered = F.tanh(p_rendered) m_composite = F.sigmoid(m_composite) p_tryon = c * m_composite + p_rendered * (1 - m_composite) visuals = [[im_h, shape, im_pose], [c, 2*cm-1, m_composite], [p_rendered, p_tryon, im]] save_images(p_tryon, im_names, try_on_dir) save_images(im_h, im_names, im_h_dir) save_images(shape, im_names, shape_dir) save_images(im_pose, im_names, im_pose_dir) save_images(m_composite, im_names, m_composite_dir) save_images(p_rendered, im_names, p_rendered_dir) # For test data if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) t = time.time() - iter_start_time print('step: %8d, time: %.3f' % (step+1, t), flush=True) def main(): opt = get_opt() print(opt) print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name)) # create dataset test_dataset = CPDataset(opt) # create dataloader test_loader = CPDataLoader(opt, test_dataset) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name)) # create model & test if opt.stage == 'GMM': model = GMM(opt) load_checkpoint(model, opt.checkpoint) with torch.no_grad(): test_gmm(opt, test_loader, model, board) elif opt.stage == 'TOM': # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON model = UnetGenerator(26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON+ load_checkpoint(model, opt.checkpoint) with torch.no_grad(): test_tom(opt, test_loader, model, board) else: raise NotImplementedError('Model [%s] is not implemented' % opt.stage) print('Finished test %s, named: %s!' % (opt.stage, opt.name)) if __name__ == "__main__": main()