JasonSmithSO's picture
Upload 777 files
0034848 verified
import random
from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from ...core.transforms_interface import DualTransform, KeypointType
from .functional import cutout
__all__ = ["CoarseDropout"]
class CoarseDropout(DualTransform):
"""CoarseDropout of the rectangular regions in the image.
Args:
max_holes (int): Maximum number of regions to zero out.
max_height (int, float): Maximum height of the hole.
If float, it is calculated as a fraction of the image height.
max_width (int, float): Maximum width of the hole.
If float, it is calculated as a fraction of the image width.
min_holes (int): Minimum number of regions to zero out. If `None`,
`min_holes` is be set to `max_holes`. Default: `None`.
min_height (int, float): Minimum height of the hole. Default: None. If `None`,
`min_height` is set to `max_height`. Default: `None`.
If float, it is calculated as a fraction of the image height.
min_width (int, float): Minimum width of the hole. If `None`, `min_height` is
set to `max_width`. Default: `None`.
If float, it is calculated as a fraction of the image width.
fill_value (int, float, list of int, list of float): value for dropped pixels.
mask_fill_value (int, float, list of int, list of float): fill value for dropped pixels
in mask. If `None` - mask is not affected. Default: `None`.
Targets:
image, mask, keypoints
Image types:
uint8, float32
Reference:
| https://arxiv.org/abs/1708.04552
| https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
| https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
"""
def __init__(
self,
max_holes: int = 8,
max_height: int = 8,
max_width: int = 8,
min_holes: Optional[int] = None,
min_height: Optional[int] = None,
min_width: Optional[int] = None,
fill_value: int = 0,
mask_fill_value: Optional[int] = None,
always_apply: bool = False,
p: float = 0.5,
):
super(CoarseDropout, self).__init__(always_apply, p)
self.max_holes = max_holes
self.max_height = max_height
self.max_width = max_width
self.min_holes = min_holes if min_holes is not None else max_holes
self.min_height = min_height if min_height is not None else max_height
self.min_width = min_width if min_width is not None else max_width
self.fill_value = fill_value
self.mask_fill_value = mask_fill_value
if not 0 < self.min_holes <= self.max_holes:
raise ValueError("Invalid combination of min_holes and max_holes. Got: {}".format([min_holes, max_holes]))
self.check_range(self.max_height)
self.check_range(self.min_height)
self.check_range(self.max_width)
self.check_range(self.min_width)
if not 0 < self.min_height <= self.max_height:
raise ValueError(
"Invalid combination of min_height and max_height. Got: {}".format([min_height, max_height])
)
if not 0 < self.min_width <= self.max_width:
raise ValueError("Invalid combination of min_width and max_width. Got: {}".format([min_width, max_width]))
def check_range(self, dimension):
if isinstance(dimension, float) and not 0 <= dimension < 1.0:
raise ValueError(
"Invalid value {}. If using floats, the value should be in the range [0.0, 1.0)".format(dimension)
)
def apply(
self,
img: np.ndarray,
fill_value: Union[int, float] = 0,
holes: Iterable[Tuple[int, int, int, int]] = (),
**params
) -> np.ndarray:
return cutout(img, holes, fill_value)
def apply_to_mask(
self,
img: np.ndarray,
mask_fill_value: Union[int, float] = 0,
holes: Iterable[Tuple[int, int, int, int]] = (),
**params
) -> np.ndarray:
if mask_fill_value is None:
return img
return cutout(img, holes, mask_fill_value)
def get_params_dependent_on_targets(self, params):
img = params["image"]
height, width = img.shape[:2]
holes = []
for _n in range(random.randint(self.min_holes, self.max_holes)):
if all(
[
isinstance(self.min_height, int),
isinstance(self.min_width, int),
isinstance(self.max_height, int),
isinstance(self.max_width, int),
]
):
hole_height = random.randint(self.min_height, self.max_height)
hole_width = random.randint(self.min_width, self.max_width)
elif all(
[
isinstance(self.min_height, float),
isinstance(self.min_width, float),
isinstance(self.max_height, float),
isinstance(self.max_width, float),
]
):
hole_height = int(height * random.uniform(self.min_height, self.max_height))
hole_width = int(width * random.uniform(self.min_width, self.max_width))
else:
raise ValueError(
"Min width, max width, \
min height and max height \
should all either be ints or floats. \
Got: {} respectively".format(
[
type(self.min_width),
type(self.max_width),
type(self.min_height),
type(self.max_height),
]
)
)
y1 = random.randint(0, height - hole_height)
x1 = random.randint(0, width - hole_width)
y2 = y1 + hole_height
x2 = x1 + hole_width
holes.append((x1, y1, x2, y2))
return {"holes": holes}
@property
def targets_as_params(self):
return ["image"]
def _keypoint_in_hole(self, keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool:
x1, y1, x2, y2 = hole
x, y = keypoint[:2]
return x1 <= x < x2 and y1 <= y < y2
def apply_to_keypoints(
self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params
) -> List[KeypointType]:
result = set(keypoints)
for hole in holes:
for kp in keypoints:
if self._keypoint_in_hole(kp, hole):
result.discard(kp)
return list(result)
def get_transform_init_args_names(self):
return (
"max_holes",
"max_height",
"max_width",
"min_holes",
"min_height",
"min_width",
"fill_value",
"mask_fill_value",
)