Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import copy | |
import dataclasses | |
from pathlib import Path | |
from typing import Any, Dict, Optional, Type, TypeVar | |
import yaml | |
from . import util | |
__all__ = ['ConfigBase', 'PathLike'] | |
T = TypeVar('T', bound='ConfigBase') | |
PathLike = util.PathLike | |
def _is_missing(obj: Any) -> bool: | |
return isinstance(obj, type(dataclasses.MISSING)) | |
class ConfigBase: | |
""" | |
Base class of config classes. | |
Subclass may override `_canonical_rules` and `_validation_rules`, | |
and `validate()` if the logic is complex. | |
""" | |
# Rules to convert field value to canonical format. | |
# The key is field name. | |
# The value is callable `value -> canonical_value` | |
# It is not type-hinted so dataclass won't treat it as field | |
_canonical_rules = {} # type: ignore | |
# Rules to validate field value. | |
# The key is field name. | |
# The value is callable `value -> valid` or `value -> (valid, error_message)` | |
# The rule will be called with canonical format and is only called when `value` is not None. | |
# `error_message` is used when `valid` is False. | |
# It will be prepended with class name and field name in exception message. | |
_validation_rules = {} # type: ignore | |
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs): | |
""" | |
Initialize a config object and set some fields. | |
Name of keyword arguments can either be snake_case or camelCase. | |
They will be converted to snake_case automatically. | |
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`. | |
""" | |
if 'basepath' in kwargs: | |
_base_path = kwargs.pop('basepath') | |
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()} | |
if _base_path is None: | |
_base_path = Path() | |
for field in dataclasses.fields(self): | |
value = kwargs.pop(util.case_insensitive(field.name), field.default) | |
if value is not None and not _is_missing(value): | |
# relative paths loaded from config file are not relative to pwd | |
if 'Path' in str(field.type): | |
value = Path(value).expanduser() | |
if not value.is_absolute(): | |
value = _base_path / value | |
setattr(self, field.name, value) | |
if kwargs: | |
cls = type(self).__name__ | |
fields = ', '.join(kwargs.keys()) | |
raise ValueError(f'{cls}: Unrecognized fields {fields}') | |
def load(cls: Type[T], path: PathLike) -> T: | |
""" | |
Load config from YAML (or JSON) file. | |
Keys in YAML file can either be camelCase or snake_case. | |
""" | |
data = yaml.safe_load(open(path)) | |
if not isinstance(data, dict): | |
raise ValueError(f'Content of config file {path} is not a dict/object') | |
return cls(**data, _base_path=Path(path).parent) | |
def json(self) -> Dict[str, Any]: | |
""" | |
Convert config to JSON object. | |
The keys of returned object will be camelCase. | |
""" | |
self.validate() | |
return dataclasses.asdict( | |
self.canonical(), | |
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None) | |
) | |
def canonical(self: T) -> T: | |
""" | |
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format. | |
Noticeably, relative path may be converted to absolute path. | |
""" | |
ret = copy.deepcopy(self) | |
for field in dataclasses.fields(ret): | |
key, value = field.name, getattr(ret, field.name) | |
rule = ret._canonical_rules.get(key) | |
if rule is not None: | |
setattr(ret, key, rule(value)) | |
elif isinstance(value, ConfigBase): | |
setattr(ret, key, value.canonical()) | |
# value will be copied twice, should not be a performance issue anyway | |
elif isinstance(value, Path): | |
setattr(ret, key, str(value)) | |
return ret | |
def validate(self) -> None: | |
""" | |
Validate the config object and raise Exception if it's ill-formed. | |
""" | |
class_name = type(self).__name__ | |
config = self.canonical() | |
for field in dataclasses.fields(config): | |
key, value = field.name, getattr(config, field.name) | |
# check existence | |
if _is_missing(value): | |
raise ValueError(f'{class_name}: {key} is not set') | |
# check type (TODO) | |
type_name = str(field.type).replace('typing.', '') | |
optional = any([ | |
type_name.startswith('Optional['), | |
type_name.startswith('Union[') and 'None' in type_name, | |
type_name == 'Any' | |
]) | |
if value is None: | |
if optional: | |
continue | |
else: | |
raise ValueError(f'{class_name}: {key} cannot be None') | |
# check value | |
rule = config._validation_rules.get(key) | |
if rule is not None: | |
try: | |
result = rule(value) | |
except Exception: | |
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}') | |
if isinstance(result, bool): | |
if not result: | |
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range') | |
else: | |
if not result[0]: | |
raise ValueError(f'{class_name}: {key} {result[1]}') | |
# check nested config | |
if isinstance(value, ConfigBase): | |
value.validate() | |