|
|
""" |
|
|
Code borrowed from https://github.com/zijundeng/pytorch-semantic-segmentation |
|
|
MIT License |
|
|
|
|
|
Copyright (c) 2017 ZijunDeng |
|
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
|
of this software and associated documentation files (the "Software"), to deal |
|
|
in the Software without restriction, including without limitation the rights |
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
|
copies of the Software, and to permit persons to whom the Software is |
|
|
furnished to do so, subject to the following conditions: |
|
|
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
|
copies or substantial portions of the Software. |
|
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
|
SOFTWARE. |
|
|
""" |
|
|
|
|
|
import random |
|
|
import math |
|
|
import numbers |
|
|
|
|
|
import numpy as np |
|
|
import torchvision.transforms as torch_tr |
|
|
import torch |
|
|
from PIL import Image, ImageFilter, ImageOps |
|
|
|
|
|
|
|
|
from skimage.filters import gaussian |
|
|
|
|
|
class RandomGaussianBlur(object): |
|
|
""" |
|
|
Apply Gaussian Blur |
|
|
""" |
|
|
def __call__(self, imgs, mask): |
|
|
img, imgB = imgs[0], imgs[1] |
|
|
sigma = 0.15 + random.random() * 1.15 |
|
|
blurred_img = gaussian(np.array(img), sigma=sigma, channel_axis=-1) |
|
|
blurred_img *= 255 |
|
|
blurred_imgB = gaussian(np.array(imgB), sigma=sigma, channel_axis=-1) |
|
|
blurred_imgB *= 255 |
|
|
return Image.fromarray(blurred_img.astype(np.uint8)), Image.fromarray(blurred_imgB.astype(np.uint8)), mask |
|
|
|
|
|
|
|
|
|
|
|
class RandomScale(object): |
|
|
def __init__(self, scale_list=[0.75, 1.0, 1.25], mode='value'): |
|
|
self.scale_list = scale_list |
|
|
self.mode = mode |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
oh, ow = img.size |
|
|
scale_amt = 1.0 |
|
|
if self.mode == 'value': |
|
|
scale_amt = np.random.choice(self.scale_list, 1) |
|
|
elif self.mode == 'range': |
|
|
scale_amt = random.uniform(self.scale_list[0], self.scale_list[-1]) |
|
|
h = int(scale_amt * oh) |
|
|
w = int(scale_amt * ow) |
|
|
return img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) |
|
|
|
|
|
class SmartCropV1(object): |
|
|
def __init__(self, crop_size=512, |
|
|
max_ratio=0.75, |
|
|
ignore_index=12, nopad=False): |
|
|
self.crop_size = crop_size |
|
|
self.max_ratio = max_ratio |
|
|
self.ignore_index = ignore_index |
|
|
self.crop = RandomCrop(crop_size, ignore_index=ignore_index, nopad=nopad) |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
count = 0 |
|
|
while True: |
|
|
img_crop, mask_crop = self.crop(img.copy(), mask.copy()) |
|
|
count += 1 |
|
|
labels, cnt = np.unique(np.array(mask_crop), return_counts=True) |
|
|
cnt = cnt[labels != self.ignore_index] |
|
|
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.max_ratio: |
|
|
break |
|
|
if count > 10: |
|
|
break |
|
|
|
|
|
return img_crop, mask_crop |
|
|
|
|
|
|
|
|
|
|
|
class DeNormalize(object): |
|
|
def __init__(self, mean, std): |
|
|
self.mean = mean |
|
|
self.std = std |
|
|
|
|
|
def __call__(self, tensor): |
|
|
for t, m, s in zip(tensor, self.mean, self.std): |
|
|
t.mul_(s).add_(m) |
|
|
return tensor |
|
|
|
|
|
|
|
|
class MaskToTensor(object): |
|
|
def __call__(self, img): |
|
|
return torch.from_numpy(np.array(img, dtype=np.int32)).long() |
|
|
|
|
|
|
|
|
class FreeScale(object): |
|
|
def __init__(self, size, interpolation=Image.BILINEAR): |
|
|
self.size = tuple(reversed(size)) |
|
|
self.interpolation = interpolation |
|
|
|
|
|
def __call__(self, img): |
|
|
return img.resize(self.size, self.interpolation) |
|
|
|
|
|
|
|
|
class FlipChannels(object): |
|
|
def __call__(self, img): |
|
|
img = np.array(img)[:, :, ::-1] |
|
|
return Image.fromarray(img.astype(np.uint8)) |
|
|
|
|
|
|
|
|
class Compose(object): |
|
|
def __init__(self, transforms): |
|
|
self.transforms = transforms |
|
|
|
|
|
def __call__(self, imgs, mask): |
|
|
img ,imgB = imgs[0], imgs[1] |
|
|
assert img.size == mask.size |
|
|
for t in self.transforms: |
|
|
img, imgB, mask = t([img, imgB], mask) |
|
|
return img, imgB, mask |
|
|
|
|
|
class RandomCrop(object): |
|
|
""" |
|
|
Take a random crop from the image. |
|
|
|
|
|
First the image or crop size may need to be adjusted if the incoming image |
|
|
is too small... |
|
|
|
|
|
If the image is smaller than the crop, then: |
|
|
the image is padded up to the size of the crop |
|
|
unless 'nopad', in which case the crop size is shrunk to fit the image |
|
|
|
|
|
A random crop is taken such that the crop fits within the image. |
|
|
If a centroid is passed in, the crop must intersect the centroid. |
|
|
""" |
|
|
def __init__(self, size, ignore_index=0, nopad=True): |
|
|
if isinstance(size, numbers.Number): |
|
|
self.size = (int(size), int(size)) |
|
|
else: |
|
|
self.size = size |
|
|
self.ignore_index = ignore_index |
|
|
self.nopad = nopad |
|
|
self.pad_color = (0, 0, 0) |
|
|
|
|
|
def __call__(self, imgs, mask, centroid=None): |
|
|
img, imgB = imgs[0], imgs[1] |
|
|
assert img.size == mask.size |
|
|
w, h = img.size |
|
|
|
|
|
th, tw = self.size |
|
|
if w == tw and h == th: |
|
|
return img, imgB, mask |
|
|
|
|
|
if self.nopad: |
|
|
if th > h or tw > w: |
|
|
|
|
|
shorter_side = min(w, h) |
|
|
th, tw = shorter_side, shorter_side |
|
|
else: |
|
|
|
|
|
if th > h: |
|
|
pad_h = (th - h) // 2 + 1 |
|
|
else: |
|
|
pad_h = 0 |
|
|
if tw > w: |
|
|
pad_w = (tw - w) // 2 + 1 |
|
|
else: |
|
|
pad_w = 0 |
|
|
border = (pad_w, pad_h, pad_w, pad_h) |
|
|
if pad_h or pad_w: |
|
|
img = ImageOps.expand(img, border=border, fill=self.pad_color) |
|
|
imgB = ImageOps.expand(imgB, border=border, fill=self.pad_color) |
|
|
mask = ImageOps.expand(mask, border=border, fill=self.ignore_index) |
|
|
w, h = img.size |
|
|
|
|
|
if centroid is not None: |
|
|
|
|
|
|
|
|
c_x, c_y = centroid |
|
|
max_x = w - tw |
|
|
max_y = h - th |
|
|
x1 = random.randint(c_x - tw, c_x) |
|
|
x1 = min(max_x, max(0, x1)) |
|
|
y1 = random.randint(c_y - th, c_y) |
|
|
y1 = min(max_y, max(0, y1)) |
|
|
else: |
|
|
if w == tw: |
|
|
x1 = 0 |
|
|
else: |
|
|
x1 = random.randint(0, w - tw) |
|
|
if h == th: |
|
|
y1 = 0 |
|
|
else: |
|
|
y1 = random.randint(0, h - th) |
|
|
return img.crop((x1, y1, x1 + tw, y1 + th)), imgB.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) |
|
|
|
|
|
class CenterCrop(object): |
|
|
def __init__(self, size): |
|
|
if isinstance(size, numbers.Number): |
|
|
self.size = (int(size), int(size)) |
|
|
else: |
|
|
self.size = size |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
w, h = img.size |
|
|
th, tw = self.size |
|
|
x1 = int(round((w - tw) / 2.)) |
|
|
y1 = int(round((h - th) / 2.)) |
|
|
return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) |
|
|
|
|
|
|
|
|
class RandomHorizontallyFlip(object): |
|
|
def __init__(self, p): |
|
|
self.p = p |
|
|
|
|
|
def __call__(self, imgs, mask): |
|
|
img, imgB = imgs[0], imgs[1] |
|
|
if random.random() < self.p: |
|
|
return img.transpose(Image.FLIP_LEFT_RIGHT), imgB.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose( |
|
|
Image.FLIP_LEFT_RIGHT) |
|
|
return img, imgB, mask |
|
|
|
|
|
|
|
|
class RandomVerticalFlip(object): |
|
|
def __init__(self, p): |
|
|
self.p = p |
|
|
|
|
|
def __call__(self, imgs, mask): |
|
|
img, imgB = imgs[0], imgs[1] |
|
|
if random.random() < self.p: |
|
|
return img.transpose(Image.FLIP_TOP_BOTTOM), imgB.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose( |
|
|
Image.FLIP_TOP_BOTTOM) |
|
|
return img, imgB, mask |
|
|
|
|
|
|
|
|
class FreeScale(object): |
|
|
def __init__(self, size): |
|
|
self.size = tuple(reversed(size)) |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) |
|
|
|
|
|
|
|
|
class Scale(object): |
|
|
def __init__(self, size): |
|
|
self.size = size |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
w, h = img.size |
|
|
if (w >= h and w == self.size) or (h >= w and h == self.size): |
|
|
return img, mask |
|
|
if w > h: |
|
|
ow = self.size |
|
|
oh = int(self.size * h / w) |
|
|
return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) |
|
|
else: |
|
|
oh = self.size |
|
|
ow = int(self.size * w / h) |
|
|
return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) |
|
|
|
|
|
|
|
|
class RandomSizedCrop(object): |
|
|
def __init__(self, size): |
|
|
self.size = size |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
for attempt in range(10): |
|
|
area = img.size[0] * img.size[1] |
|
|
target_area = random.uniform(0.45, 1.0) * area |
|
|
aspect_ratio = random.uniform(0.5, 2) |
|
|
|
|
|
w = int(round(math.sqrt(target_area * aspect_ratio))) |
|
|
h = int(round(math.sqrt(target_area / aspect_ratio))) |
|
|
|
|
|
if random.random() < 0.5: |
|
|
w, h = h, w |
|
|
|
|
|
if w <= img.size[0] and h <= img.size[1]: |
|
|
x1 = random.randint(0, img.size[0] - w) |
|
|
y1 = random.randint(0, img.size[1] - h) |
|
|
|
|
|
img = img.crop((x1, y1, x1 + w, y1 + h)) |
|
|
mask = mask.crop((x1, y1, x1 + w, y1 + h)) |
|
|
assert (img.size == (w, h)) |
|
|
|
|
|
return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), |
|
|
Image.NEAREST) |
|
|
|
|
|
|
|
|
scale = Scale(self.size) |
|
|
crop = CenterCrop(self.size) |
|
|
return crop(*scale(img, mask)) |
|
|
|
|
|
|
|
|
class RandomRotate(object): |
|
|
def __init__(self, degree): |
|
|
self.degree = degree |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
rotate_degree = random.random() * 2 * self.degree - self.degree |
|
|
return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) |
|
|
|
|
|
|
|
|
class RandomSized(object): |
|
|
def __init__(self, size): |
|
|
self.size = size |
|
|
self.scale = Scale(self.size) |
|
|
self.crop = RandomCrop(self.size) |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
|
|
|
w = int(random.uniform(0.5, 2) * img.size[0]) |
|
|
h = int(random.uniform(0.5, 2) * img.size[1]) |
|
|
|
|
|
img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) |
|
|
|
|
|
return self.crop(*self.scale(img, mask)) |
|
|
|
|
|
|
|
|
class SlidingCropOld(object): |
|
|
def __init__(self, crop_size, stride_rate, ignore_label): |
|
|
self.crop_size = crop_size |
|
|
self.stride_rate = stride_rate |
|
|
self.ignore_label = ignore_label |
|
|
|
|
|
def _pad(self, img, mask): |
|
|
h, w = img.shape[: 2] |
|
|
pad_h = max(self.crop_size - h, 0) |
|
|
pad_w = max(self.crop_size - w, 0) |
|
|
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') |
|
|
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) |
|
|
return img, mask |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
|
|
|
w, h = img.size |
|
|
long_size = max(h, w) |
|
|
|
|
|
img = np.array(img) |
|
|
mask = np.array(mask) |
|
|
|
|
|
if long_size > self.crop_size: |
|
|
stride = int(math.ceil(self.crop_size * self.stride_rate)) |
|
|
h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 |
|
|
w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 |
|
|
img_sublist, mask_sublist = [], [] |
|
|
for yy in range(h_step_num): |
|
|
for xx in range(w_step_num): |
|
|
sy, sx = yy * stride, xx * stride |
|
|
ey, ex = sy + self.crop_size, sx + self.crop_size |
|
|
img_sub = img[sy: ey, sx: ex, :] |
|
|
mask_sub = mask[sy: ey, sx: ex] |
|
|
img_sub, mask_sub = self._pad(img_sub, mask_sub) |
|
|
img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) |
|
|
mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) |
|
|
return img_sublist, mask_sublist |
|
|
else: |
|
|
img, mask = self._pad(img, mask) |
|
|
img = Image.fromarray(img.astype(np.uint8)).convert('RGB') |
|
|
mask = Image.fromarray(mask.astype(np.uint8)).convert('P') |
|
|
return img, mask |
|
|
|
|
|
|
|
|
class SlidingCrop(object): |
|
|
def __init__(self, crop_size, stride_rate, ignore_label): |
|
|
self.crop_size = crop_size |
|
|
self.stride_rate = stride_rate |
|
|
self.ignore_label = ignore_label |
|
|
|
|
|
def _pad(self, img, mask): |
|
|
h, w = img.shape[: 2] |
|
|
pad_h = max(self.crop_size - h, 0) |
|
|
pad_w = max(self.crop_size - w, 0) |
|
|
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') |
|
|
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) |
|
|
return img, mask, h, w |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
|
|
|
w, h = img.size |
|
|
long_size = max(h, w) |
|
|
|
|
|
img = np.array(img) |
|
|
mask = np.array(mask) |
|
|
|
|
|
if long_size > self.crop_size: |
|
|
stride = int(math.ceil(self.crop_size * self.stride_rate)) |
|
|
h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 |
|
|
w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 |
|
|
img_slices, mask_slices, slices_info = [], [], [] |
|
|
for yy in range(h_step_num): |
|
|
for xx in range(w_step_num): |
|
|
sy, sx = yy * stride, xx * stride |
|
|
ey, ex = sy + self.crop_size, sx + self.crop_size |
|
|
img_sub = img[sy: ey, sx: ex, :] |
|
|
mask_sub = mask[sy: ey, sx: ex] |
|
|
img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) |
|
|
img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) |
|
|
mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) |
|
|
slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) |
|
|
return img_slices, mask_slices, slices_info |
|
|
else: |
|
|
img, mask, sub_h, sub_w = self._pad(img, mask) |
|
|
img = Image.fromarray(img.astype(np.uint8)).convert('RGB') |
|
|
mask = Image.fromarray(mask.astype(np.uint8)).convert('P') |
|
|
return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] |
|
|
|
|
|
class PadImage(object): |
|
|
def __init__(self, size, ignore_index): |
|
|
self.size = size |
|
|
self.ignore_index = ignore_index |
|
|
|
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
th, tw = self.size, self.size |
|
|
|
|
|
|
|
|
w, h = img.size |
|
|
|
|
|
if w > tw or h > th : |
|
|
wpercent = (tw/float(w)) |
|
|
target_h = int((float(img.size[1])*float(wpercent))) |
|
|
img, mask = img.resize((tw, target_h), Image.BICUBIC), mask.resize((tw, target_h), Image.NEAREST) |
|
|
|
|
|
w, h = img.size |
|
|
|
|
|
img = ImageOps.expand(img, border=(0,0,tw-w, th-h), fill=0) |
|
|
mask = ImageOps.expand(mask, border=(0,0,tw-w, th-h), fill=self.ignore_index) |
|
|
|
|
|
return img, mask |
|
|
|
|
|
class Resize(object): |
|
|
""" |
|
|
Resize image to exact size of crop |
|
|
""" |
|
|
|
|
|
def __init__(self, size): |
|
|
self.size = (size, size) |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
w, h = img.size |
|
|
if (w == h and w == self.size): |
|
|
return img, mask |
|
|
return (img.resize(self.size, Image.BICUBIC), |
|
|
mask.resize(self.size, Image.NEAREST)) |
|
|
|
|
|
class ResizeImage(object): |
|
|
""" |
|
|
Resize image to exact size of crop |
|
|
""" |
|
|
|
|
|
def __init__(self, size): |
|
|
self.size = (size, size) |
|
|
|
|
|
def __call__(self, img, mask): |
|
|
assert img.size == mask.size |
|
|
w, h = img.size |
|
|
if (w == h and w == self.size): |
|
|
return img, mask |
|
|
return (img.resize(self.size, Image.BICUBIC), |
|
|
mask) |
|
|
|
|
|
class RandomSizeAndCrop(object): |
|
|
def __init__(self, size, crop_nopad, |
|
|
scale_min=1, scale_max=1.2, ignore_index=0, pre_size=None): |
|
|
self.size = size |
|
|
self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad) |
|
|
self.scale_min = scale_min |
|
|
self.scale_max = scale_max |
|
|
self.pre_size = pre_size |
|
|
|
|
|
def __call__(self, imgs, mask, centroid=None): |
|
|
img, imgB = imgs[0], imgs[1] |
|
|
assert img.size == mask.size |
|
|
|
|
|
|
|
|
if self.pre_size is None: |
|
|
scale_amt = 1. |
|
|
elif img.size[1] < img.size[0]: |
|
|
scale_amt = self.pre_size / img.size[1] |
|
|
else: |
|
|
scale_amt = self.pre_size / img.size[0] |
|
|
scale_amt *= random.uniform(self.scale_min, self.scale_max) |
|
|
w, h = [int(i * scale_amt) for i in img.size] |
|
|
|
|
|
if centroid is not None: |
|
|
centroid = [int(c * scale_amt) for c in centroid] |
|
|
|
|
|
img, imgB, mask = img.resize((w, h), Image.BICUBIC), imgB.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) |
|
|
|
|
|
return self.crop([img, imgB], mask, centroid) |
|
|
|
|
|
class ColorJitter(object): |
|
|
"""Randomly change the brightness, contrast and saturation of an image. |
|
|
|
|
|
Args: |
|
|
brightness (float): How much to jitter brightness. brightness_factor |
|
|
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. |
|
|
contrast (float): How much to jitter contrast. contrast_factor |
|
|
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. |
|
|
saturation (float): How much to jitter saturation. saturation_factor |
|
|
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. |
|
|
hue(float): How much to jitter hue. hue_factor is chosen uniformly from |
|
|
[-hue, hue]. Should be >=0 and <= 0.5. |
|
|
""" |
|
|
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): |
|
|
self.brightness = brightness |
|
|
self.contrast = contrast |
|
|
self.saturation = saturation |
|
|
self.hue = hue |
|
|
|
|
|
@staticmethod |
|
|
def get_params(brightness, contrast, saturation, hue): |
|
|
"""Get a randomized transform to be applied on image. |
|
|
|
|
|
Arguments are same as that of __init__. |
|
|
|
|
|
Returns: |
|
|
Transform which randomly adjusts brightness, contrast and |
|
|
saturation in a random order. |
|
|
""" |
|
|
transforms = [] |
|
|
if brightness > 0: |
|
|
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) |
|
|
transforms.append( |
|
|
torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) |
|
|
|
|
|
if contrast > 0: |
|
|
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) |
|
|
transforms.append( |
|
|
torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) |
|
|
|
|
|
if saturation > 0: |
|
|
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) |
|
|
transforms.append( |
|
|
torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) |
|
|
|
|
|
if hue > 0: |
|
|
hue_factor = np.random.uniform(-hue, hue) |
|
|
transforms.append( |
|
|
torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) |
|
|
|
|
|
np.random.shuffle(transforms) |
|
|
transform = torch_tr.Compose(transforms) |
|
|
|
|
|
return transform |
|
|
|
|
|
def __call__(self, img): |
|
|
""" |
|
|
Args: |
|
|
img (PIL Image): Input image. |
|
|
|
|
|
Returns: |
|
|
PIL Image: Color jittered image. |
|
|
""" |
|
|
transform = self.get_params(self.brightness, self.contrast, |
|
|
self.saturation, self.hue) |
|
|
return transform(img) |
|
|
|
|
|
|
|
|
|