SuperResolution / dataloaders /paired_dataset.py
alexnasa's picture
Upload 22 files
c382c9b verified
import glob
import os
from PIL import Image
import random
import numpy as np
from torch import nn
from torchvision import transforms
from torch.utils import data as data
import torch.nn.functional as F
from .realesrgan import RealESRGAN_degradation
class PairedCaptionDataset(data.Dataset):
def __init__(
self,
root_folders=None,
tokenizer=None,
null_text_ratio=0.5,
# use_ram_encoder=False,
# use_gt_caption=False,
# caption_type = 'gt_caption',
):
super(PairedCaptionDataset, self).__init__()
self.null_text_ratio = null_text_ratio
self.lr_list = []
self.gt_list = []
self.tag_path_list = []
root_folders = root_folders.split(',')
for root_folder in root_folders:
lr_path = root_folder +'/sr_bicubic'
tag_path = root_folder +'/tag'
gt_path = root_folder +'/gt'
self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
assert len(self.lr_list) == len(self.gt_list)
assert len(self.lr_list) == len(self.tag_path_list)
self.img_preproc = transforms.Compose([
transforms.ToTensor(),
])
ram_mean = [0.485, 0.456, 0.406]
ram_std = [0.229, 0.224, 0.225]
self.ram_normalize = transforms.Normalize(mean=ram_mean, std=ram_std)
self.tokenizer = tokenizer
def tokenize_caption(self, caption=""):
inputs = self.tokenizer(
caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
def __getitem__(self, index):
gt_path = self.gt_list[index]
gt_img = Image.open(gt_path).convert('RGB')
gt_img = self.img_preproc(gt_img)
lq_path = self.lr_list[index]
lq_img = Image.open(lq_path).convert('RGB')
lq_img = self.img_preproc(lq_img)
if random.random() < self.null_text_ratio:
tag = ''
else:
tag_path = self.tag_path_list[index]
file = open(tag_path, 'r')
tag = file.read()
file.close()
example = dict()
example["conditioning_pixel_values"] = lq_img.squeeze(0)
example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0
example["input_ids"] = self.tokenize_caption(caption=tag).squeeze(0)
lq_img = lq_img.squeeze()
ram_values = F.interpolate(lq_img.unsqueeze(0), size=(384, 384), mode='bicubic')
ram_values = ram_values.clamp(0.0, 1.0)
example["ram_values"] = self.ram_normalize(ram_values.squeeze(0))
return example
def __len__(self):
return len(self.gt_list)