Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import os | |
import glob | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import random | |
import numpy as np | |
import math | |
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels | |
from basicsr.data.transforms import augment | |
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor | |
from PIL import Image | |
class SimpleDataset(Dataset): | |
def __init__(self, opt, fix_size=512): | |
self.opt = opt | |
self.image_root = opt['gt_path'] | |
self.fix_size = fix_size | |
exts = ['*.jpg', '*.png'] | |
self.image_list = [] | |
for image_root in self.image_root: | |
for ext in exts: | |
image_list = glob.glob(os.path.join(image_root, ext)) | |
self.image_list += image_list | |
# if add lsdir dataset | |
image_list = glob.glob(os.path.join(image_root, '00*', ext)) | |
self.image_list += image_list | |
self.crop_preproc = transforms.Compose([ | |
# transforms.CenterCrop(fix_size), | |
transforms.Resize(fix_size) | |
# transforms.RandomHorizontalFlip(), | |
]) | |
self.img_preproc = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
# blur settings for the first degradation | |
self.blur_kernel_size = opt['blur_kernel_size'] | |
self.kernel_list = opt['kernel_list'] | |
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability | |
self.blur_sigma = opt['blur_sigma'] | |
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels | |
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels | |
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters | |
# blur settings for the second degradation | |
self.blur_kernel_size2 = opt['blur_kernel_size2'] | |
self.kernel_list2 = opt['kernel_list2'] | |
self.kernel_prob2 = opt['kernel_prob2'] | |
self.blur_sigma2 = opt['blur_sigma2'] | |
self.betag_range2 = opt['betag_range2'] | |
self.betap_range2 = opt['betap_range2'] | |
self.sinc_prob2 = opt['sinc_prob2'] | |
# a final sinc filter | |
self.final_sinc_prob = opt['final_sinc_prob'] | |
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 | |
# TODO: kernel range is now hard-coded, should be in the configure file | |
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect | |
self.pulse_tensor[10, 10] = 1 | |
print(f'The dataset length: {len(self.image_list)}') | |
def __getitem__(self, index): | |
image = Image.open(self.image_list[index]).convert('RGB') | |
# width, height = image.size | |
# if width > height: | |
# width_after = self.fix_size | |
# height_after = int(height*width_after/width) | |
# elif height > width: | |
# height_after = self.fix_size | |
# width_after = int(width*height_after/height) | |
# elif height == width: | |
# height_after = self.fix_size | |
# width_after = self.fix_size | |
image = image.resize((self.fix_size, self.fix_size),Image.LANCZOS) | |
# image = self.crop_preproc(image) | |
image = self.img_preproc(image) | |
# ------------------------ Generate kernels (used in the first degradation) ------------------------ # | |
kernel_size = random.choice(self.kernel_range) | |
if np.random.uniform() < self.opt['sinc_prob']: | |
# this sinc filter setting is for kernels ranging from [7, 21] | |
if kernel_size < 13: | |
omega_c = np.random.uniform(np.pi / 3, np.pi) | |
else: | |
omega_c = np.random.uniform(np.pi / 5, np.pi) | |
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) | |
else: | |
kernel = random_mixed_kernels( | |
self.kernel_list, | |
self.kernel_prob, | |
kernel_size, | |
self.blur_sigma, | |
self.blur_sigma, [-math.pi, math.pi], | |
self.betag_range, | |
self.betap_range, | |
noise_range=None) | |
# pad kernel | |
pad_size = (21 - kernel_size) // 2 | |
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) | |
# ------------------------ Generate kernels (used in the second degradation) ------------------------ # | |
kernel_size = random.choice(self.kernel_range) | |
if np.random.uniform() < self.opt['sinc_prob2']: | |
if kernel_size < 13: | |
omega_c = np.random.uniform(np.pi / 3, np.pi) | |
else: | |
omega_c = np.random.uniform(np.pi / 5, np.pi) | |
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) | |
else: | |
kernel2 = random_mixed_kernels( | |
self.kernel_list2, | |
self.kernel_prob2, | |
kernel_size, | |
self.blur_sigma2, | |
self.blur_sigma2, [-math.pi, math.pi], | |
self.betag_range2, | |
self.betap_range2, | |
noise_range=None) | |
# pad kernel | |
pad_size = (21 - kernel_size) // 2 | |
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) | |
# ------------------------------------- the final sinc kernel ------------------------------------- # | |
if np.random.uniform() < self.opt['final_sinc_prob']: | |
kernel_size = random.choice(self.kernel_range) | |
omega_c = np.random.uniform(np.pi / 3, np.pi) | |
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) | |
sinc_kernel = torch.FloatTensor(sinc_kernel) | |
else: | |
sinc_kernel = self.pulse_tensor | |
# BGR to RGB, HWC to CHW, numpy to tensor | |
# img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] | |
kernel = torch.FloatTensor(kernel) | |
kernel2 = torch.FloatTensor(kernel2) | |
return_d = {'gt': image, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'lq_path': self.image_list[index]} | |
return return_d | |
def __len__(self): | |
return len(self.image_list) | |