import glob import os from PIL import Image import random import numpy as np from torch import nn from torchvision import transforms from torch.utils import data as data import torch.nn.functional as F from .realesrgan import RealESRGAN_degradation class PairedCaptionDataset(data.Dataset): def __init__( self, root_folders=None, tokenizer=None, null_text_ratio=0.5, # use_ram_encoder=False, # use_gt_caption=False, # caption_type = 'gt_caption', ): super(PairedCaptionDataset, self).__init__() self.null_text_ratio = null_text_ratio self.lr_list = [] self.gt_list = [] self.tag_path_list = [] root_folders = root_folders.split(',') for root_folder in root_folders: lr_path = root_folder +'/sr_bicubic' tag_path = root_folder +'/tag' gt_path = root_folder +'/gt' self.lr_list += glob.glob(os.path.join(lr_path, '*.png')) self.gt_list += glob.glob(os.path.join(gt_path, '*.png')) self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt')) assert len(self.lr_list) == len(self.gt_list) assert len(self.lr_list) == len(self.tag_path_list) self.img_preproc = transforms.Compose([ transforms.ToTensor(), ]) ram_mean = [0.485, 0.456, 0.406] ram_std = [0.229, 0.224, 0.225] self.ram_normalize = transforms.Normalize(mean=ram_mean, std=ram_std) self.tokenizer = tokenizer def tokenize_caption(self, caption=""): inputs = self.tokenizer( caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids def __getitem__(self, index): gt_path = self.gt_list[index] gt_img = Image.open(gt_path).convert('RGB') gt_img = self.img_preproc(gt_img) lq_path = self.lr_list[index] lq_img = Image.open(lq_path).convert('RGB') lq_img = self.img_preproc(lq_img) if random.random() < self.null_text_ratio: tag = '' else: tag_path = self.tag_path_list[index] file = open(tag_path, 'r') tag = file.read() file.close() example = dict() example["conditioning_pixel_values"] = lq_img.squeeze(0) example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 example["input_ids"] = self.tokenize_caption(caption=tag).squeeze(0) lq_img = lq_img.squeeze() ram_values = F.interpolate(lq_img.unsqueeze(0), size=(384, 384), mode='bicubic') ram_values = ram_values.clamp(0.0, 1.0) example["ram_values"] = self.ram_normalize(ram_values.squeeze(0)) return example def __len__(self): return len(self.gt_list)