Spaces:
Configuration error
Configuration error
File size: 10,089 Bytes
0034848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
from __future__ import absolute_import
import json
import typing
import warnings
from abc import ABC, ABCMeta, abstractmethod
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
try:
import yaml
yaml_available = True
except ImportError:
yaml_available = False
from custom_albumentations import __version__
__all__ = ["to_dict", "from_dict", "save", "load"]
SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
NON_SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
def shorten_class_name(class_fullname: str) -> str:
splitted = class_fullname.split(".")
if len(splitted) == 1:
return class_fullname
top_module, *_, class_name = splitted
if top_module == "albumentations":
return class_name
return class_fullname
def get_shortest_class_fullname(cls: Type) -> str:
class_fullname = "{cls.__module__}.{cls.__name__}".format(cls=cls)
return shorten_class_name(class_fullname)
class SerializableMeta(ABCMeta):
"""
A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY`
so they can be found later while deserializing transformation pipeline using classes full names.
"""
def __new__(mcs, name: str, bases: Tuple[type, ...], *args, **kwargs) -> "SerializableMeta":
cls_obj = super().__new__(mcs, name, bases, *args, **kwargs)
if name != "Serializable" and ABC not in bases:
if cls_obj.is_serializable():
SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
else:
NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
return cls_obj
@classmethod
def is_serializable(mcs) -> bool:
return False
@classmethod
def get_class_fullname(mcs) -> str:
return get_shortest_class_fullname(mcs)
@classmethod
def _to_dict(mcs) -> Dict[str, Any]:
return {}
class Serializable(metaclass=SerializableMeta):
@classmethod
@abstractmethod
def is_serializable(cls) -> bool:
raise NotImplementedError
@classmethod
@abstractmethod
def get_class_fullname(cls) -> str:
raise NotImplementedError
@abstractmethod
def _to_dict(self) -> Dict[str, Any]:
raise NotImplementedError
def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
"""
Take a transform pipeline and convert it to a serializable representation that uses only standard
python data types: dictionaries, lists, strings, integers, and floats.
Args:
self: A transform that should be serialized. If the transform doesn't implement the `to_dict`
method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
but no transform parameters will be serialized.
on_not_implemented_error (str): `raise` or `warn`.
"""
if on_not_implemented_error not in {"raise", "warn"}:
raise ValueError(
"Unknown on_not_implemented_error value: {}. Supported values are: 'raise' and 'warn'".format(
on_not_implemented_error
)
)
try:
transform_dict = self._to_dict()
except NotImplementedError as e:
if on_not_implemented_error == "raise":
raise e
transform_dict = {}
warnings.warn(
"Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. "
"Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' "
"method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__)
)
return {"__version__": __version__, "transform": transform_dict}
def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
"""
Take a transform pipeline and convert it to a serializable representation that uses only standard
python data types: dictionaries, lists, strings, integers, and floats.
Args:
transform: A transform that should be serialized. If the transform doesn't implement the `to_dict`
method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
but no transform parameters will be serialized.
on_not_implemented_error (str): `raise` or `warn`.
"""
return transform.to_dict(on_not_implemented_error)
def instantiate_nonserializable(
transform: Dict[str, Any], nonserializable: Optional[Dict[str, Any]] = None
) -> Optional[Serializable]:
if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY:
name = transform["__name__"]
if nonserializable is None:
raise ValueError(
"To deserialize a non-serializable transform with name {name} you need to pass a dict with"
"this transform as the `lambda_transforms` argument".format(name=name)
)
result_transform = nonserializable.get(name)
if transform is None:
raise ValueError(
"Non-serializable transform with {name} was not found in `nonserializable`".format(name=name)
)
return result_transform
return None
def from_dict(
transform_dict: Dict[str, Any],
nonserializable: Optional[Dict[str, Any]] = None,
lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
) -> Optional[Serializable]:
"""
Args:
transform_dict (dict): A dictionary with serialized transform pipeline.
nonserializable (dict): A dictionary that contains non-serializable transforms.
This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
Keys in that dictionary should be named same as `name` arguments in respective transforms from
a serialized pipeline.
lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
"""
if lambda_transforms != "deprecated":
warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
register_additional_transforms()
transform = transform_dict["transform"]
lmbd = instantiate_nonserializable(transform, nonserializable)
if lmbd:
return lmbd
name = transform["__class_fullname__"]
args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)]
if "transforms" in args:
args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]]
return cls(**args)
def check_data_format(data_format: str) -> None:
if data_format not in {"json", "yaml"}:
raise ValueError("Unknown data_format {}. Supported formats are: 'json' and 'yaml'".format(data_format))
def save(
transform: Serializable, filepath: str, data_format: str = "json", on_not_implemented_error: str = "raise"
) -> None:
"""
Take a transform pipeline, serialize it and save a serialized version to a file
using either json or yaml format.
Args:
transform (obj): Transform to serialize.
filepath (str): Filepath to write to.
data_format (str): Serialization format. Should be either `json` or 'yaml'.
on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement
the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be
ignored and no transform arguments will be saved.
"""
check_data_format(data_format)
transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
dump_fn = json.dump if data_format == "json" else yaml.safe_dump
with open(filepath, "w") as f:
dump_fn(transform_dict, f) # type: ignore
def load(
filepath: str,
data_format: str = "json",
nonserializable: Optional[Dict[str, Any]] = None,
lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
) -> object:
"""
Load a serialized pipeline from a json or yaml file and construct a transform pipeline.
Args:
filepath (str): Filepath to read from.
data_format (str): Serialization format. Should be either `json` or 'yaml'.
nonserializable (dict): A dictionary that contains non-serializable transforms.
This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
Keys in that dictionary should be named same as `name` arguments in respective transforms from
a serialized pipeline.
lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
"""
if lambda_transforms != "deprecated":
warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
check_data_format(data_format)
load_fn = json.load if data_format == "json" else yaml.safe_load
with open(filepath) as f:
transform_dict = load_fn(f) # type: ignore
return from_dict(transform_dict, nonserializable=nonserializable)
def register_additional_transforms() -> None:
"""
Register transforms that are not imported directly into the `albumentations` module.
"""
try:
# This import will result in ImportError if `torch` is not installed
import custom_albumentations.pytorch
except ImportError:
pass
|