|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""YACS -- Yet Another Configuration System is designed to be a simple |
|
configuration management system for academic and industrial research |
|
projects. |
|
|
|
See README.md for usage and examples. |
|
""" |
|
|
|
import copy |
|
import io |
|
import logging |
|
import os |
|
import sys |
|
from ast import literal_eval |
|
|
|
import yaml |
|
|
|
|
|
|
|
_PY2 = sys.version_info.major == 2 |
|
|
|
|
|
_YAML_EXTS = {"", ".yaml", ".yml"} |
|
_PY_EXTS = {".py"} |
|
|
|
_FILE_TYPES = (io.IOBase,) |
|
|
|
|
|
_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)} |
|
|
|
if _PY2: |
|
_VALID_TYPES = _VALID_TYPES.union({unicode}) |
|
|
|
|
|
if _PY2: |
|
|
|
import imp |
|
else: |
|
import importlib.util |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CfgNode(dict): |
|
""" |
|
CfgNode represents an internal node in the configuration tree. It's a simple |
|
dict-like container that allows for attribute-based access to keys. |
|
""" |
|
|
|
IMMUTABLE = "__immutable__" |
|
DEPRECATED_KEYS = "__deprecated_keys__" |
|
RENAMED_KEYS = "__renamed_keys__" |
|
NEW_ALLOWED = "__new_allowed__" |
|
|
|
def __init__(self, init_dict=None, key_list=None, new_allowed=False): |
|
""" |
|
Args: |
|
init_dict (dict): the possibly-nested dictionary to initailize the |
|
CfgNode. |
|
key_list (list[str]): a list of names which index this CfgNode from |
|
the root. |
|
Currently only used for logging purposes. |
|
new_allowed (bool): whether adding new key is allowed when merging with |
|
other configs. |
|
""" |
|
|
|
init_dict = {} if init_dict is None else init_dict |
|
key_list = [] if key_list is None else key_list |
|
init_dict = self._create_config_tree_from_dict(init_dict, key_list) |
|
super(CfgNode, self).__init__(init_dict) |
|
|
|
self.__dict__[CfgNode.IMMUTABLE] = False |
|
|
|
|
|
|
|
self.__dict__[CfgNode.DEPRECATED_KEYS] = set() |
|
|
|
|
|
|
|
|
|
|
|
self.__dict__[CfgNode.RENAMED_KEYS] = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed |
|
|
|
@classmethod |
|
def _create_config_tree_from_dict(cls, dic, key_list): |
|
""" |
|
Create a configuration tree using the given dict. |
|
Any dict-like objects inside dict will be treated as a new CfgNode. |
|
|
|
Args: |
|
dic (dict): |
|
key_list (list[str]): a list of names which index this CfgNode from |
|
the root. Currently only used for logging purposes. |
|
""" |
|
dic = copy.deepcopy(dic) |
|
for k, v in dic.items(): |
|
if isinstance(v, dict): |
|
|
|
dic[k] = cls(v, key_list=key_list + [k]) |
|
else: |
|
|
|
_assert_with_logging( |
|
_valid_type(v, allow_cfg_node=False), |
|
"Key {} with value {} is not a valid type; valid types: {}".format( |
|
".".join(key_list + [k]), type(v), _VALID_TYPES |
|
), |
|
) |
|
return dic |
|
|
|
def __getattr__(self, name): |
|
if name in self: |
|
return self[name] |
|
else: |
|
raise AttributeError(name) |
|
|
|
def __setattr__(self, name, value): |
|
if self.is_frozen(): |
|
raise AttributeError( |
|
"Attempted to set {} to {}, but CfgNode is immutable".format( |
|
name, value |
|
) |
|
) |
|
|
|
_assert_with_logging( |
|
name not in self.__dict__, |
|
"Invalid attempt to modify internal CfgNode state: {}".format(name), |
|
) |
|
_assert_with_logging( |
|
_valid_type(value, allow_cfg_node=True), |
|
"Invalid type {} for key {}; valid types = {}".format( |
|
type(value), name, _VALID_TYPES |
|
), |
|
) |
|
|
|
self[name] = value |
|
|
|
def __str__(self): |
|
def _indent(s_, num_spaces): |
|
s = s_.split("\n") |
|
if len(s) == 1: |
|
return s_ |
|
first = s.pop(0) |
|
s = [(num_spaces * " ") + line for line in s] |
|
s = "\n".join(s) |
|
s = first + "\n" + s |
|
return s |
|
|
|
r = "" |
|
s = [] |
|
for k, v in sorted(self.items()): |
|
seperator = "\n" if isinstance(v, CfgNode) else " " |
|
attr_str = "{}:{}{}".format(str(k), seperator, str(v)) |
|
attr_str = _indent(attr_str, 2) |
|
s.append(attr_str) |
|
r += "\n".join(s) |
|
return r |
|
|
|
def __repr__(self): |
|
return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) |
|
|
|
def dump(self, **kwargs): |
|
"""Dump to a string.""" |
|
|
|
def convert_to_dict(cfg_node, key_list): |
|
if not isinstance(cfg_node, CfgNode): |
|
_assert_with_logging( |
|
_valid_type(cfg_node), |
|
"Key {} with value {} is not a valid type; valid types: {}".format( |
|
".".join(key_list), type(cfg_node), _VALID_TYPES |
|
), |
|
) |
|
return cfg_node |
|
else: |
|
cfg_dict = dict(cfg_node) |
|
for k, v in cfg_dict.items(): |
|
cfg_dict[k] = convert_to_dict(v, key_list + [k]) |
|
return cfg_dict |
|
|
|
self_as_dict = convert_to_dict(self, []) |
|
return yaml.safe_dump(self_as_dict, **kwargs) |
|
|
|
def merge_from_file(self, cfg_filename): |
|
"""Load a yaml config file and merge it this CfgNode.""" |
|
with open(cfg_filename, "r", encoding="utf-8") as f: |
|
cfg = self.load_cfg(f) |
|
self.merge_from_other_cfg(cfg) |
|
|
|
def merge_from_other_cfg(self, cfg_other): |
|
"""Merge `cfg_other` into this CfgNode.""" |
|
_merge_a_into_b(cfg_other, self, self, []) |
|
|
|
def merge_from_list(self, cfg_list): |
|
"""Merge config (keys, values) in a list (e.g., from command line) into |
|
this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. |
|
""" |
|
_assert_with_logging( |
|
len(cfg_list) % 2 == 0, |
|
"Override list has odd length: {}; it must be a list of pairs".format( |
|
cfg_list |
|
), |
|
) |
|
root = self |
|
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): |
|
if root.key_is_deprecated(full_key): |
|
continue |
|
if root.key_is_renamed(full_key): |
|
root.raise_key_rename_error(full_key) |
|
key_list = full_key.split(".") |
|
d = self |
|
for subkey in key_list[:-1]: |
|
_assert_with_logging( |
|
subkey in d, "Non-existent key: {}".format(full_key) |
|
) |
|
d = d[subkey] |
|
subkey = key_list[-1] |
|
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) |
|
value = self._decode_cfg_value(v) |
|
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) |
|
d[subkey] = value |
|
|
|
def freeze(self): |
|
"""Make this CfgNode and all of its children immutable.""" |
|
self._immutable(True) |
|
|
|
def defrost(self): |
|
"""Make this CfgNode and all of its children mutable.""" |
|
self._immutable(False) |
|
|
|
def is_frozen(self): |
|
"""Return mutability.""" |
|
return self.__dict__[CfgNode.IMMUTABLE] |
|
|
|
def _immutable(self, is_immutable): |
|
"""Set immutability to is_immutable and recursively apply the setting |
|
to all nested CfgNodes. |
|
""" |
|
self.__dict__[CfgNode.IMMUTABLE] = is_immutable |
|
|
|
for v in self.__dict__.values(): |
|
if isinstance(v, CfgNode): |
|
v._immutable(is_immutable) |
|
for v in self.values(): |
|
if isinstance(v, CfgNode): |
|
v._immutable(is_immutable) |
|
|
|
def clone(self): |
|
"""Recursively copy this CfgNode.""" |
|
return copy.deepcopy(self) |
|
|
|
def register_deprecated_key(self, key): |
|
"""Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated |
|
keys a warning is generated and the key is ignored. |
|
""" |
|
_assert_with_logging( |
|
key not in self.__dict__[CfgNode.DEPRECATED_KEYS], |
|
"key {} is already registered as a deprecated key".format(key), |
|
) |
|
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) |
|
|
|
def register_renamed_key(self, old_name, new_name, message=None): |
|
"""Register a key as having been renamed from `old_name` to `new_name`. |
|
When merging a renamed key, an exception is thrown alerting to user to |
|
the fact that the key has been renamed. |
|
""" |
|
_assert_with_logging( |
|
old_name not in self.__dict__[CfgNode.RENAMED_KEYS], |
|
"key {} is already registered as a renamed cfg key".format(old_name), |
|
) |
|
value = new_name |
|
if message: |
|
value = (new_name, message) |
|
self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value |
|
|
|
def key_is_deprecated(self, full_key): |
|
"""Test if a key is deprecated.""" |
|
if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: |
|
logger.warning("Deprecated config key (ignoring): {}".format(full_key)) |
|
return True |
|
return False |
|
|
|
def key_is_renamed(self, full_key): |
|
"""Test if a key is renamed.""" |
|
return full_key in self.__dict__[CfgNode.RENAMED_KEYS] |
|
|
|
def raise_key_rename_error(self, full_key): |
|
new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] |
|
if isinstance(new_key, tuple): |
|
msg = " Note: " + new_key[1] |
|
new_key = new_key[0] |
|
else: |
|
msg = "" |
|
raise KeyError( |
|
"Key {} was renamed to {}; please update your config.{}".format( |
|
full_key, new_key, msg |
|
) |
|
) |
|
|
|
def is_new_allowed(self): |
|
return self.__dict__[CfgNode.NEW_ALLOWED] |
|
|
|
@classmethod |
|
def load_cfg(cls, cfg_file_obj_or_str): |
|
""" |
|
Load a cfg. |
|
Args: |
|
cfg_file_obj_or_str (str or file): |
|
Supports loading from: |
|
- A file object backed by a YAML file |
|
- A file object backed by a Python source file that exports an attribute |
|
"cfg" that is either a dict or a CfgNode |
|
- A string that can be parsed as valid YAML |
|
""" |
|
_assert_with_logging( |
|
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), |
|
"Expected first argument to be of type {} or {}, but it was {}".format( |
|
_FILE_TYPES, str, type(cfg_file_obj_or_str) |
|
), |
|
) |
|
if isinstance(cfg_file_obj_or_str, str): |
|
return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) |
|
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): |
|
return cls._load_cfg_from_file(cfg_file_obj_or_str) |
|
else: |
|
raise NotImplementedError("Impossible to reach here (unless there's a bug)") |
|
|
|
@classmethod |
|
def _load_cfg_from_file(cls, file_obj): |
|
"""Load a config from a YAML file or a Python source file.""" |
|
_, file_extension = os.path.splitext(file_obj.name) |
|
if file_extension in _YAML_EXTS: |
|
return cls._load_cfg_from_yaml_str(file_obj.read()) |
|
elif file_extension in _PY_EXTS: |
|
return cls._load_cfg_py_source(file_obj.name) |
|
else: |
|
raise Exception( |
|
"Attempt to load from an unsupported file type {}; " |
|
"only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) |
|
) |
|
|
|
@classmethod |
|
def _load_cfg_from_yaml_str(cls, str_obj): |
|
"""Load a config from a YAML string encoding.""" |
|
cfg_as_dict = yaml.safe_load(str_obj) |
|
return cls(cfg_as_dict) |
|
|
|
@classmethod |
|
def _load_cfg_py_source(cls, filename): |
|
"""Load a config from a Python source file.""" |
|
module = _load_module_from_file("yacs.config.override", filename) |
|
_assert_with_logging( |
|
hasattr(module, "cfg"), |
|
"Python module from file {} must have 'cfg' attr".format(filename), |
|
) |
|
VALID_ATTR_TYPES = {dict, CfgNode} |
|
_assert_with_logging( |
|
type(module.cfg) in VALID_ATTR_TYPES, |
|
"Imported module 'cfg' attr must be in {} but is {} instead".format( |
|
VALID_ATTR_TYPES, type(module.cfg) |
|
), |
|
) |
|
return cls(module.cfg) |
|
|
|
@classmethod |
|
def _decode_cfg_value(cls, value): |
|
""" |
|
Decodes a raw config value (e.g., from a yaml config files or command |
|
line argument) into a Python object. |
|
|
|
If the value is a dict, it will be interpreted as a new CfgNode. |
|
If the value is a str, it will be evaluated as literals. |
|
Otherwise it is returned as-is. |
|
""" |
|
|
|
|
|
if isinstance(value, dict): |
|
return cls(value) |
|
|
|
if not isinstance(value, str): |
|
return value |
|
|
|
|
|
try: |
|
value = literal_eval(value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except ValueError: |
|
pass |
|
except SyntaxError: |
|
pass |
|
return value |
|
|
|
|
|
load_cfg = ( |
|
CfgNode.load_cfg |
|
) |
|
|
|
|
|
def _valid_type(value, allow_cfg_node=False): |
|
return (type(value) in _VALID_TYPES) or ( |
|
allow_cfg_node and isinstance(value, CfgNode) |
|
) |
|
|
|
|
|
def _merge_a_into_b(a, b, root, key_list): |
|
"""Merge config dictionary a into config dictionary b, clobbering the |
|
options in b whenever they are also specified in a. |
|
""" |
|
_assert_with_logging( |
|
isinstance(a, CfgNode), |
|
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), |
|
) |
|
_assert_with_logging( |
|
isinstance(b, CfgNode), |
|
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), |
|
) |
|
|
|
for k, v_ in a.items(): |
|
full_key = ".".join(key_list + [k]) |
|
|
|
v = copy.deepcopy(v_) |
|
v = b._decode_cfg_value(v) |
|
|
|
if k in b: |
|
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) |
|
|
|
if isinstance(v, CfgNode): |
|
try: |
|
_merge_a_into_b(v, b[k], root, key_list + [k]) |
|
except BaseException: |
|
raise |
|
else: |
|
b[k] = v |
|
elif b.is_new_allowed(): |
|
b[k] = v |
|
else: |
|
if root.key_is_deprecated(full_key): |
|
continue |
|
elif root.key_is_renamed(full_key): |
|
root.raise_key_rename_error(full_key) |
|
else: |
|
raise KeyError("Non-existent config key: {}".format(full_key)) |
|
|
|
|
|
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): |
|
"""Checks that `replacement`, which is intended to replace `original` is of |
|
the right type. The type is correct if it matches exactly or is one of a few |
|
cases in which the type can be easily coerced. |
|
""" |
|
original_type = type(original) |
|
replacement_type = type(replacement) |
|
|
|
|
|
if replacement_type == original_type: |
|
return replacement |
|
|
|
|
|
|
|
def conditional_cast(from_type, to_type): |
|
if replacement_type == from_type and original_type == to_type: |
|
return True, to_type(replacement) |
|
else: |
|
return False, None |
|
|
|
|
|
|
|
casts = [(tuple, list), (list, tuple)] |
|
|
|
try: |
|
casts.append((str, unicode)) |
|
except Exception: |
|
pass |
|
|
|
for (from_type, to_type) in casts: |
|
converted, converted_value = conditional_cast(from_type, to_type) |
|
if converted: |
|
return converted_value |
|
|
|
raise ValueError( |
|
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config " |
|
"key: {}".format( |
|
original_type, replacement_type, original, replacement, full_key |
|
) |
|
) |
|
|
|
|
|
def _assert_with_logging(cond, msg): |
|
if not cond: |
|
logger.debug(msg) |
|
assert cond, msg |
|
|
|
|
|
def _load_module_from_file(name, filename): |
|
if _PY2: |
|
module = imp.load_source(name, filename) |
|
else: |
|
spec = importlib.util.spec_from_file_location(name, filename) |
|
module = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(module) |
|
return module |
|
|