""" Code borrowed from https://github.com/zijundeng/pytorch-semantic-segmentation MIT License Copyright (c) 2017 ZijunDeng Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import random import math import numbers import numpy as np import torchvision.transforms as torch_tr import torch from PIL import Image, ImageFilter, ImageOps from skimage.filters import gaussian class RandomGaussianBlur(object): """ Apply Gaussian Blur """ def __call__(self, imgs, mask): img, imgB = imgs[0], imgs[1] sigma = 0.15 + random.random() * 1.15 blurred_img = gaussian(np.array(img), sigma=sigma, channel_axis=-1) blurred_img *= 255 blurred_imgB = gaussian(np.array(imgB), sigma=sigma, channel_axis=-1) blurred_imgB *= 255 return Image.fromarray(blurred_img.astype(np.uint8)), Image.fromarray(blurred_imgB.astype(np.uint8)), mask class RandomScale(object): def __init__(self, scale_list=[0.75, 1.0, 1.25], mode='value'): self.scale_list = scale_list self.mode = mode def __call__(self, img, mask): oh, ow = img.size scale_amt = 1.0 if self.mode == 'value': scale_amt = np.random.choice(self.scale_list, 1) elif self.mode == 'range': scale_amt = random.uniform(self.scale_list[0], self.scale_list[-1]) h = int(scale_amt * oh) w = int(scale_amt * ow) return img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) class SmartCropV1(object): def __init__(self, crop_size=512, max_ratio=0.75, ignore_index=12, nopad=False): self.crop_size = crop_size self.max_ratio = max_ratio self.ignore_index = ignore_index self.crop = RandomCrop(crop_size, ignore_index=ignore_index, nopad=nopad) def __call__(self, img, mask): assert img.size == mask.size count = 0 while True: img_crop, mask_crop = self.crop(img.copy(), mask.copy()) count += 1 labels, cnt = np.unique(np.array(mask_crop), return_counts=True) cnt = cnt[labels != self.ignore_index] if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.max_ratio: break if count > 10: break return img_crop, mask_crop class DeNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): for t, m, s in zip(tensor, self.mean, self.std): t.mul_(s).add_(m) return tensor class MaskToTensor(object): def __call__(self, img): return torch.from_numpy(np.array(img, dtype=np.int32)).long() class FreeScale(object): def __init__(self, size, interpolation=Image.BILINEAR): self.size = tuple(reversed(size)) # size: (h, w) self.interpolation = interpolation def __call__(self, img): return img.resize(self.size, self.interpolation) class FlipChannels(object): def __call__(self, img): img = np.array(img)[:, :, ::-1] return Image.fromarray(img.astype(np.uint8)) class Compose(object): def __init__(self, transforms): self.transforms = transforms def __call__(self, imgs, mask): img ,imgB = imgs[0], imgs[1] assert img.size == mask.size for t in self.transforms: img, imgB, mask = t([img, imgB], mask) return img, imgB, mask class RandomCrop(object): """ Take a random crop from the image. First the image or crop size may need to be adjusted if the incoming image is too small... If the image is smaller than the crop, then: the image is padded up to the size of the crop unless 'nopad', in which case the crop size is shrunk to fit the image A random crop is taken such that the crop fits within the image. If a centroid is passed in, the crop must intersect the centroid. """ def __init__(self, size, ignore_index=0, nopad=True): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.ignore_index = ignore_index self.nopad = nopad self.pad_color = (0, 0, 0) def __call__(self, imgs, mask, centroid=None): img, imgB = imgs[0], imgs[1] assert img.size == mask.size w, h = img.size # ASSUME H, W th, tw = self.size if w == tw and h == th: return img, imgB, mask if self.nopad: if th > h or tw > w: # Instead of padding, adjust crop size to the shorter edge of image. shorter_side = min(w, h) th, tw = shorter_side, shorter_side else: # Check if we need to pad img to fit for crop_size. if th > h: pad_h = (th - h) // 2 + 1 else: pad_h = 0 if tw > w: pad_w = (tw - w) // 2 + 1 else: pad_w = 0 border = (pad_w, pad_h, pad_w, pad_h) if pad_h or pad_w: img = ImageOps.expand(img, border=border, fill=self.pad_color) imgB = ImageOps.expand(imgB, border=border, fill=self.pad_color) mask = ImageOps.expand(mask, border=border, fill=self.ignore_index) w, h = img.size if centroid is not None: # Need to insure that centroid is covered by crop and that crop # sits fully within the image c_x, c_y = centroid max_x = w - tw max_y = h - th x1 = random.randint(c_x - tw, c_x) x1 = min(max_x, max(0, x1)) y1 = random.randint(c_y - th, c_y) y1 = min(max_y, max(0, y1)) else: if w == tw: x1 = 0 else: x1 = random.randint(0, w - tw) if h == th: y1 = 0 else: y1 = random.randint(0, h - th) return img.crop((x1, y1, x1 + tw, y1 + th)), imgB.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) class CenterCrop(object): def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, img, mask): assert img.size == mask.size w, h = img.size th, tw = self.size x1 = int(round((w - tw) / 2.)) y1 = int(round((h - th) / 2.)) return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) class RandomHorizontallyFlip(object): def __init__(self, p): self.p = p def __call__(self, imgs, mask): img, imgB = imgs[0], imgs[1] if random.random() < self.p: return img.transpose(Image.FLIP_LEFT_RIGHT), imgB.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose( Image.FLIP_LEFT_RIGHT) return img, imgB, mask class RandomVerticalFlip(object): def __init__(self, p): self.p = p def __call__(self, imgs, mask): img, imgB = imgs[0], imgs[1] if random.random() < self.p: return img.transpose(Image.FLIP_TOP_BOTTOM), imgB.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose( Image.FLIP_TOP_BOTTOM) return img, imgB, mask class FreeScale(object): def __init__(self, size): self.size = tuple(reversed(size)) # size: (h, w) def __call__(self, img, mask): assert img.size == mask.size return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) class Scale(object): def __init__(self, size): self.size = size def __call__(self, img, mask): assert img.size == mask.size w, h = img.size if (w >= h and w == self.size) or (h >= w and h == self.size): return img, mask if w > h: ow = self.size oh = int(self.size * h / w) return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) else: oh = self.size ow = int(self.size * w / h) return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) class RandomSizedCrop(object): def __init__(self, size): self.size = size def __call__(self, img, mask): assert img.size == mask.size for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.45, 1.0) * area aspect_ratio = random.uniform(0.5, 2) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if random.random() < 0.5: w, h = h, w if w <= img.size[0] and h <= img.size[1]: x1 = random.randint(0, img.size[0] - w) y1 = random.randint(0, img.size[1] - h) img = img.crop((x1, y1, x1 + w, y1 + h)) mask = mask.crop((x1, y1, x1 + w, y1 + h)) assert (img.size == (w, h)) return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), Image.NEAREST) # Fallback scale = Scale(self.size) crop = CenterCrop(self.size) return crop(*scale(img, mask)) class RandomRotate(object): def __init__(self, degree): self.degree = degree def __call__(self, img, mask): rotate_degree = random.random() * 2 * self.degree - self.degree return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) class RandomSized(object): def __init__(self, size): self.size = size self.scale = Scale(self.size) self.crop = RandomCrop(self.size) def __call__(self, img, mask): assert img.size == mask.size w = int(random.uniform(0.5, 2) * img.size[0]) h = int(random.uniform(0.5, 2) * img.size[1]) img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) return self.crop(*self.scale(img, mask)) class SlidingCropOld(object): def __init__(self, crop_size, stride_rate, ignore_label): self.crop_size = crop_size self.stride_rate = stride_rate self.ignore_label = ignore_label def _pad(self, img, mask): h, w = img.shape[: 2] pad_h = max(self.crop_size - h, 0) pad_w = max(self.crop_size - w, 0) img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) return img, mask def __call__(self, img, mask): assert img.size == mask.size w, h = img.size long_size = max(h, w) img = np.array(img) mask = np.array(mask) if long_size > self.crop_size: stride = int(math.ceil(self.crop_size * self.stride_rate)) h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 img_sublist, mask_sublist = [], [] for yy in range(h_step_num): for xx in range(w_step_num): sy, sx = yy * stride, xx * stride ey, ex = sy + self.crop_size, sx + self.crop_size img_sub = img[sy: ey, sx: ex, :] mask_sub = mask[sy: ey, sx: ex] img_sub, mask_sub = self._pad(img_sub, mask_sub) img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) return img_sublist, mask_sublist else: img, mask = self._pad(img, mask) img = Image.fromarray(img.astype(np.uint8)).convert('RGB') mask = Image.fromarray(mask.astype(np.uint8)).convert('P') return img, mask class SlidingCrop(object): def __init__(self, crop_size, stride_rate, ignore_label): self.crop_size = crop_size self.stride_rate = stride_rate self.ignore_label = ignore_label def _pad(self, img, mask): h, w = img.shape[: 2] pad_h = max(self.crop_size - h, 0) pad_w = max(self.crop_size - w, 0) img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) return img, mask, h, w def __call__(self, img, mask): assert img.size == mask.size w, h = img.size long_size = max(h, w) img = np.array(img) mask = np.array(mask) if long_size > self.crop_size: stride = int(math.ceil(self.crop_size * self.stride_rate)) h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 img_slices, mask_slices, slices_info = [], [], [] for yy in range(h_step_num): for xx in range(w_step_num): sy, sx = yy * stride, xx * stride ey, ex = sy + self.crop_size, sx + self.crop_size img_sub = img[sy: ey, sx: ex, :] mask_sub = mask[sy: ey, sx: ex] img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) return img_slices, mask_slices, slices_info else: img, mask, sub_h, sub_w = self._pad(img, mask) img = Image.fromarray(img.astype(np.uint8)).convert('RGB') mask = Image.fromarray(mask.astype(np.uint8)).convert('P') return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] class PadImage(object): def __init__(self, size, ignore_index): self.size = size self.ignore_index = ignore_index def __call__(self, img, mask): assert img.size == mask.size th, tw = self.size, self.size w, h = img.size if w > tw or h > th : wpercent = (tw/float(w)) target_h = int((float(img.size[1])*float(wpercent))) img, mask = img.resize((tw, target_h), Image.BICUBIC), mask.resize((tw, target_h), Image.NEAREST) w, h = img.size ##Pad img = ImageOps.expand(img, border=(0,0,tw-w, th-h), fill=0) mask = ImageOps.expand(mask, border=(0,0,tw-w, th-h), fill=self.ignore_index) return img, mask class Resize(object): """ Resize image to exact size of crop """ def __init__(self, size): self.size = (size, size) def __call__(self, img, mask): assert img.size == mask.size w, h = img.size if (w == h and w == self.size): return img, mask return (img.resize(self.size, Image.BICUBIC), mask.resize(self.size, Image.NEAREST)) class ResizeImage(object): """ Resize image to exact size of crop """ def __init__(self, size): self.size = (size, size) def __call__(self, img, mask): assert img.size == mask.size w, h = img.size if (w == h and w == self.size): return img, mask return (img.resize(self.size, Image.BICUBIC), mask) class RandomSizeAndCrop(object): def __init__(self, size, crop_nopad, scale_min=1, scale_max=1.2, ignore_index=0, pre_size=None): self.size = size self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad) self.scale_min = scale_min self.scale_max = scale_max self.pre_size = pre_size def __call__(self, imgs, mask, centroid=None): img, imgB = imgs[0], imgs[1] assert img.size == mask.size # first, resize such that shorter edge is pre_size if self.pre_size is None: scale_amt = 1. elif img.size[1] < img.size[0]: scale_amt = self.pre_size / img.size[1] else: scale_amt = self.pre_size / img.size[0] scale_amt *= random.uniform(self.scale_min, self.scale_max) w, h = [int(i * scale_amt) for i in img.size] if centroid is not None: centroid = [int(c * scale_amt) for c in centroid] img, imgB, mask = img.resize((w, h), Image.BICUBIC), imgB.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) return self.crop([img, imgB], mask, centroid) class ColorJitter(object): """Randomly change the brightness, contrast and saturation of an image. Args: brightness (float): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. contrast (float): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. saturation (float): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue @staticmethod def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. Arguments are same as that of __init__. Returns: Transform which randomly adjusts brightness, contrast and saturation in a random order. """ transforms = [] if brightness > 0: brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) transforms.append( torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) if contrast > 0: contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) transforms.append( torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) if saturation > 0: saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) transforms.append( torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) if hue > 0: hue_factor = np.random.uniform(-hue, hue) transforms.append( torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) np.random.shuffle(transforms) transform = torch_tr.Compose(transforms) return transform def __call__(self, img): """ Args: img (PIL Image): Input image. Returns: PIL Image: Color jittered image. """ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) return transform(img)