Spaces:
Configuration error
Configuration error
import random | |
from typing import Iterable, Optional, Tuple | |
import numpy as np | |
from ...core.transforms_interface import DualTransform | |
from . import functional as F | |
__all__ = ["GridDropout"] | |
class GridDropout(DualTransform): | |
"""GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion. | |
Args: | |
ratio (float): the ratio of the mask holes to the unit_size (same for horizontal and vertical directions). | |
Must be between 0 and 1. Default: 0.5. | |
unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge. | |
If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`. | |
unit_size_max (int): maximum size of the grid unit. Must be between 2 and the image shorter edge. | |
If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`. | |
holes_number_x (int): the number of grid units in x direction. Must be between 1 and image width//2. | |
If 'None', grid unit width is set as image_width//10. Default: `None`. | |
holes_number_y (int): the number of grid units in y direction. Must be between 1 and image height//2. | |
If `None`, grid unit height is set equal to the grid unit width or image height, whatever is smaller. | |
shift_x (int): offsets of the grid start in x direction from (0,0) coordinate. | |
Clipped between 0 and grid unit_width - hole_width. Default: 0. | |
shift_y (int): offsets of the grid start in y direction from (0,0) coordinate. | |
Clipped between 0 and grid unit height - hole_height. Default: 0. | |
random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size | |
If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`. | |
fill_value (int): value for the dropped pixels. Default = 0 | |
mask_fill_value (int): value for the dropped pixels in mask. | |
If `None`, transformation is not applied to the mask. Default: `None`. | |
Targets: | |
image, mask | |
Image types: | |
uint8, float32 | |
References: | |
https://arxiv.org/abs/2001.04086 | |
""" | |
def __init__( | |
self, | |
ratio: float = 0.5, | |
unit_size_min: Optional[int] = None, | |
unit_size_max: Optional[int] = None, | |
holes_number_x: Optional[int] = None, | |
holes_number_y: Optional[int] = None, | |
shift_x: int = 0, | |
shift_y: int = 0, | |
random_offset: bool = False, | |
fill_value: int = 0, | |
mask_fill_value: Optional[int] = None, | |
always_apply: bool = False, | |
p: float = 0.5, | |
): | |
super(GridDropout, self).__init__(always_apply, p) | |
self.ratio = ratio | |
self.unit_size_min = unit_size_min | |
self.unit_size_max = unit_size_max | |
self.holes_number_x = holes_number_x | |
self.holes_number_y = holes_number_y | |
self.shift_x = shift_x | |
self.shift_y = shift_y | |
self.random_offset = random_offset | |
self.fill_value = fill_value | |
self.mask_fill_value = mask_fill_value | |
if not 0 < self.ratio <= 1: | |
raise ValueError("ratio must be between 0 and 1.") | |
def apply(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray: | |
return F.cutout(img, holes, self.fill_value) | |
def apply_to_mask(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray: | |
if self.mask_fill_value is None: | |
return img | |
return F.cutout(img, holes, self.mask_fill_value) | |
def get_params_dependent_on_targets(self, params): | |
img = params["image"] | |
height, width = img.shape[:2] | |
# set grid using unit size limits | |
if self.unit_size_min and self.unit_size_max: | |
if not 2 <= self.unit_size_min <= self.unit_size_max: | |
raise ValueError("Max unit size should be >= min size, both at least 2 pixels.") | |
if self.unit_size_max > min(height, width): | |
raise ValueError("Grid size limits must be within the shortest image edge.") | |
unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1) | |
unit_height = unit_width | |
else: | |
# set grid using holes numbers | |
if self.holes_number_x is None: | |
unit_width = max(2, width // 10) | |
else: | |
if not 1 <= self.holes_number_x <= width // 2: | |
raise ValueError("The hole_number_x must be between 1 and image width//2.") | |
unit_width = width // self.holes_number_x | |
if self.holes_number_y is None: | |
unit_height = max(min(unit_width, height), 2) | |
else: | |
if not 1 <= self.holes_number_y <= height // 2: | |
raise ValueError("The hole_number_y must be between 1 and image height//2.") | |
unit_height = height // self.holes_number_y | |
hole_width = int(unit_width * self.ratio) | |
hole_height = int(unit_height * self.ratio) | |
# min 1 pixel and max unit length - 1 | |
hole_width = min(max(hole_width, 1), unit_width - 1) | |
hole_height = min(max(hole_height, 1), unit_height - 1) | |
# set offset of the grid | |
if self.shift_x is None: | |
shift_x = 0 | |
else: | |
shift_x = min(max(0, self.shift_x), unit_width - hole_width) | |
if self.shift_y is None: | |
shift_y = 0 | |
else: | |
shift_y = min(max(0, self.shift_y), unit_height - hole_height) | |
if self.random_offset: | |
shift_x = random.randint(0, unit_width - hole_width) | |
shift_y = random.randint(0, unit_height - hole_height) | |
holes = [] | |
for i in range(width // unit_width + 1): | |
for j in range(height // unit_height + 1): | |
x1 = min(shift_x + unit_width * i, width) | |
y1 = min(shift_y + unit_height * j, height) | |
x2 = min(x1 + hole_width, width) | |
y2 = min(y1 + hole_height, height) | |
holes.append((x1, y1, x2, y2)) | |
return {"holes": holes} | |
def targets_as_params(self): | |
return ["image"] | |
def get_transform_init_args_names(self): | |
return ( | |
"ratio", | |
"unit_size_min", | |
"unit_size_max", | |
"holes_number_x", | |
"holes_number_y", | |
"shift_x", | |
"shift_y", | |
"random_offset", | |
"fill_value", | |
"mask_fill_value", | |
) | |