File size: 10,089 Bytes
0034848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from __future__ import absolute_import

import json
import typing
import warnings
from abc import ABC, ABCMeta, abstractmethod
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union

try:
    import yaml

    yaml_available = True
except ImportError:
    yaml_available = False


from custom_albumentations import __version__

__all__ = ["to_dict", "from_dict", "save", "load"]


SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
NON_SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}


def shorten_class_name(class_fullname: str) -> str:
    splitted = class_fullname.split(".")
    if len(splitted) == 1:
        return class_fullname
    top_module, *_, class_name = splitted
    if top_module == "albumentations":
        return class_name
    return class_fullname


def get_shortest_class_fullname(cls: Type) -> str:
    class_fullname = "{cls.__module__}.{cls.__name__}".format(cls=cls)
    return shorten_class_name(class_fullname)


class SerializableMeta(ABCMeta):
    """
    A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY`
    so they can be found later while deserializing transformation pipeline using classes full names.
    """

    def __new__(mcs, name: str, bases: Tuple[type, ...], *args, **kwargs) -> "SerializableMeta":
        cls_obj = super().__new__(mcs, name, bases, *args, **kwargs)
        if name != "Serializable" and ABC not in bases:
            if cls_obj.is_serializable():
                SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
            else:
                NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
        return cls_obj

    @classmethod
    def is_serializable(mcs) -> bool:
        return False

    @classmethod
    def get_class_fullname(mcs) -> str:
        return get_shortest_class_fullname(mcs)

    @classmethod
    def _to_dict(mcs) -> Dict[str, Any]:
        return {}


class Serializable(metaclass=SerializableMeta):
    @classmethod
    @abstractmethod
    def is_serializable(cls) -> bool:
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def get_class_fullname(cls) -> str:
        raise NotImplementedError

    @abstractmethod
    def _to_dict(self) -> Dict[str, Any]:
        raise NotImplementedError

    def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
        """
        Take a transform pipeline and convert it to a serializable representation that uses only standard
        python data types: dictionaries, lists, strings, integers, and floats.

        Args:
            self: A transform that should be serialized. If the transform doesn't implement the `to_dict`
                method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
                If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
                but no transform parameters will be serialized.
            on_not_implemented_error (str): `raise` or `warn`.
        """
        if on_not_implemented_error not in {"raise", "warn"}:
            raise ValueError(
                "Unknown on_not_implemented_error value: {}. Supported values are: 'raise' and 'warn'".format(
                    on_not_implemented_error
                )
            )
        try:
            transform_dict = self._to_dict()
        except NotImplementedError as e:
            if on_not_implemented_error == "raise":
                raise e

            transform_dict = {}
            warnings.warn(
                "Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. "
                "Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' "
                "method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__)
            )
        return {"__version__": __version__, "transform": transform_dict}


def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
    """
    Take a transform pipeline and convert it to a serializable representation that uses only standard
    python data types: dictionaries, lists, strings, integers, and floats.

    Args:
        transform: A transform that should be serialized. If the transform doesn't implement the `to_dict`
            method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
            If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
            but no transform parameters will be serialized.
        on_not_implemented_error (str): `raise` or `warn`.
    """
    return transform.to_dict(on_not_implemented_error)


def instantiate_nonserializable(
    transform: Dict[str, Any], nonserializable: Optional[Dict[str, Any]] = None
) -> Optional[Serializable]:
    if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY:
        name = transform["__name__"]
        if nonserializable is None:
            raise ValueError(
                "To deserialize a non-serializable transform with name {name} you need to pass a dict with"
                "this transform as the `lambda_transforms` argument".format(name=name)
            )
        result_transform = nonserializable.get(name)
        if transform is None:
            raise ValueError(
                "Non-serializable transform with {name} was not found in `nonserializable`".format(name=name)
            )
        return result_transform
    return None


def from_dict(
    transform_dict: Dict[str, Any],
    nonserializable: Optional[Dict[str, Any]] = None,
    lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
) -> Optional[Serializable]:
    """
    Args:
        transform_dict (dict): A dictionary with serialized transform pipeline.
        nonserializable (dict): A dictionary that contains non-serializable transforms.
            This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
            Keys in that dictionary should be named same as `name` arguments in respective transforms from
            a serialized pipeline.
        lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
    """
    if lambda_transforms != "deprecated":
        warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
        nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)

    register_additional_transforms()
    transform = transform_dict["transform"]
    lmbd = instantiate_nonserializable(transform, nonserializable)
    if lmbd:
        return lmbd
    name = transform["__class_fullname__"]
    args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
    cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)]
    if "transforms" in args:
        args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]]
    return cls(**args)


def check_data_format(data_format: str) -> None:
    if data_format not in {"json", "yaml"}:
        raise ValueError("Unknown data_format {}. Supported formats are: 'json' and 'yaml'".format(data_format))


def save(
    transform: Serializable, filepath: str, data_format: str = "json", on_not_implemented_error: str = "raise"
) -> None:
    """
    Take a transform pipeline, serialize it and save a serialized version to a file
    using either json or yaml format.

    Args:
        transform (obj): Transform to serialize.
        filepath (str): Filepath to write to.
        data_format (str): Serialization format. Should be either `json` or 'yaml'.
        on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement
            the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be
            ignored and no transform arguments will be saved.
    """
    check_data_format(data_format)
    transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
    dump_fn = json.dump if data_format == "json" else yaml.safe_dump
    with open(filepath, "w") as f:
        dump_fn(transform_dict, f)  # type: ignore


def load(
    filepath: str,
    data_format: str = "json",
    nonserializable: Optional[Dict[str, Any]] = None,
    lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
) -> object:
    """
    Load a serialized pipeline from a json or yaml file and construct a transform pipeline.

    Args:
        filepath (str): Filepath to read from.
        data_format (str): Serialization format. Should be either `json` or 'yaml'.
        nonserializable (dict): A dictionary that contains non-serializable transforms.
            This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
            Keys in that dictionary should be named same as `name` arguments in respective transforms from
            a serialized pipeline.
        lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
    """
    if lambda_transforms != "deprecated":
        warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
        nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)

    check_data_format(data_format)
    load_fn = json.load if data_format == "json" else yaml.safe_load
    with open(filepath) as f:
        transform_dict = load_fn(f)  # type: ignore

    return from_dict(transform_dict, nonserializable=nonserializable)


def register_additional_transforms() -> None:
    """
    Register transforms that are not imported directly into the `albumentations` module.
    """
    try:
        # This import will result in ImportError if `torch` is not installed
        import custom_albumentations.pytorch
    except ImportError:
        pass