Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchvision.utils import make_grid, save_image | |
import argparse | |
import os | |
import time | |
from cp_dataset import CPDatasetTest, CPDataLoader | |
from networks import ConditionGenerator, load_checkpoint, define_D | |
from tqdm import tqdm | |
from tensorboardX import SummaryWriter | |
from utils import * | |
from get_norm_const import D_logit | |
def get_opt(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--gpu_ids", default="") | |
parser.add_argument('-j', '--workers', type=int, default=4) | |
parser.add_argument('-b', '--batch-size', type=int, default=8) | |
parser.add_argument('--fp16', action='store_true', help='use amp') | |
parser.add_argument("--dataroot", default="./data/zalando-hd-resize") | |
parser.add_argument("--datamode", default="test") | |
parser.add_argument("--data_list", default="test_pairs.txt") | |
parser.add_argument("--datasetting", default="paired") | |
parser.add_argument("--fine_width", type=int, default=192) | |
parser.add_argument("--fine_height", type=int, default=256) | |
parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos') | |
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos') | |
parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint') | |
parser.add_argument('--D_checkpoint', type=str, default='', help='D checkpoint') | |
parser.add_argument("--tensorboard_count", type=int, default=100) | |
parser.add_argument("--shuffle", action='store_true', help='shuffle input data') | |
parser.add_argument("--semantic_nc", type=int, default=13) | |
parser.add_argument("--output_nc", type=int, default=13) | |
# network | |
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1") | |
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu") | |
# training | |
parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad') | |
# Hyper-parameters | |
parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear']) | |
parser.add_argument('--occlusion', action='store_true', help="Occlusion handling") | |
# Discriminator | |
parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field") | |
parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D") | |
parser.add_argument('--num_D', type=int, default=2, help='Generator ngf') | |
parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D") | |
parser.add_argument('--norm_const', type=float, help='Normalizing constant for rejection sampling') | |
opt = parser.parse_args() | |
return opt | |
def test(opt, test_loader, board, tocg, D=None): | |
# Model | |
tocg.cuda() | |
tocg.eval() | |
if D is not None: | |
D.cuda() | |
D.eval() | |
os.makedirs(os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1], | |
opt.datamode, opt.datasetting, 'multi-task'), exist_ok=True) | |
num = 0 | |
iter_start_time = time.time() | |
if D is not None: | |
D_score = [] | |
for inputs in test_loader.data_loader: | |
# input1 | |
c_paired = inputs['cloth'][opt.datasetting].cuda() | |
cm_paired = inputs['cloth_mask'][opt.datasetting].cuda() | |
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
# input2 | |
parse_agnostic = inputs['parse_agnostic'].cuda() | |
densepose = inputs['densepose'].cuda() | |
openpose = inputs['pose'].cuda() | |
# GT | |
label_onehot = inputs['parse_onehot'].cuda() # CE | |
label = inputs['parse'].cuda() # GAN loss | |
parse_cloth_mask = inputs['pcm'].cuda() # L1 | |
im_c = inputs['parse_cloth'].cuda() # VGG | |
# visualization | |
im = inputs['image'] | |
with torch.no_grad(): | |
# inputs | |
input1 = torch.cat([c_paired, cm_paired], 1) | |
input2 = torch.cat([parse_agnostic, densepose], 1) | |
# forward | |
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) | |
# warped cloth mask one hot | |
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
if opt.clothmask_composition != 'no_composition': | |
if opt.clothmask_composition == 'detach': | |
cloth_mask = torch.ones_like(fake_segmap) | |
cloth_mask[:,3:4, :, :] = warped_cm_onehot | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.clothmask_composition == 'warp_grad': | |
cloth_mask = torch.ones_like(fake_segmap) | |
cloth_mask[:,3:4, :, :] = warped_clothmask_paired | |
fake_segmap = fake_segmap * cloth_mask | |
if D is not None: | |
fake_segmap_softmax = F.softmax(fake_segmap, dim=1) | |
pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1)) | |
score = D_logit(pred_segmap) | |
# score = torch.exp(score) / opt.norm_const | |
score = (score / (1 - score)) / opt.norm_const | |
print("prob0", score) | |
for i in range(cm_paired.shape[0]): | |
name = inputs['c_name']['paired'][i].replace('.jpg', '.png') | |
D_score.append((name, score[i].item())) | |
# generated fake cloth mask & misalign mask | |
fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long() | |
misalign = fake_clothmask - warped_cm_onehot | |
misalign[misalign < 0.0] = 0.0 | |
for i in range(c_paired.shape[0]): | |
grid = make_grid([(c_paired[i].cpu() / 2 + 0.5), (cm_paired[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2), | |
(im_c[i].cpu() / 2 + 0.5), parse_cloth_mask[i].cpu().expand(3, -1, -1), (warped_cloth_paired[i].cpu().detach() / 2 + 0.5), (warped_cm_onehot[i].cpu().detach()).expand(3, -1, -1), | |
visualize_segmap(label.cpu(), batch=i), visualize_segmap(fake_segmap.cpu(), batch=i), (im[i]/2 +0.5), (misalign[i].cpu().detach()).expand(3, -1, -1)], | |
nrow=4) | |
save_image(grid, os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1], | |
opt.datamode, opt.datasetting, 'multi-task', | |
(inputs['c_name']['paired'][i].split('.')[0] + '_' + | |
inputs['c_name']['unpaired'][i].split('.')[0] + '.png'))) | |
num += c_paired.shape[0] | |
print(num) | |
if D is not None: | |
D_score.sort(key=lambda x: x[1], reverse=True) | |
# Save D_score | |
for name, score in D_score: | |
f = open(os.path.join('./output', opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1], | |
opt.datamode, opt.datasetting, 'multi-task', 'rejection_prob.txt'), 'a') | |
f.write(name + ' ' + str(score) + '\n') | |
f.close() | |
print(f"Test time {time.time() - iter_start_time}") | |
def main(): | |
opt = get_opt() | |
print(opt) | |
print("Start to test %s!") | |
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids | |
# create test dataset & loader | |
test_dataset = CPDatasetTest(opt) | |
test_loader = CPDataLoader(opt, test_dataset) | |
# visualization | |
if not os.path.exists(opt.tensorboard_dir): | |
os.makedirs(opt.tensorboard_dir) | |
board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.tocg_checkpoint.split('/')[-2], opt.tocg_checkpoint.split('/')[-1], opt.datamode, opt.datasetting)) | |
# Model | |
input1_nc = 4 # cloth + cloth-mask | |
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose | |
tocg = ConditionGenerator(opt, input1_nc=input1_nc, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d) | |
if not opt.D_checkpoint == '' and os.path.exists(opt.D_checkpoint): | |
if opt.norm_const is None: | |
raise NotImplementedError | |
D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2 = opt.Ddownx2, Ddropout = opt.Ddropout, n_layers_D=3, spectral = opt.spectral, num_D = opt.num_D) | |
else: | |
D = None | |
# Load Checkpoint | |
load_checkpoint(tocg, opt.tocg_checkpoint) | |
if not opt.D_checkpoint == '' and os.path.exists(opt.D_checkpoint): | |
load_checkpoint(D, opt.D_checkpoint) | |
# Train | |
test(opt, test_loader, board, tocg, D=D) | |
print("Finished testing!") | |
if __name__ == "__main__": | |
main() |