Spaces:
Configuration error
Configuration error
import random | |
import warnings | |
from typing import Any, Dict, Tuple, Union | |
import numpy as np | |
from custom_albumentations.core.transforms_interface import ImageOnlyTransform | |
from .functional import cutout | |
__all__ = ["Cutout"] | |
class Cutout(ImageOnlyTransform): | |
"""CoarseDropout of the square regions in the image. | |
Args: | |
num_holes (int): number of regions to zero out | |
max_h_size (int): maximum height of the hole | |
max_w_size (int): maximum width of the hole | |
fill_value (int, float, list of int, list of float): value for dropped pixels. | |
Targets: | |
image | |
Image types: | |
uint8, float32 | |
Reference: | |
| https://arxiv.org/abs/1708.04552 | |
| https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py | |
| https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py | |
""" | |
def __init__( | |
self, | |
num_holes: int = 8, | |
max_h_size: int = 8, | |
max_w_size: int = 8, | |
fill_value: Union[int, float] = 0, | |
always_apply: bool = False, | |
p: float = 0.5, | |
): | |
super(Cutout, self).__init__(always_apply, p) | |
self.num_holes = num_holes | |
self.max_h_size = max_h_size | |
self.max_w_size = max_w_size | |
self.fill_value = fill_value | |
warnings.warn( | |
f"{self.__class__.__name__} has been deprecated. Please use CoarseDropout", | |
FutureWarning, | |
) | |
def apply(self, img: np.ndarray, fill_value: Union[int, float] = 0, holes=(), **params): | |
return cutout(img, holes, fill_value) | |
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
img = params["image"] | |
height, width = img.shape[:2] | |
holes = [] | |
for _n in range(self.num_holes): | |
y = random.randint(0, height) | |
x = random.randint(0, width) | |
y1 = np.clip(y - self.max_h_size // 2, 0, height) | |
y2 = np.clip(y1 + self.max_h_size, 0, height) | |
x1 = np.clip(x - self.max_w_size // 2, 0, width) | |
x2 = np.clip(x1 + self.max_w_size, 0, width) | |
holes.append((x1, y1, x2, y2)) | |
return {"holes": holes} | |
def targets_as_params(self): | |
return ["image"] | |
def get_transform_init_args_names(self) -> Tuple[str, ...]: | |
return ("num_holes", "max_h_size", "max_w_size") | |