wizzseen's picture
Upload 948 files
8a6df40 verified
import torch
import torch.nn as nn
from torchvision.utils import make_grid, save_image
import argparse
import os
import time
from cp_dataset import CPDatasetTest, CPDataLoader
from networks import ConditionGenerator, load_checkpoint, define_D
from tqdm import tqdm
from tensorboardX import SummaryWriter
from utils import *
from get_norm_const import D_logit
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu_ids", default="")
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=8)
parser.add_argument('--fp16', action='store_true', help='use amp')
parser.add_argument("--dataroot", default="./data/zalando-hd-resize")
parser.add_argument("--datamode", default="test")
parser.add_argument("--data_list", default="test_pairs.txt")
parser.add_argument("--datasetting", default="paired")
parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint')
parser.add_argument('--D_checkpoint', type=str, default='', help='D checkpoint')
parser.add_argument("--tensorboard_count", type=int, default=100)
parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
parser.add_argument("--semantic_nc", type=int, default=13)
parser.add_argument("--output_nc", type=int, default=13)
# network
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
# training
parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad')
# Hyper-parameters
parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear'])
parser.add_argument('--occlusion', action='store_true', help="Occlusion handling")
# Discriminator
parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field")
parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D")
parser.add_argument('--num_D', type=int, default=2, help='Generator ngf')
parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D")
parser.add_argument('--norm_const', type=float, help='Normalizing constant for rejection sampling')
opt = parser.parse_args()
return opt
def test(opt, test_loader, board, tocg, D=None):
# Model
tocg.cuda()
tocg.eval()
if D is not None:
D.cuda()
D.eval()
os.makedirs(os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1],
opt.datamode, opt.datasetting, 'multi-task'), exist_ok=True)
num = 0
iter_start_time = time.time()
if D is not None:
D_score = []
for inputs in test_loader.data_loader:
# input1
c_paired = inputs['cloth'][opt.datasetting].cuda()
cm_paired = inputs['cloth_mask'][opt.datasetting].cuda()
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
# input2
parse_agnostic = inputs['parse_agnostic'].cuda()
densepose = inputs['densepose'].cuda()
openpose = inputs['pose'].cuda()
# GT
label_onehot = inputs['parse_onehot'].cuda() # CE
label = inputs['parse'].cuda() # GAN loss
parse_cloth_mask = inputs['pcm'].cuda() # L1
im_c = inputs['parse_cloth'].cuda() # VGG
# visualization
im = inputs['image']
with torch.no_grad():
# inputs
input1 = torch.cat([c_paired, cm_paired], 1)
input2 = torch.cat([parse_agnostic, densepose], 1)
# forward
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2)
# warped cloth mask one hot
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
if opt.clothmask_composition != 'no_composition':
if opt.clothmask_composition == 'detach':
cloth_mask = torch.ones_like(fake_segmap)
cloth_mask[:,3:4, :, :] = warped_cm_onehot
fake_segmap = fake_segmap * cloth_mask
if opt.clothmask_composition == 'warp_grad':
cloth_mask = torch.ones_like(fake_segmap)
cloth_mask[:,3:4, :, :] = warped_clothmask_paired
fake_segmap = fake_segmap * cloth_mask
if D is not None:
fake_segmap_softmax = F.softmax(fake_segmap, dim=1)
pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1))
score = D_logit(pred_segmap)
# score = torch.exp(score) / opt.norm_const
score = (score / (1 - score)) / opt.norm_const
print("prob0", score)
for i in range(cm_paired.shape[0]):
name = inputs['c_name']['paired'][i].replace('.jpg', '.png')
D_score.append((name, score[i].item()))
# generated fake cloth mask & misalign mask
fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long()
misalign = fake_clothmask - warped_cm_onehot
misalign[misalign < 0.0] = 0.0
for i in range(c_paired.shape[0]):
grid = make_grid([(c_paired[i].cpu() / 2 + 0.5), (cm_paired[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2),
(im_c[i].cpu() / 2 + 0.5), parse_cloth_mask[i].cpu().expand(3, -1, -1), (warped_cloth_paired[i].cpu().detach() / 2 + 0.5), (warped_cm_onehot[i].cpu().detach()).expand(3, -1, -1),
visualize_segmap(label.cpu(), batch=i), visualize_segmap(fake_segmap.cpu(), batch=i), (im[i]/2 +0.5), (misalign[i].cpu().detach()).expand(3, -1, -1)],
nrow=4)
save_image(grid, os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1],
opt.datamode, opt.datasetting, 'multi-task',
(inputs['c_name']['paired'][i].split('.')[0] + '_' +
inputs['c_name']['unpaired'][i].split('.')[0] + '.png')))
num += c_paired.shape[0]
print(num)
if D is not None:
D_score.sort(key=lambda x: x[1], reverse=True)
# Save D_score
for name, score in D_score:
f = open(os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1],
opt.datamode, opt.datasetting, 'multi-task', 'rejection_prob.txt'), 'a')
f.write(name + ' ' + str(score) + '\n')
f.close()
print(f"Test time {time.time() - iter_start_time}")
def main():
opt = get_opt()
print(opt)
print("Start to test %s!")
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
# create test dataset & loader
test_dataset = CPDatasetTest(opt)
test_loader = CPDataLoader(opt, test_dataset)
# visualization
if not os.path.exists(opt.tensorboard_dir):
os.makedirs(opt.tensorboard_dir)
board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1], opt.datamode, opt.datasetting))
# Model
input1_nc = 4 # cloth + cloth-mask
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose
tocg = ConditionGenerator(opt, input1_nc=input1_nc, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
if not opt.D_checkpoint == '' and os.path.exists(opt.D_checkpoint):
if opt.norm_const is None:
raise NotImplementedError
D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2 = opt.Ddownx2, Ddropout = opt.Ddropout, n_layers_D=3, spectral = opt.spectral, num_D = opt.num_D)
else:
D = None
# Load Checkpoint
load_checkpoint(tocg, opt.tocg_checkpoint)
if not opt.D_checkpoint == '' and os.path.exists(opt.D_checkpoint):
load_checkpoint(D, opt.D_checkpoint)
# Train
test(opt, test_loader, board, tocg, D=D)
print("Finished testing!")
if __name__ == "__main__":
main()