Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchvision.utils import make_grid | |
from networks import make_grid as mkgrid | |
import argparse | |
import os | |
import time | |
from cp_dataset import CPDataset, CPDatasetTest, CPDataLoader | |
from networks import ConditionGenerator, VGGLoss, GANLoss, load_checkpoint, save_checkpoint, define_D | |
from tqdm import tqdm | |
from tensorboardX import SummaryWriter | |
from utils import * | |
from torch.utils.data import Subset | |
def iou_metric(y_pred_batch, y_true_batch): | |
B = y_pred_batch.shape[0] | |
iou = 0 | |
for i in range(B): | |
y_pred = y_pred_batch[i] | |
y_true = y_true_batch[i] | |
# y_pred is not one-hot, so need to threshold it | |
y_pred = y_pred > 0.5 | |
y_pred = y_pred.flatten() | |
y_true = y_true.flatten() | |
intersection = torch.sum(y_pred[y_true == 1]) | |
union = torch.sum(y_pred) + torch.sum(y_true) | |
iou += (intersection + 1e-7) / (union - intersection + 1e-7) / B | |
return iou | |
def remove_overlap(seg_out, warped_cm): | |
assert len(warped_cm.shape) == 4 | |
warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm | |
return warped_cm | |
def get_opt(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--name", default="test") | |
parser.add_argument("--gpu_ids", default="") | |
parser.add_argument('-j', '--workers', type=int, default=4) | |
parser.add_argument('-b', '--batch-size', type=int, default=8) | |
parser.add_argument('--fp16', action='store_true', help='use amp') | |
parser.add_argument("--dataroot", default="./data/") | |
parser.add_argument("--datamode", default="train") | |
parser.add_argument("--data_list", default="train_pairs.txt") | |
parser.add_argument("--fine_width", type=int, default=192) | |
parser.add_argument("--fine_height", type=int, default=256) | |
parser.add_argument('--tensorboard_dir', type=str, default='tensorboard', help='save tensorboard infos') | |
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos') | |
parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint') | |
parser.add_argument("--tensorboard_count", type=int, default=100) | |
parser.add_argument("--display_count", type=int, default=100) | |
parser.add_argument("--save_count", type=int, default=10000) | |
parser.add_argument("--load_step", type=int, default=0) | |
parser.add_argument("--keep_step", type=int, default=300000) | |
parser.add_argument("--shuffle", action='store_true', help='shuffle input data') | |
parser.add_argument("--semantic_nc", type=int, default=13) | |
parser.add_argument("--output_nc", type=int, default=13) | |
# network | |
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1") | |
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu") | |
parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field") | |
parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D") | |
parser.add_argument('--num_D', type=int, default=2, help='Generator ngf') | |
# Cuda availability | |
parser.add_argument('--cuda',default=False, help='cuda or cpu') | |
# training | |
parser.add_argument("--G_D_seperate", action='store_true') | |
parser.add_argument("--no_GAN_loss", action='store_true') | |
parser.add_argument("--lasttvonly", action='store_true') | |
parser.add_argument("--interflowloss", action='store_true', help="Intermediate flow loss") | |
parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad') | |
parser.add_argument('--edgeawaretv', type=str, choices=['no_edge', 'last_only', 'weighted'], default="no_edge", help="Edge aware TV loss") | |
parser.add_argument('--add_lasttv', action='store_true') | |
# test visualize | |
parser.add_argument("--no_test_visualize", action='store_true') | |
parser.add_argument("--num_test_visualize", type=int, default=3) | |
parser.add_argument("--test_datasetting", default="unpaired") | |
parser.add_argument("--test_dataroot", default="./data/") | |
parser.add_argument("--test_data_list", default="test_pairs.txt") | |
# Hyper-parameters | |
parser.add_argument('--G_lr', type=float, default=0.0002, help='Generator initial learning rate for adam') | |
parser.add_argument('--D_lr', type=float, default=0.0002, help='Discriminator initial learning rate for adam') | |
parser.add_argument('--CElamda', type=float, default=10, help='initial learning rate for adam') | |
parser.add_argument('--GANlambda', type=float, default=1) | |
parser.add_argument('--tvlambda', type=float, default=2) | |
parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear']) | |
parser.add_argument('--val_count', type=int, default='1000') | |
parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D") | |
parser.add_argument('--occlusion', action='store_true', help="Occlusion handling") | |
opt = parser.parse_args() | |
return opt | |
def train(opt, train_loader, test_loader, val_loader, board, tocg, D): | |
# Model | |
tocg.cuda() | |
tocg.train() | |
D.cuda() | |
D.train() | |
# criterion | |
criterionL1 = nn.L1Loss() | |
criterionVGG = VGGLoss(opt) | |
if opt.fp16: | |
criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.HalfTensor) | |
else : | |
criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor if opt.gpu_ids else torch.Tensor) | |
# optimizer | |
optimizer_G = torch.optim.Adam(tocg.parameters(), lr=opt.G_lr, betas=(0.5, 0.999)) | |
optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.D_lr, betas=(0.5, 0.999)) | |
for step in tqdm(range(opt.load_step, opt.keep_step)): | |
iter_start_time = time.time() | |
inputs = train_loader.next_batch() | |
# input1 | |
c_paired = inputs['cloth']['paired'].cuda() | |
cm_paired = inputs['cloth_mask']['paired'].cuda() | |
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
# input2 | |
parse_agnostic = inputs['parse_agnostic'].cuda() | |
densepose = inputs['densepose'].cuda() | |
openpose = inputs['pose'].cuda() | |
# GT | |
label_onehot = inputs['parse_onehot'].cuda() # CE | |
label = inputs['parse'].cuda() # GAN loss | |
parse_cloth_mask = inputs['pcm'].cuda() # L1 | |
im_c = inputs['parse_cloth'].cuda() # VGG | |
# visualization | |
im = inputs['image'] | |
# inputs | |
input1 = torch.cat([c_paired, cm_paired], 1) | |
input2 = torch.cat([parse_agnostic, densepose], 1) | |
# forward | |
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) | |
# warped cloth mask one hot | |
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
# fake segmap cloth channel * warped clothmask | |
if opt.clothmask_composition != 'no_composition': | |
if opt.clothmask_composition == 'detach': | |
cloth_mask = torch.ones_like(fake_segmap.detach()) | |
cloth_mask[:, 3:4, :, :] = warped_cm_onehot | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.clothmask_composition == 'warp_grad': | |
cloth_mask = torch.ones_like(fake_segmap.detach()) | |
cloth_mask[:, 3:4, :, :] = warped_clothmask_paired | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.occlusion: | |
warped_clothmask_paired = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired) | |
warped_cloth_paired = warped_cloth_paired * warped_clothmask_paired + torch.ones_like(warped_cloth_paired) * (1-warped_clothmask_paired) | |
# generated fake cloth mask & misalign mask | |
fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long() | |
misalign = fake_clothmask - warped_cm_onehot | |
misalign[misalign < 0.0] = 0.0 | |
# loss warping | |
loss_l1_cloth = criterionL1(warped_clothmask_paired, parse_cloth_mask) | |
loss_vgg = criterionVGG(warped_cloth_paired, im_c) | |
loss_tv = 0 | |
if opt.edgeawaretv == 'no_edge': | |
if not opt.lasttvonly: | |
for flow in flow_list: | |
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() | |
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() | |
loss_tv = loss_tv + y_tv + x_tv | |
else: | |
for flow in flow_list[-1:]: | |
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() | |
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() | |
loss_tv = loss_tv + y_tv + x_tv | |
else: | |
if opt.edgeawaretv == 'last_only': | |
flow = flow_list[-1] | |
warped_clothmask_paired_down = F.interpolate(warped_clothmask_paired, flow.shape[1:3], mode='bilinear') | |
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]) | |
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]) | |
mask_y = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, 1:, :, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :-1, :, :])) | |
mask_x = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, 1:, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, :-1, :])) | |
y_tv = y_tv * mask_y | |
x_tv = x_tv * mask_x | |
y_tv = y_tv.mean() | |
x_tv = x_tv.mean() | |
loss_tv = loss_tv + y_tv + x_tv | |
elif opt.edgeawaretv == 'weighted': | |
for i in range(5): | |
flow = flow_list[i] | |
warped_clothmask_paired_down = F.interpolate(warped_clothmask_paired, flow.shape[1:3], mode='bilinear') | |
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]) | |
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]) | |
mask_y = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, 1:, :, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :-1, :, :])) | |
mask_x = torch.exp(-150*torch.abs(warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, 1:, :] - warped_clothmask_paired_down.permute(0, 2, 3, 1)[:, :, :-1, :])) | |
y_tv = y_tv * mask_y | |
x_tv = x_tv * mask_x | |
y_tv = y_tv.mean() / (2 ** (4-i)) | |
x_tv = x_tv.mean() / (2 ** (4-i)) | |
loss_tv = loss_tv + y_tv + x_tv | |
if opt.add_lasttv: | |
for flow in flow_list[-1:]: | |
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() | |
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() | |
loss_tv = loss_tv + y_tv + x_tv | |
N, _, iH, iW = c_paired.size() | |
# Intermediate flow loss | |
if opt.interflowloss: | |
for i in range(len(flow_list)-1): | |
flow = flow_list[i] | |
N, fH, fW, _ = flow.size() | |
grid = mkgrid(N, iH, iW) | |
flow = F.interpolate(flow.permute(0, 3, 1, 2), size = c_paired.shape[2:], mode=opt.upsample).permute(0, 2, 3, 1) | |
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((fW - 1.0) / 2.0), flow[:, :, :, 1:2] / ((fH - 1.0) / 2.0)], 3) | |
warped_c = F.grid_sample(c_paired, flow_norm + grid, padding_mode='border') | |
warped_cm = F.grid_sample(cm_paired, flow_norm + grid, padding_mode='border') | |
warped_cm = remove_overlap(F.softmax(fake_segmap, dim=1), warped_cm) | |
loss_l1_cloth += criterionL1(warped_cm, parse_cloth_mask) / (2 ** (4-i)) | |
loss_vgg += criterionVGG(warped_c, im_c) / (2 ** (4-i)) | |
# loss segmentation | |
# generator | |
CE_loss = cross_entropy2d(fake_segmap, label_onehot.transpose(0, 1)[0].long()) | |
if opt.no_GAN_loss: | |
loss_G = (10 * loss_l1_cloth + loss_vgg + opt.tvlambda * loss_tv) + (CE_loss * opt.CElamda) | |
# step | |
optimizer_G.zero_grad() | |
loss_G.backward() | |
optimizer_G.step() | |
else: | |
fake_segmap_softmax = torch.softmax(fake_segmap, 1) | |
pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1)) | |
loss_G_GAN = criterionGAN(pred_segmap, True) | |
if not opt.G_D_seperate: | |
# discriminator | |
fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax.detach()),dim=1)) | |
real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label),dim=1)) | |
loss_D_fake = criterionGAN(fake_segmap_pred, False) | |
loss_D_real = criterionGAN(real_segmap_pred, True) | |
# loss sum | |
loss_G = (10 * loss_l1_cloth + loss_vgg +opt.tvlambda * loss_tv) + (CE_loss * opt.CElamda + loss_G_GAN * opt.GANlambda) # warping + seg_generation | |
loss_D = loss_D_fake + loss_D_real | |
# step | |
optimizer_G.zero_grad() | |
loss_G.backward() | |
optimizer_G.step() | |
optimizer_D.zero_grad() | |
loss_D.backward() | |
optimizer_D.step() | |
else: # train G first after that train D | |
# loss G sum | |
loss_G = (10 * loss_l1_cloth + loss_vgg + opt.tvlambda * loss_tv) + (CE_loss * opt.CElamda + loss_G_GAN * opt.GANlambda) # warping + seg_generation | |
# step G | |
optimizer_G.zero_grad() | |
loss_G.backward() | |
optimizer_G.step() | |
# discriminator | |
with torch.no_grad(): | |
_, fake_segmap, _, _ = tocg(input1, input2) | |
fake_segmap_softmax = torch.softmax(fake_segmap, 1) | |
# loss discriminator | |
fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax.detach()),dim=1)) | |
real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label),dim=1)) | |
loss_D_fake = criterionGAN(fake_segmap_pred, False) | |
loss_D_real = criterionGAN(real_segmap_pred, True) | |
loss_D = loss_D_fake + loss_D_real | |
optimizer_D.zero_grad() | |
loss_D.backward() | |
optimizer_D.step() | |
# Vaildation | |
if (step + 1) % opt.val_count == 0: | |
tocg.eval() | |
iou_list = [] | |
with torch.no_grad(): | |
for cnt in range(2000//opt.batch_size): | |
inputs = val_loader.next_batch() | |
# input1 | |
c_paired = inputs['cloth']['paired'].cuda() | |
cm_paired = inputs['cloth_mask']['paired'].cuda() | |
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
# input2 | |
parse_agnostic = inputs['parse_agnostic'].cuda() | |
densepose = inputs['densepose'].cuda() | |
openpose = inputs['pose'].cuda() | |
# GT | |
label_onehot = inputs['parse_onehot'].cuda() # CE | |
label = inputs['parse'].cuda() # GAN loss | |
parse_cloth_mask = inputs['pcm'].cuda() # L1 | |
im_c = inputs['parse_cloth'].cuda() # VGG | |
# visualization | |
im = inputs['image'] | |
input1 = torch.cat([c_paired, cm_paired], 1) | |
input2 = torch.cat([parse_agnostic, densepose], 1) | |
# forward | |
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) | |
# fake segmap cloth channel * warped clothmask | |
if opt.clothmask_composition != 'no_composition': | |
if opt.clothmask_composition == 'detach': | |
cloth_mask = torch.ones_like(fake_segmap.detach()) | |
cloth_mask[:, 3:4, :, :] = warped_cm_onehot | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.clothmask_composition == 'warp_grad': | |
cloth_mask = torch.ones_like(fake_segmap.detach()) | |
cloth_mask[:, 3:4, :, :] = warped_clothmask_paired | |
fake_segmap = fake_segmap * cloth_mask | |
# calculate iou | |
iou = iou_metric(F.softmax(fake_segmap, dim=1).detach(), label) | |
iou_list.append(iou.item()) | |
tocg.train() | |
board.add_scalar('val/iou', np.mean(iou_list), step + 1) | |
# tensorboard | |
if (step + 1) % opt.tensorboard_count == 0: | |
# loss G | |
board.add_scalar('Loss/G', loss_G.item(), step + 1) | |
board.add_scalar('Loss/G/l1_cloth', loss_l1_cloth.item(), step + 1) | |
board.add_scalar('Loss/G/vgg', loss_vgg.item(), step + 1) | |
board.add_scalar('Loss/G/tv', loss_tv.item(), step + 1) | |
board.add_scalar('Loss/G/CE', CE_loss.item(), step + 1) | |
if not opt.no_GAN_loss: | |
board.add_scalar('Loss/G/GAN', loss_G_GAN.item(), step + 1) | |
# loss D | |
board.add_scalar('Loss/D', loss_D.item(), step + 1) | |
board.add_scalar('Loss/D/pred_real', loss_D_real.item(), step + 1) | |
board.add_scalar('Loss/D/pred_fake', loss_D_fake.item(), step + 1) | |
grid = make_grid([(c_paired[0].cpu() / 2 + 0.5), (cm_paired[0].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu()), ((densepose.cpu()[0]+1)/2), | |
(im_c[0].cpu() / 2 + 0.5), parse_cloth_mask[0].cpu().expand(3, -1, -1), (warped_cloth_paired[0].cpu().detach() / 2 + 0.5), (warped_cm_onehot[0].cpu().detach()).expand(3, -1, -1), | |
visualize_segmap(label.cpu()), visualize_segmap(fake_segmap.cpu()), (im[0]/2 +0.5), (misalign[0].cpu().detach()).expand(3, -1, -1)], | |
nrow=4) | |
board.add_images('train_images', grid.unsqueeze(0), step + 1) | |
if not opt.no_test_visualize: | |
inputs = test_loader.next_batch() | |
# input1 | |
c_paired = inputs['cloth'][opt.test_datasetting].cuda() | |
cm_paired = inputs['cloth_mask'][opt.test_datasetting].cuda() | |
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
# input2 | |
parse_agnostic = inputs['parse_agnostic'].cuda() | |
densepose = inputs['densepose'].cuda() | |
openpose = inputs['pose'].cuda() | |
# GT | |
label_onehot = inputs['parse_onehot'].cuda() # CE | |
label = inputs['parse'].cuda() # GAN loss | |
parse_cloth_mask = inputs['pcm'].cuda() # L1 | |
im_c = inputs['parse_cloth'].cuda() # VGG | |
# visualization | |
im = inputs['image'] | |
tocg.eval() | |
with torch.no_grad(): | |
# inputs | |
input1 = torch.cat([c_paired, cm_paired], 1) | |
input2 = torch.cat([parse_agnostic, densepose], 1) | |
# forward | |
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(input1, input2) | |
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
if opt.clothmask_composition != 'no_composition': | |
if opt.clothmask_composition == 'detach': | |
cloth_mask = torch.ones_like(fake_segmap) | |
cloth_mask[:,3:4, :, :] = warped_cm_onehot | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.clothmask_composition == 'warp_grad': | |
cloth_mask = torch.ones_like(fake_segmap) | |
cloth_mask[:,3:4, :, :] = warped_clothmask_paired | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.occlusion: | |
warped_clothmask_paired = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired) | |
warped_cloth_paired = warped_cloth_paired * warped_clothmask_paired + torch.ones_like(warped_cloth_paired) * (1-warped_clothmask_paired) | |
# generated fake cloth mask & misalign mask | |
fake_clothmask = (torch.argmax(fake_segmap.detach(), dim=1, keepdim=True) == 3).long() | |
misalign = fake_clothmask - warped_cm_onehot | |
misalign[misalign < 0.0] = 0.0 | |
for i in range(opt.num_test_visualize): | |
grid = make_grid([(c_paired[i].cpu() / 2 + 0.5), (cm_paired[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2), | |
(im_c[i].cpu() / 2 + 0.5), parse_cloth_mask[i].cpu().expand(3, -1, -1), (warped_cloth_paired[i].cpu().detach() / 2 + 0.5), (warped_cm_onehot[i].cpu().detach()).expand(3, -1, -1), | |
visualize_segmap(label.cpu(), batch=i), visualize_segmap(fake_segmap.cpu(), batch=i), (im[i]/2 +0.5), (misalign[i].cpu().detach()).expand(3, -1, -1)], | |
nrow=4) | |
board.add_images(f'test_images/{i}', grid.unsqueeze(0), step + 1) | |
tocg.train() | |
# display | |
if (step + 1) % opt.display_count == 0: | |
t = time.time() - iter_start_time | |
if not opt.no_GAN_loss: | |
print("step: %8d, time: %.3f\nloss G: %.4f, L1_cloth loss: %.4f, VGG loss: %.4f, TV loss: %.4f CE: %.4f, G GAN: %.4f\nloss D: %.4f, D real: %.4f, D fake: %.4f" | |
% (step + 1, t, loss_G.item(), loss_l1_cloth.item(), loss_vgg.item(), loss_tv.item(), CE_loss.item(), loss_G_GAN.item(), loss_D.item(), loss_D_real.item(), loss_D_fake.item()), flush=True) | |
# save | |
if (step + 1) % opt.save_count == 0: | |
save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_step_%06d.pth' % (step + 1)),opt) | |
save_checkpoint(D, os.path.join(opt.checkpoint_dir, opt.name, 'D_step_%06d.pth' % (step + 1)),opt) | |
def main(): | |
opt = get_opt() | |
print(opt) | |
print("Start to train %s!" % opt.name) | |
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids | |
# create train dataset & loader | |
train_dataset = CPDataset(opt) | |
train_loader = CPDataLoader(opt, train_dataset) | |
# create test dataset & loader | |
test_loader = None | |
if not opt.no_test_visualize: | |
train_bsize = opt.batch_size | |
opt.batch_size = opt.num_test_visualize | |
opt.dataroot = opt.test_dataroot | |
opt.datamode = 'test' | |
opt.data_list = opt.test_data_list | |
test_dataset = CPDatasetTest(opt) | |
opt.batch_size = train_bsize | |
val_dataset = Subset(test_dataset, np.arange(2000)) | |
test_loader = CPDataLoader(opt, test_dataset) | |
val_loader = CPDataLoader(opt, val_dataset) | |
# visualization | |
if not os.path.exists(opt.tensorboard_dir): | |
os.makedirs(opt.tensorboard_dir) | |
board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name)) | |
# Model | |
input1_nc = 4 # cloth + cloth-mask | |
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose | |
tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d) | |
D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2 = opt.Ddownx2, Ddropout = opt.Ddropout, n_layers_D=3, spectral = opt.spectral, num_D = opt.num_D) | |
# Load Checkpoint | |
if not opt.tocg_checkpoint == '' and os.path.exists(opt.tocg_checkpoint): | |
load_checkpoint(tocg, opt.tocg_checkpoint) | |
# Train | |
train(opt, train_loader, val_loader, test_loader, board, tocg, D) | |
# Save Checkpoint | |
save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_final.pth'),opt) | |
save_checkpoint(D, os.path.join(opt.checkpoint_dir, opt.name, 'D_final.pth'),opt) | |
print("Finished training %s!" % opt.name) | |
if __name__ == "__main__": | |
main() |