Spaces:
Running
Running
File size: 9,024 Bytes
8a6df40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import torch
import torch.nn as nn
from torchvision.utils import make_grid, save_image
import argparse
import os
import time
from cp_dataset import CPDatasetTest, CPDataLoader
from networks import ConditionGenerator, load_checkpoint, define_D
from tqdm import tqdm
from tensorboardX import SummaryWriter
from utils import *
from get_norm_const import D_logit
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu_ids", default="")
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=8)
parser.add_argument('--fp16', action='store_true', help='use amp')
parser.add_argument("--dataroot", default="./data/zalando-hd-resize")
parser.add_argument("--datamode", default="test")
parser.add_argument("--data_list", default="test_pairs.txt")
parser.add_argument("--datasetting", default="paired")
parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint')
parser.add_argument('--D_checkpoint', type=str, default='', help='D checkpoint')
parser.add_argument("--tensorboard_count", type=int, default=100)
parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
parser.add_argument("--semantic_nc", type=int, default=13)
parser.add_argument("--output_nc", type=int, default=13)
# network
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
# training
parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad')
# Hyper-parameters
parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear'])
parser.add_argument('--occlusion', action='store_true', help="Occlusion handling")
# Discriminator
parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field")
parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D")
parser.add_argument('--num_D', type=int, default=2, help='Generator ngf')
parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D")
parser.add_argument('--norm_const', type=float, help='Normalizing constant for rejection sampling')
opt = parser.parse_args()
return opt
def test(opt, test_loader, board, tocg, D=None):
# Model
tocg.cuda()
tocg.eval()
if D is not None:
D.cuda()
D.eval()
os.makedirs(os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1],
opt.datamode, opt.datasetting, 'multi-task'), exist_ok=True)
num = 0
iter_start_time = time.time()
if D is not None:
D_score = []
for inputs in test_loader.data_loader:
# input1
c_paired = inputs['cloth'][opt.datasetting].cuda()
cm_paired = inputs['cloth_mask'][opt.datasetting].cuda()
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
# input2
parse_agnostic = inputs['parse_agnostic'].cuda()
densepose = inputs['densepose'].cuda()
openpose = inputs['pose'].cuda()
# GT
label_onehot = inputs['parse_onehot'].cuda() # CE
label = inputs['parse'].cuda() # GAN loss
parse_cloth_mask = inputs['pcm'].cuda() # L1
im_c = inputs['parse_cloth'].cuda() # VGG
# visualization
im = inputs['image']
with torch.no_grad():
# inputs
input1 = torch.cat([c_paired, cm_paired], 1)
input2 = torch.cat([parse_agnostic, densepose], 1)
# forward
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2)
# warped cloth mask one hot
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
if opt.clothmask_composition != 'no_composition':
if opt.clothmask_composition == 'detach':
cloth_mask = torch.ones_like(fake_segmap)
cloth_mask[:,3:4, :, :] = warped_cm_onehot
fake_segmap = fake_segmap * cloth_mask
if opt.clothmask_composition == 'warp_grad':
cloth_mask = torch.ones_like(fake_segmap)
cloth_mask[:,3:4, :, :] = warped_clothmask_paired
fake_segmap = fake_segmap * cloth_mask
if D is not None:
fake_segmap_softmax = F.softmax(fake_segmap, dim=1)
pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1))
score = D_logit(pred_segmap)
# score = torch.exp(score) / opt.norm_const
score = (score / (1 - score)) / opt.norm_const
print("prob0", score)
for i in range(cm_paired.shape[0]):
name = inputs['c_name']['paired'][i].replace('.jpg', '.png')
D_score.append((name, score[i].item()))
# generated fake cloth mask & misalign mask
fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long()
misalign = fake_clothmask - warped_cm_onehot
misalign[misalign < 0.0] = 0.0
for i in range(c_paired.shape[0]):
grid = make_grid([(c_paired[i].cpu() / 2 + 0.5), (cm_paired[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2),
(im_c[i].cpu() / 2 + 0.5), parse_cloth_mask[i].cpu().expand(3, -1, -1), (warped_cloth_paired[i].cpu().detach() / 2 + 0.5), (warped_cm_onehot[i].cpu().detach()).expand(3, -1, -1),
visualize_segmap(label.cpu(), batch=i), visualize_segmap(fake_segmap.cpu(), batch=i), (im[i]/2 +0.5), (misalign[i].cpu().detach()).expand(3, -1, -1)],
nrow=4)
save_image(grid, os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1],
opt.datamode, opt.datasetting, 'multi-task',
(inputs['c_name']['paired'][i].split('.')[0] + '_' +
inputs['c_name']['unpaired'][i].split('.')[0] + '.png')))
num += c_paired.shape[0]
print(num)
if D is not None:
D_score.sort(key=lambda x: x[1], reverse=True)
# Save D_score
for name, score in D_score:
f = open(os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1],
opt.datamode, opt.datasetting, 'multi-task', 'rejection_prob.txt'), 'a')
f.write(name + ' ' + str(score) + '\n')
f.close()
print(f"Test time {time.time() - iter_start_time}")
def main():
opt = get_opt()
print(opt)
print("Start to test %s!")
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
# create test dataset & loader
test_dataset = CPDatasetTest(opt)
test_loader = CPDataLoader(opt, test_dataset)
# visualization
if not os.path.exists(opt.tensorboard_dir):
os.makedirs(opt.tensorboard_dir)
board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1], opt.datamode, opt.datasetting))
# Model
input1_nc = 4 # cloth + cloth-mask
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose
tocg = ConditionGenerator(opt, input1_nc=input1_nc, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
if not opt.D_checkpoint == '' and os.path.exists(opt.D_checkpoint):
if opt.norm_const is None:
raise NotImplementedError
D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2 = opt.Ddownx2, Ddropout = opt.Ddropout, n_layers_D=3, spectral = opt.spectral, num_D = opt.num_D)
else:
D = None
# Load Checkpoint
load_checkpoint(tocg, opt.tocg_checkpoint)
if not opt.D_checkpoint == '' and os.path.exists(opt.D_checkpoint):
load_checkpoint(D, opt.D_checkpoint)
# Train
test(opt, test_loader, board, tocg, D=D)
print("Finished testing!")
if __name__ == "__main__":
main() |