wizzseen's picture
Upload 948 files
8a6df40 verified
# Calculate the normalization constant for discriminator rejection
import torch
import torch.nn as nn
import argparse
import os
import time
from cp_dataset import CPDataset, CPDataLoader
from networks import ConditionGenerator, load_checkpoint, define_D
from utils import *
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--name", default="test")
parser.add_argument("--gpu_ids", default="")
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=8)
parser.add_argument('--fp16', action='store_true', help='use amp')
parser.add_argument("--dataroot", default="./data")
parser.add_argument("--datamode", default="train")
parser.add_argument("--data_list", default="train_pairs_zalando.txt")
parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
parser.add_argument('--D_checkpoint', type=str, default='', help='tocg checkpoint')
parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint')
parser.add_argument("--tensorboard_count", type=int, default=100)
parser.add_argument("--display_count", type=int, default=100)
parser.add_argument("--save_count", type=int, default=10000)
parser.add_argument("--load_step", type=int, default=0)
parser.add_argument("--keep_step", type=int, default=300000)
parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
parser.add_argument("--semantic_nc", type=int, default=13)
parser.add_argument("--output_nc", type=int, default=13)
# Condition generator
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad')
# network structure
parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field")
parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D")
parser.add_argument('--num_D', type=int, default=2, help='Generator ngf')
parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D")
parser.add_argument("--test_datasetting", default="unpaired")
parser.add_argument("--test_dataroot", default="./data/zalando-hd-resize")
parser.add_argument("--test_data_list", default="test_pairs.txt")
opt = parser.parse_args()
return opt
def D_logit(pred):
score = 0
for i in pred:
score += i[-1].mean((1,2,3)) / 2
return score
def get_const(opt, train_loader, tocg, D, length):
# Model
D.cuda()
D.eval()
tocg.cuda()
tocg.eval()
logit_list = []
i = 0
for step in range(length // opt.batch_size):
iter_start_time = time.time()
inputs = train_loader.next_batch()
# input1
c_paired = inputs['cloth']['paired'].cuda()
cm_paired = inputs['cloth_mask']['paired'].cuda()
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
# input2
parse_agnostic = inputs['parse_agnostic'].cuda()
densepose = inputs['densepose'].cuda()
openpose = inputs['pose'].cuda()
# GT
label_onehot = inputs['parse_onehot'].cuda() # CE
label = inputs['parse'].cuda() # GAN loss
parse_cloth_mask = inputs['pcm'].cuda() # L1
im_c = inputs['parse_cloth'].cuda() # VGG
# visualization
im = inputs['image']
with torch.no_grad():
# inputs
input1 = torch.cat([c_paired, cm_paired], 1)
input2 = torch.cat([parse_agnostic, densepose], 1)
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2)
if opt.clothmask_composition != 'no_composition':
if opt.clothmask_composition == 'detach':
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
cloth_mask = torch.ones_like(fake_segmap.detach())
cloth_mask[:, 3:4, :, :] = warped_cm_onehot
fake_segmap = fake_segmap * cloth_mask
if opt.clothmask_composition == 'warp_grad':
cloth_mask = torch.ones_like(fake_segmap.detach())
cloth_mask[:, 3:4, :, :] = warped_clothmask_paired
fake_segmap = fake_segmap * cloth_mask
fake_segmap_softmax = F.softmax(fake_segmap, dim=1)
real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label),dim=1))
fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax),dim=1))
print("real:", D_logit(real_segmap_pred), "fake:", D_logit(fake_segmap_pred))
# print(fake_segmap_pred)
logit_real = D_logit(real_segmap_pred)
logit_fake = D_logit(fake_segmap_pred)
for l in logit_real:
l = l / (1-l)
logit_list.append(l.item())
for l in logit_fake:
l = l / (1-l)
logit_list.append(l.item())
# i += logit_real.shape[0]+logit_fake.shape[0]
print("i:", i)
logit_list.sort()
return logit_list[-1]
def main():
opt = get_opt()
print(opt)
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
# create train dataset & loader
train_dataset = CPDataset(opt)
train_loader = CPDataLoader(opt, train_dataset)
# Model
input1_nc = 4 # cloth + cloth-mask
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose
D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2 = opt.Ddownx2, Ddropout = opt.Ddropout, n_layers_D=3, spectral = opt.spectral, num_D = opt.num_D)
tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
# Load Checkpoint
load_checkpoint(D, opt.D_checkpoint)
load_checkpoint(tocg, opt.tocg_checkpoint)
M = get_const(opt, train_loader, tocg, D, length = len(train_dataset))
print("M:", M)
if __name__ == "__main__":
main()