Spaces:
Configuration error
Configuration error
import random | |
from typing import Any, Dict, Optional, Tuple, Union | |
import cv2 | |
import numpy as np | |
from skimage.measure import label | |
from ...core.transforms_interface import DualTransform, to_tuple | |
__all__ = ["MaskDropout"] | |
class MaskDropout(DualTransform): | |
""" | |
Image & mask augmentation that zero out mask and image regions corresponding | |
to randomly chosen object instance from mask. | |
Mask must be single-channel image, zero values treated as background. | |
Image can be any number of channels. | |
Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254 | |
Args: | |
max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max] | |
image_fill_value: Fill value to use when filling image. | |
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images) | |
mask_fill_value: Fill value to use when filling mask. | |
Targets: | |
image, mask | |
Image types: | |
uint8, float32 | |
""" | |
def __init__( | |
self, | |
max_objects: int = 1, | |
image_fill_value: Union[int, float, str] = 0, | |
mask_fill_value: Union[int, float] = 0, | |
always_apply: bool = False, | |
p: float = 0.5, | |
): | |
super(MaskDropout, self).__init__(always_apply, p) | |
self.max_objects = to_tuple(max_objects, 1) | |
self.image_fill_value = image_fill_value | |
self.mask_fill_value = mask_fill_value | |
def targets_as_params(self): | |
return ["mask"] | |
def get_params_dependent_on_targets(self, params) -> Dict[str, Any]: | |
mask = params["mask"] | |
label_image, num_labels = label(mask, return_num=True) | |
if num_labels == 0: | |
dropout_mask = None | |
else: | |
objects_to_drop = random.randint(int(self.max_objects[0]), int(self.max_objects[1])) | |
objects_to_drop = min(num_labels, objects_to_drop) | |
if objects_to_drop == num_labels: | |
dropout_mask = mask > 0 | |
else: | |
labels_index = random.sample(range(1, num_labels + 1), objects_to_drop) | |
dropout_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=bool) | |
for label_index in labels_index: | |
dropout_mask |= label_image == label_index | |
params.update({"dropout_mask": dropout_mask}) | |
return params | |
def apply(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray: | |
if dropout_mask is None: | |
return img | |
if self.image_fill_value == "inpaint": | |
dropout_mask = dropout_mask.astype(np.uint8) | |
_, _, w, h = cv2.boundingRect(dropout_mask) | |
radius = min(3, max(w, h) // 2) | |
img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS) | |
else: | |
img = img.copy() | |
img[dropout_mask] = self.image_fill_value | |
return img | |
def apply_to_mask(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray: | |
if dropout_mask is None: | |
return img | |
img = img.copy() | |
img[dropout_mask] = self.mask_fill_value | |
return img | |
def get_transform_init_args_names(self) -> Tuple[str, ...]: | |
return "max_objects", "image_fill_value", "mask_fill_value" | |