devilops's picture
initial commit
38e9b9a
import argparse
import os
import yaml
__all__ = ['get_config', 'print_config']
def get_config(args):
config = dict2namespace(setdefault(_get_raw_config(args.config), _get_raw_config("default.yml")))
if not hasattr(config.sampling, "sigma_dist"):
config.sampling.sigma_dist = config.model.sigma_dist
if not hasattr(config.biggan, "resolution"):
config.biggan.resolution = config.data.image_size
if args.consistent:
config.sampling.consistent = args.consistent
config.sampling.noise_first = False
if args.step_lr:
config.sampling.step_lr = args.step_lr
if args.nsigma != 0:
config.sampling.nsigma = args.nsigma
if args.step_lr != 0:
config.sampling.step_lr = args.step_lr
if args.batch_size != 0:
config.sampling.batch_size = args.batch_size
config.fast_fid.batch_size = args.batch_size
# ToDo: experimental and is only using model_types
if args.model_types is not None and len(args.model_types)==1 and args.model_types[0] in [0, 6, 23] and config.data.dataset in ['tinyImages', 'CIFAR10']:
config.sampling.batch_size = min(200, config.sampling.batch_size)
if args.model_types is not None and len(args.model_types) == 1 and args.model_types[0] in [8] and config.data.dataset in ['tinyImages', 'CIFAR10']:
config.sampling.batch_size = min(800, config.sampling.batch_size)
if args.ODI_steps == -1:
args.ODI_steps = None
if args.fid_num_samples != 0:
config.fast_fid.num_samples = args.fid_num_samples
if args.begin_ckpt != 0:
config.fast_fid.begin_ckpt = args.begin_ckpt
config.sampling.ckpt_id = args.begin_ckpt
if args.end_ckpt != 0:
config.fast_fid.end_ckpt = args.begin_ckpt
if args.adam:
config.optim.beta1 = args.adam_beta[0]
config.optim.beta2 = args.adam_beta[1]
if args.D_adam:
config.optim.adv_beta1 = args.D_adam_beta[0]
config.optim.adv_beta2 = args.D_adam_beta[1]
if args.D_steps != 0:
config.adversarial.D_steps = args.D_steps
return config
def _get_raw_config(name):
here = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(here, name), 'r') as f:
yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
return yaml_dict
def setdefault(config, default):
#print('config is', config, 'default is', default)
for x in default:
v = default.get(x)
if isinstance(v, dict) and x in config:
setdefault(config.get(x), v)
else:
config.setdefault(x, v)
return config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def print_config(config):
print(">" * 80)
print(yaml.dump(config, default_flow_style=False))
print("<" * 80)