|
import argparse |
|
import os |
|
|
|
import yaml |
|
|
|
__all__ = ['get_config', 'print_config'] |
|
|
|
|
|
def get_config(args): |
|
|
|
config = dict2namespace(setdefault(_get_raw_config(args.config), _get_raw_config("default.yml"))) |
|
|
|
if not hasattr(config.sampling, "sigma_dist"): |
|
config.sampling.sigma_dist = config.model.sigma_dist |
|
if not hasattr(config.biggan, "resolution"): |
|
config.biggan.resolution = config.data.image_size |
|
|
|
if args.consistent: |
|
config.sampling.consistent = args.consistent |
|
config.sampling.noise_first = False |
|
if args.step_lr: |
|
config.sampling.step_lr = args.step_lr |
|
if args.nsigma != 0: |
|
config.sampling.nsigma = args.nsigma |
|
if args.step_lr != 0: |
|
config.sampling.step_lr = args.step_lr |
|
if args.batch_size != 0: |
|
config.sampling.batch_size = args.batch_size |
|
config.fast_fid.batch_size = args.batch_size |
|
|
|
if args.model_types is not None and len(args.model_types)==1 and args.model_types[0] in [0, 6, 23] and config.data.dataset in ['tinyImages', 'CIFAR10']: |
|
config.sampling.batch_size = min(200, config.sampling.batch_size) |
|
if args.model_types is not None and len(args.model_types) == 1 and args.model_types[0] in [8] and config.data.dataset in ['tinyImages', 'CIFAR10']: |
|
config.sampling.batch_size = min(800, config.sampling.batch_size) |
|
|
|
if args.ODI_steps == -1: |
|
args.ODI_steps = None |
|
if args.fid_num_samples != 0: |
|
config.fast_fid.num_samples = args.fid_num_samples |
|
if args.begin_ckpt != 0: |
|
config.fast_fid.begin_ckpt = args.begin_ckpt |
|
config.sampling.ckpt_id = args.begin_ckpt |
|
if args.end_ckpt != 0: |
|
config.fast_fid.end_ckpt = args.begin_ckpt |
|
if args.adam: |
|
config.optim.beta1 = args.adam_beta[0] |
|
config.optim.beta2 = args.adam_beta[1] |
|
if args.D_adam: |
|
config.optim.adv_beta1 = args.D_adam_beta[0] |
|
config.optim.adv_beta2 = args.D_adam_beta[1] |
|
if args.D_steps != 0: |
|
config.adversarial.D_steps = args.D_steps |
|
|
|
return config |
|
|
|
|
|
def _get_raw_config(name): |
|
here = os.path.dirname(os.path.abspath(__file__)) |
|
with open(os.path.join(here, name), 'r') as f: |
|
yaml_dict = yaml.load(f, Loader=yaml.FullLoader) |
|
return yaml_dict |
|
|
|
|
|
def setdefault(config, default): |
|
|
|
for x in default: |
|
v = default.get(x) |
|
if isinstance(v, dict) and x in config: |
|
setdefault(config.get(x), v) |
|
else: |
|
config.setdefault(x, v) |
|
return config |
|
|
|
|
|
def dict2namespace(config): |
|
namespace = argparse.Namespace() |
|
for key, value in config.items(): |
|
if isinstance(value, dict): |
|
new_value = dict2namespace(value) |
|
else: |
|
new_value = value |
|
setattr(namespace, key, new_value) |
|
return namespace |
|
|
|
|
|
def print_config(config): |
|
print(">" * 80) |
|
print(yaml.dump(config, default_flow_style=False)) |
|
print("<" * 80) |
|
|
|
|