SuperResolution / dataloaders /simple_dataset.py
alexnasa's picture
Upload 22 files
c382c9b verified
import cv2
import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random
import numpy as np
import math
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from PIL import Image
class SimpleDataset(Dataset):
def __init__(self, opt, fix_size=512):
self.opt = opt
self.image_root = opt['gt_path']
self.fix_size = fix_size
exts = ['*.jpg', '*.png']
self.image_list = []
for image_root in self.image_root:
for ext in exts:
image_list = glob.glob(os.path.join(image_root, ext))
self.image_list += image_list
# if add lsdir dataset
image_list = glob.glob(os.path.join(image_root, '00*', ext))
self.image_list += image_list
self.crop_preproc = transforms.Compose([
# transforms.CenterCrop(fix_size),
transforms.Resize(fix_size)
# transforms.RandomHorizontalFlip(),
])
self.img_preproc = transforms.Compose([
transforms.ToTensor(),
])
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
self.kernel_list2 = opt['kernel_list2']
self.kernel_prob2 = opt['kernel_prob2']
self.blur_sigma2 = opt['blur_sigma2']
self.betag_range2 = opt['betag_range2']
self.betap_range2 = opt['betap_range2']
self.sinc_prob2 = opt['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
print(f'The dataset length: {len(self.image_list)}')
def __getitem__(self, index):
image = Image.open(self.image_list[index]).convert('RGB')
# width, height = image.size
# if width > height:
# width_after = self.fix_size
# height_after = int(height*width_after/width)
# elif height > width:
# height_after = self.fix_size
# width_after = int(width*height_after/height)
# elif height == width:
# height_after = self.fix_size
# width_after = self.fix_size
image = image.resize((self.fix_size, self.fix_size),Image.LANCZOS)
# image = self.crop_preproc(image)
image = self.img_preproc(image)
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob']:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob2']:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
# img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return_d = {'gt': image, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'lq_path': self.image_list[index]}
return return_d
def __len__(self):
return len(self.image_list)