Spaces:
Configuration error
Configuration error
from __future__ import division | |
import random | |
import typing | |
import warnings | |
from collections import defaultdict | |
import numpy as np | |
from .. import random_utils | |
from .bbox_utils import BboxParams, BboxProcessor | |
from .keypoints_utils import KeypointParams, KeypointsProcessor | |
from .serialization import ( | |
SERIALIZABLE_REGISTRY, | |
Serializable, | |
get_shortest_class_fullname, | |
instantiate_nonserializable, | |
) | |
from .transforms_interface import BasicTransform | |
from .utils import format_args, get_shape | |
__all__ = [ | |
"BaseCompose", | |
"Compose", | |
"SomeOf", | |
"OneOf", | |
"OneOrOther", | |
"BboxParams", | |
"KeypointParams", | |
"ReplayCompose", | |
"Sequential", | |
] | |
REPR_INDENT_STEP = 2 | |
TransformType = typing.Union[BasicTransform, "BaseCompose"] | |
TransformsSeqType = typing.Sequence[TransformType] | |
def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType]) -> TransformsSeqType: | |
new_transforms: typing.List[TransformType] = [] | |
for transform in transforms: # type: ignore | |
if isinstance(transform, BaseCompose): | |
new_transforms.extend(get_always_apply(transform)) | |
elif transform.always_apply: | |
new_transforms.append(transform) | |
return new_transforms | |
class BaseCompose(Serializable): | |
def __init__(self, transforms: TransformsSeqType, p: float): | |
if isinstance(transforms, (BaseCompose, BasicTransform)): | |
warnings.warn( | |
"transforms is single transform, but a sequence is expected! Transform will be wrapped into list." | |
) | |
transforms = [transforms] | |
self.transforms = transforms | |
self.p = p | |
self.replay_mode = False | |
self.applied_in_replay = False | |
def __len__(self) -> int: | |
return len(self.transforms) | |
def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]: | |
raise NotImplementedError | |
def __getitem__(self, item: int) -> TransformType: # type: ignore | |
return self.transforms[item] | |
def __repr__(self) -> str: | |
return self.indented_repr() | |
def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str: | |
args = {k: v for k, v in self._to_dict().items() if not (k.startswith("__") or k == "transforms")} | |
repr_string = self.__class__.__name__ + "([" | |
for t in self.transforms: | |
repr_string += "\n" | |
if hasattr(t, "indented_repr"): | |
t_repr = t.indented_repr(indent + REPR_INDENT_STEP) # type: ignore | |
else: | |
t_repr = repr(t) | |
repr_string += " " * indent + t_repr + "," | |
repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + "], {args})".format(args=format_args(args)) | |
return repr_string | |
def get_class_fullname(cls) -> str: | |
return get_shortest_class_fullname(cls) | |
def is_serializable(cls) -> bool: | |
return True | |
def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
return { | |
"__class_fullname__": self.get_class_fullname(), | |
"p": self.p, | |
"transforms": [t._to_dict() for t in self.transforms], # skipcq: PYL-W0212 | |
} | |
def get_dict_with_id(self) -> typing.Dict[str, typing.Any]: | |
return { | |
"__class_fullname__": self.get_class_fullname(), | |
"id": id(self), | |
"params": None, | |
"transforms": [t.get_dict_with_id() for t in self.transforms], | |
} | |
def add_targets(self, additional_targets: typing.Optional[typing.Dict[str, str]]) -> None: | |
if additional_targets: | |
for t in self.transforms: | |
t.add_targets(additional_targets) | |
def set_deterministic(self, flag: bool, save_key: str = "replay") -> None: | |
for t in self.transforms: | |
t.set_deterministic(flag, save_key) | |
class Compose(BaseCompose): | |
"""Compose transforms and handle all transformations regarding bounding boxes | |
Args: | |
transforms (list): list of transformations to compose. | |
bbox_params (BboxParams): Parameters for bounding boxes transforms | |
keypoint_params (KeypointParams): Parameters for keypoints transforms | |
additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'} | |
p (float): probability of applying all list of transforms. Default: 1.0. | |
is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you | |
would like to disable this check - pass False (do it only if you are sure in your data consistency). | |
""" | |
def __init__( | |
self, | |
transforms: TransformsSeqType, | |
bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None, | |
keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None, | |
additional_targets: typing.Optional[typing.Dict[str, str]] = None, | |
p: float = 1.0, | |
is_check_shapes: bool = True, | |
): | |
super(Compose, self).__init__(transforms, p) | |
self.processors: typing.Dict[str, typing.Union[BboxProcessor, KeypointsProcessor]] = {} | |
if bbox_params: | |
if isinstance(bbox_params, dict): | |
b_params = BboxParams(**bbox_params) | |
elif isinstance(bbox_params, BboxParams): | |
b_params = bbox_params | |
else: | |
raise ValueError("unknown format of bbox_params, please use `dict` or `BboxParams`") | |
self.processors["bboxes"] = BboxProcessor(b_params, additional_targets) | |
if keypoint_params: | |
if isinstance(keypoint_params, dict): | |
k_params = KeypointParams(**keypoint_params) | |
elif isinstance(keypoint_params, KeypointParams): | |
k_params = keypoint_params | |
else: | |
raise ValueError("unknown format of keypoint_params, please use `dict` or `KeypointParams`") | |
self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets) | |
if additional_targets is None: | |
additional_targets = {} | |
self.additional_targets = additional_targets | |
for proc in self.processors.values(): | |
proc.ensure_transforms_valid(self.transforms) | |
self.add_targets(additional_targets) | |
self.is_check_args = True | |
self._disable_check_args_for_transforms(self.transforms) | |
self.is_check_shapes = is_check_shapes | |
def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None: | |
for transform in transforms: | |
if isinstance(transform, BaseCompose): | |
Compose._disable_check_args_for_transforms(transform.transforms) | |
if isinstance(transform, Compose): | |
transform._disable_check_args() | |
def _disable_check_args(self) -> None: | |
self.is_check_args = False | |
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
if args: | |
raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") | |
if self.is_check_args: | |
self._check_args(**data) | |
assert isinstance(force_apply, (bool, int)), "force_apply must have bool or int type" | |
need_to_run = force_apply or random.random() < self.p | |
for p in self.processors.values(): | |
p.ensure_data_valid(data) | |
transforms = self.transforms if need_to_run else get_always_apply(self.transforms) | |
check_each_transform = any( | |
getattr(item.params, "check_each_transform", False) for item in self.processors.values() | |
) | |
for p in self.processors.values(): | |
p.preprocess(data) | |
for idx, t in enumerate(transforms): | |
data = t(**data) | |
if check_each_transform: | |
data = self._check_data_post_transform(data) | |
data = Compose._make_targets_contiguous(data) # ensure output targets are contiguous | |
for p in self.processors.values(): | |
p.postprocess(data) | |
return data | |
def _check_data_post_transform(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: | |
rows, cols = get_shape(data["image"]) | |
for p in self.processors.values(): | |
if not getattr(p.params, "check_each_transform", False): | |
continue | |
for data_name in p.data_fields: | |
data[data_name] = p.filter(data[data_name], rows, cols) | |
return data | |
def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
dictionary = super(Compose, self)._to_dict() | |
bbox_processor = self.processors.get("bboxes") | |
keypoints_processor = self.processors.get("keypoints") | |
dictionary.update( | |
{ | |
"bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212 | |
"keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212 | |
if keypoints_processor | |
else None, | |
"additional_targets": self.additional_targets, | |
"is_check_shapes": self.is_check_shapes, | |
} | |
) | |
return dictionary | |
def get_dict_with_id(self) -> typing.Dict[str, typing.Any]: | |
dictionary = super().get_dict_with_id() | |
bbox_processor = self.processors.get("bboxes") | |
keypoints_processor = self.processors.get("keypoints") | |
dictionary.update( | |
{ | |
"bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212 | |
"keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212 | |
if keypoints_processor | |
else None, | |
"additional_targets": self.additional_targets, | |
"params": None, | |
"is_check_shapes": self.is_check_shapes, | |
} | |
) | |
return dictionary | |
def _check_args(self, **kwargs) -> None: | |
checked_single = ["image", "mask"] | |
checked_multi = ["masks"] | |
check_bbox_param = ["bboxes"] | |
# ["bboxes", "keypoints"] could be almost any type, no need to check them | |
shapes = [] | |
for data_name, data in kwargs.items(): | |
internal_data_name = self.additional_targets.get(data_name, data_name) | |
if internal_data_name in checked_single: | |
if not isinstance(data, np.ndarray): | |
raise TypeError("{} must be numpy array type".format(data_name)) | |
shapes.append(data.shape[:2]) | |
if internal_data_name in checked_multi: | |
if data is not None and len(data): | |
if not isinstance(data[0], np.ndarray): | |
raise TypeError("{} must be list of numpy arrays".format(data_name)) | |
shapes.append(data[0].shape[:2]) | |
if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None: | |
raise ValueError("bbox_params must be specified for bbox transformations") | |
if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes): | |
raise ValueError( | |
"Height and Width of image, mask or masks should be equal. You can disable shapes check " | |
"by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure " | |
"about your data consistency)." | |
) | |
def _make_targets_contiguous(data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: | |
result = {} | |
for key, value in data.items(): | |
if isinstance(value, np.ndarray): | |
value = np.ascontiguousarray(value) | |
result[key] = value | |
return result | |
class OneOf(BaseCompose): | |
"""Select one of transforms to apply. Selected transform will be called with `force_apply=True`. | |
Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. | |
Args: | |
transforms (list): list of transformations to compose. | |
p (float): probability of applying selected transform. Default: 0.5. | |
""" | |
def __init__(self, transforms: TransformsSeqType, p: float = 0.5): | |
super(OneOf, self).__init__(transforms, p) | |
transforms_ps = [t.p for t in self.transforms] | |
s = sum(transforms_ps) | |
self.transforms_ps = [t / s for t in transforms_ps] | |
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
if self.replay_mode: | |
for t in self.transforms: | |
data = t(**data) | |
return data | |
if self.transforms_ps and (force_apply or random.random() < self.p): | |
idx: int = random_utils.choice(len(self.transforms), p=self.transforms_ps) | |
t = self.transforms[idx] | |
data = t(force_apply=True, **data) | |
return data | |
class SomeOf(BaseCompose): | |
"""Select N transforms to apply. Selected transforms will be called with `force_apply=True`. | |
Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. | |
Args: | |
transforms (list): list of transformations to compose. | |
n (int): number of transforms to apply. | |
replace (bool): Whether the sampled transforms are with or without replacement. Default: True. | |
p (float): probability of applying selected transform. Default: 1. | |
""" | |
def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1): | |
super(SomeOf, self).__init__(transforms, p) | |
self.n = n | |
self.replace = replace | |
transforms_ps = [t.p for t in self.transforms] | |
s = sum(transforms_ps) | |
self.transforms_ps = [t / s for t in transforms_ps] | |
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
if self.replay_mode: | |
for t in self.transforms: | |
data = t(**data) | |
return data | |
if self.transforms_ps and (force_apply or random.random() < self.p): | |
idx = random_utils.choice(len(self.transforms), size=self.n, replace=self.replace, p=self.transforms_ps) | |
for i in idx: # type: ignore | |
t = self.transforms[i] | |
data = t(force_apply=True, **data) | |
return data | |
def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
dictionary = super(SomeOf, self)._to_dict() | |
dictionary.update({"n": self.n, "replace": self.replace}) | |
return dictionary | |
class OneOrOther(BaseCompose): | |
"""Select one or another transform to apply. Selected transform will be called with `force_apply=True`.""" | |
def __init__( | |
self, | |
first: typing.Optional[TransformType] = None, | |
second: typing.Optional[TransformType] = None, | |
transforms: typing.Optional[TransformsSeqType] = None, | |
p: float = 0.5, | |
): | |
if transforms is None: | |
if first is None or second is None: | |
raise ValueError("You must set both first and second or set transforms argument.") | |
transforms = [first, second] | |
super(OneOrOther, self).__init__(transforms, p) | |
if len(self.transforms) != 2: | |
warnings.warn("Length of transforms is not equal to 2.") | |
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
if self.replay_mode: | |
for t in self.transforms: | |
data = t(**data) | |
return data | |
if random.random() < self.p: | |
return self.transforms[0](force_apply=True, **data) | |
return self.transforms[-1](force_apply=True, **data) | |
class PerChannel(BaseCompose): | |
"""Apply transformations per-channel | |
Args: | |
transforms (list): list of transformations to compose. | |
channels (sequence): channels to apply the transform to. Pass None to apply to all. | |
Default: None (apply to all) | |
p (float): probability of applying the transform. Default: 0.5. | |
""" | |
def __init__( | |
self, transforms: TransformsSeqType, channels: typing.Optional[typing.Sequence[int]] = None, p: float = 0.5 | |
): | |
super(PerChannel, self).__init__(transforms, p) | |
self.channels = channels | |
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
if force_apply or random.random() < self.p: | |
image = data["image"] | |
# Expand mono images to have a single channel | |
if len(image.shape) == 2: | |
image = np.expand_dims(image, -1) | |
if self.channels is None: | |
self.channels = range(image.shape[2]) | |
for c in self.channels: | |
for t in self.transforms: | |
image[:, :, c] = t(image=image[:, :, c])["image"] | |
data["image"] = image | |
return data | |
class ReplayCompose(Compose): | |
def __init__( | |
self, | |
transforms: TransformsSeqType, | |
bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None, | |
keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None, | |
additional_targets: typing.Optional[typing.Dict[str, str]] = None, | |
p: float = 1.0, | |
is_check_shapes: bool = True, | |
save_key: str = "replay", | |
): | |
super(ReplayCompose, self).__init__( | |
transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes | |
) | |
self.set_deterministic(True, save_key=save_key) | |
self.save_key = save_key | |
def __call__(self, *args, force_apply: bool = False, **kwargs) -> typing.Dict[str, typing.Any]: | |
kwargs[self.save_key] = defaultdict(dict) | |
result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs) | |
serialized = self.get_dict_with_id() | |
self.fill_with_params(serialized, result[self.save_key]) | |
self.fill_applied(serialized) | |
result[self.save_key] = serialized | |
return result | |
def replay(saved_augmentations: typing.Dict[str, typing.Any], **kwargs) -> typing.Dict[str, typing.Any]: | |
augs = ReplayCompose._restore_for_replay(saved_augmentations) | |
return augs(force_apply=True, **kwargs) | |
def _restore_for_replay( | |
transform_dict: typing.Dict[str, typing.Any], lambda_transforms: typing.Optional[dict] = None | |
) -> TransformType: | |
""" | |
Args: | |
lambda_transforms (dict): A dictionary that contains lambda transforms, that | |
is instances of the Lambda class. | |
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys | |
in that dictionary should be named same as `name` arguments in respective lambda transforms from | |
a serialized pipeline. | |
""" | |
applied = transform_dict["applied"] | |
params = transform_dict["params"] | |
lmbd = instantiate_nonserializable(transform_dict, lambda_transforms) | |
if lmbd: | |
transform = lmbd | |
else: | |
name = transform_dict["__class_fullname__"] | |
args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]} | |
cls = SERIALIZABLE_REGISTRY[name] | |
if "transforms" in args: | |
args["transforms"] = [ | |
ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms) | |
for t in args["transforms"] | |
] | |
transform = cls(**args) | |
transform = typing.cast(BasicTransform, transform) | |
if isinstance(transform, BasicTransform): | |
transform.params = params | |
transform.replay_mode = True | |
transform.applied_in_replay = applied | |
return transform | |
def fill_with_params(self, serialized: dict, all_params: dict) -> None: | |
params = all_params.get(serialized.get("id")) | |
serialized["params"] = params | |
del serialized["id"] | |
for transform in serialized.get("transforms", []): | |
self.fill_with_params(transform, all_params) | |
def fill_applied(self, serialized: typing.Dict[str, typing.Any]) -> bool: | |
if "transforms" in serialized: | |
applied = [self.fill_applied(t) for t in serialized["transforms"]] | |
serialized["applied"] = any(applied) | |
else: | |
serialized["applied"] = serialized.get("params") is not None | |
return serialized["applied"] | |
def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
dictionary = super(ReplayCompose, self)._to_dict() | |
dictionary.update({"save_key": self.save_key}) | |
return dictionary | |
class Sequential(BaseCompose): | |
"""Sequentially applies all transforms to targets. | |
Note: | |
This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose` | |
the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to | |
create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly | |
chose sequence to input data (see the `Example` section for an example definition of such pipeline). | |
Example: | |
>>> import custom_albumentations as albumentations as A | |
>>> transform = A.Compose([ | |
>>> A.OneOf([ | |
>>> A.Sequential([ | |
>>> A.HorizontalFlip(p=0.5), | |
>>> A.ShiftScaleRotate(p=0.5), | |
>>> ]), | |
>>> A.Sequential([ | |
>>> A.VerticalFlip(p=0.5), | |
>>> A.RandomBrightnessContrast(p=0.5), | |
>>> ]), | |
>>> ], p=1) | |
>>> ]) | |
""" | |
def __init__(self, transforms: TransformsSeqType, p: float = 0.5): | |
super().__init__(transforms, p) | |
def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]: | |
for t in self.transforms: | |
data = t(**data) | |
return data | |