wushuang98's picture
Upload 197 files
bcb05d1 verified
import numpy as np
from PIL import Image
import torch
import random
from torchvision import transforms
import torchvision.transforms.functional as TF
def apply_joint_transforms(rgb, mask, img_size, img_aug=True, test=True):
if test:
extra_pad = 16
else:
extra_pad = random.randint(0, 32)
W_img, H_img = rgb.size[:2]
max_HW = max(H_img, W_img)
top_pad = (max_HW - H_img) // 2
bottom_pad = max_HW - H_img - top_pad
left_pad = (max_HW - W_img) // 2
right_pad = max_HW - W_img - left_pad
# 1. padding
rgb = TF.pad(rgb, (left_pad, top_pad, right_pad, bottom_pad), fill=255)
mask = TF.pad(mask, (left_pad, top_pad, right_pad, bottom_pad), fill=0)
if img_aug and (not test):
# 2. random rotate
if random.random() < 0.1:
angle = random.uniform(-10, 10)
rgb = TF.rotate(rgb, angle, fill=255)
mask = TF.rotate(mask, angle, fill=0)
# 3. random crop
if random.random() < 0.1:
crop_ratio = random.uniform(0.9, 1.0)
crop_size = int(max_HW * crop_ratio)
i, j, h, w = transforms.RandomCrop.get_params(rgb, (crop_size, crop_size))
rgb = TF.crop(rgb, i, j, h, w)
mask = TF.crop(mask, i, j, h, w)
# 4. resize
target_size = (img_size, img_size)
rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)
# 5. extra padding
rgb = TF.pad(rgb, extra_pad, fill=255)
mask = TF.pad(mask, extra_pad, fill=0)
rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)
# to tensor
rgb_tensor = TF.to_tensor(rgb)
mask_tensor = TF.to_tensor(mask)
return rgb_tensor, mask_tensor
def crop_recenter(image_no_bg, thereshold=100):
image_no_bg_np = np.array(image_no_bg)
mask = (image_no_bg_np[..., -1]).astype(np.uint8)
mask_bin = mask > thereshold
H, W = image_no_bg_np.shape[:2]
valid_pixels = mask_bin.astype(np.float32).nonzero() # [N, 2]
if np.sum(mask_bin) < (H*W) * 0.001:
min_h =0
max_h = H - 1
min_w = 0
max_w = W -1
else:
min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max()
min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max()
if min_h < 0:
min_h = 0
if min_w < 0:
min_w = 0
if max_h > H:
max_h = H
if max_w > W:
max_w = W
image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1]
return image_no_bg_np
def preprocess_image(img):
if isinstance(img, str):
img = Image.open(img)
img = np.array(img)
elif isinstance(img, Image.Image):
img = np.array(img)
if img.shape[-1] == 3:
mask = np.ones_like(img[..., 0:1])
img = np.concatenate([img, mask], axis=-1)
img = crop_recenter(img, thereshold=0) / 255.
mask = img[..., 3]
img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:])
img = Image.fromarray((img * 255).astype(np.uint8))
mask = Image.fromarray((mask * 255).astype(np.uint8))
img, mask = apply_joint_transforms(img, mask, img_size=518,
img_aug=False, test=True)
img = torch.cat([img, mask], dim=0)
return img