JasonSmithSO's picture
Upload 777 files
0034848 verified
import random
from typing import Any, Dict, Optional, Tuple, Union
import cv2
import numpy as np
from skimage.measure import label
from ...core.transforms_interface import DualTransform, to_tuple
__all__ = ["MaskDropout"]
class MaskDropout(DualTransform):
"""
Image & mask augmentation that zero out mask and image regions corresponding
to randomly chosen object instance from mask.
Mask must be single-channel image, zero values treated as background.
Image can be any number of channels.
Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254
Args:
max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max]
image_fill_value: Fill value to use when filling image.
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images)
mask_fill_value: Fill value to use when filling mask.
Targets:
image, mask
Image types:
uint8, float32
"""
def __init__(
self,
max_objects: int = 1,
image_fill_value: Union[int, float, str] = 0,
mask_fill_value: Union[int, float] = 0,
always_apply: bool = False,
p: float = 0.5,
):
super(MaskDropout, self).__init__(always_apply, p)
self.max_objects = to_tuple(max_objects, 1)
self.image_fill_value = image_fill_value
self.mask_fill_value = mask_fill_value
@property
def targets_as_params(self):
return ["mask"]
def get_params_dependent_on_targets(self, params) -> Dict[str, Any]:
mask = params["mask"]
label_image, num_labels = label(mask, return_num=True)
if num_labels == 0:
dropout_mask = None
else:
objects_to_drop = random.randint(int(self.max_objects[0]), int(self.max_objects[1]))
objects_to_drop = min(num_labels, objects_to_drop)
if objects_to_drop == num_labels:
dropout_mask = mask > 0
else:
labels_index = random.sample(range(1, num_labels + 1), objects_to_drop)
dropout_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=bool)
for label_index in labels_index:
dropout_mask |= label_image == label_index
params.update({"dropout_mask": dropout_mask})
return params
def apply(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
if dropout_mask is None:
return img
if self.image_fill_value == "inpaint":
dropout_mask = dropout_mask.astype(np.uint8)
_, _, w, h = cv2.boundingRect(dropout_mask)
radius = min(3, max(w, h) // 2)
img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS)
else:
img = img.copy()
img[dropout_mask] = self.image_fill_value
return img
def apply_to_mask(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
if dropout_mask is None:
return img
img = img.copy()
img[dropout_mask] = self.mask_fill_value
return img
def get_transform_init_args_names(self) -> Tuple[str, ...]:
return "max_objects", "image_fill_value", "mask_fill_value"