|
import argparse |
|
from pathlib import Path |
|
from util import util |
|
import torch |
|
import models |
|
import data |
|
|
|
|
|
class BaseOptions: |
|
"""This class defines options used during both training and test time. |
|
|
|
It also implements several helper functions such as parsing, printing, and saving the options. |
|
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class. |
|
""" |
|
|
|
def __init__(self): |
|
"""Reset the class; indicates the class hasn't been initailized""" |
|
self.initialized = False |
|
|
|
def initialize(self, parser): |
|
"""Define the common options that are used in both training and test.""" |
|
|
|
parser.add_argument("--dataroot", required=True, help="path to images (should have subfolders trainA, trainB, valA, valB, etc)") |
|
parser.add_argument("--name", type=str, default="experiment_name", help="name of the experiment. It decides where to store samples and models") |
|
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here") |
|
|
|
parser.add_argument("--model", type=str, default="cycle_gan", help="chooses which model to use. [cycle_gan | pix2pix | test | colorization]") |
|
parser.add_argument("--input_nc", type=int, default=3, help="# of input image channels: 3 for RGB and 1 for grayscale") |
|
parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels: 3 for RGB and 1 for grayscale") |
|
parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in the last conv layer") |
|
parser.add_argument("--ndf", type=int, default=64, help="# of discrim filters in the first conv layer") |
|
parser.add_argument("--netD", type=str, default="basic", help="specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator") |
|
parser.add_argument("--netG", type=str, default="resnet_9blocks", help="specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]") |
|
parser.add_argument("--n_layers_D", type=int, default=3, help="only used if netD==n_layers") |
|
parser.add_argument("--norm", type=str, default="instance", help="instance normalization or batch normalization [instance | batch | none | syncbatch]") |
|
parser.add_argument("--init_type", type=str, default="normal", help="network initialization [normal | xavier | kaiming | orthogonal]") |
|
parser.add_argument("--init_gain", type=float, default=0.02, help="scaling factor for normal, xavier and orthogonal.") |
|
parser.add_argument("--no_dropout", action="store_true", help="no dropout for the generator") |
|
|
|
parser.add_argument("--dataset_mode", type=str, default="unaligned", help="chooses how datasets are loaded. [unaligned | aligned | single | colorization]") |
|
parser.add_argument("--direction", type=str, default="AtoB", help="AtoB or BtoA") |
|
parser.add_argument("--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly") |
|
parser.add_argument("--num_threads", default=4, type=int, help="# threads for loading data") |
|
parser.add_argument("--batch_size", type=int, default=1, help="input batch size") |
|
parser.add_argument("--load_size", type=int, default=286, help="scale images to this size") |
|
parser.add_argument("--crop_size", type=int, default=256, help="then crop to this size") |
|
parser.add_argument("--max_dataset_size", type=int, default=float("inf"), help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.") |
|
parser.add_argument("--preprocess", type=str, default="resize_and_crop", help="scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]") |
|
parser.add_argument("--no_flip", action="store_true", help="if specified, do not flip the images for data augmentation") |
|
parser.add_argument("--display_winsize", type=int, default=256, help="display window size for both visdom and HTML") |
|
|
|
parser.add_argument("--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model") |
|
parser.add_argument("--load_iter", type=int, default="0", help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]") |
|
parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information") |
|
parser.add_argument("--suffix", default="", type=str, help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}") |
|
|
|
parser.add_argument("--use_wandb", action="store_true", help="if specified, then init wandb logging") |
|
parser.add_argument("--wandb_project_name", type=str, default="CycleGAN-and-pix2pix", help="specify wandb project name") |
|
self.initialized = True |
|
return parser |
|
|
|
def gather_options(self): |
|
"""Initialize our parser with basic options(only once). |
|
Add additional model-specific and dataset-specific options. |
|
These options are defined in the <modify_commandline_options> function |
|
in model and dataset classes. |
|
""" |
|
if not self.initialized: |
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
parser = self.initialize(parser) |
|
|
|
|
|
opt, _ = parser.parse_known_args() |
|
|
|
|
|
model_name = opt.model |
|
model_option_setter = models.get_option_setter(model_name) |
|
parser = model_option_setter(parser, self.isTrain) |
|
opt, _ = parser.parse_known_args() |
|
|
|
|
|
dataset_name = opt.dataset_mode |
|
dataset_option_setter = data.get_option_setter(dataset_name) |
|
parser = dataset_option_setter(parser, self.isTrain) |
|
|
|
|
|
self.parser = parser |
|
return parser.parse_args() |
|
|
|
def print_options(self, opt): |
|
"""Print and save options |
|
|
|
It will print both current options and default values(if different). |
|
It will save options into a text file / [checkpoints_dir] / opt.txt |
|
""" |
|
message = "" |
|
message += "----------------- Options ---------------\n" |
|
for k, v in sorted(vars(opt).items()): |
|
comment = "" |
|
default = self.parser.get_default(k) |
|
if v != default: |
|
comment = "\t[default: %s]" % str(default) |
|
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) |
|
message += "----------------- End -------------------" |
|
print(message) |
|
|
|
|
|
expr_dir = Path(opt.checkpoints_dir) / opt.name |
|
util.mkdirs(expr_dir) |
|
file_name = expr_dir / f"{opt.phase}_opt.txt" |
|
with open(file_name, "wt") as opt_file: |
|
opt_file.write(message) |
|
opt_file.write("\n") |
|
|
|
def parse(self): |
|
"""Parse our options, create checkpoints directory suffix, and set up gpu device.""" |
|
opt = self.gather_options() |
|
opt.isTrain = self.isTrain |
|
|
|
|
|
if opt.suffix: |
|
suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else "" |
|
opt.name = opt.name + suffix |
|
|
|
self.print_options(opt) |
|
self.opt = opt |
|
return self.opt |
|
|