File size: 5,977 Bytes
8cf4db8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""
CD Dataset
"""
import os
from PIL import Image
import numpy as np
from torch.utils import data
import data.util as Util
from torch.utils.data import Dataset
import torchvision
import torch
totensor = torchvision.transforms.ToTensor()
"""
CD Dataset
├─image
├─image_post
├─label
└─list
"""
IMG_FOLDER_NAME = 'A'
IMG_POST_FOLDER_NAME = 'B'
LABEL_FOLDER_NAME = 'label'
LABEL1_FOLDER_NAME = 'label1'
LABEL2_FOLDER_NAME = 'label2'
LIST_FOLDER_NAME = 'list'
label_suffix = ".png"
#list内存放image_name 构建读取图片名字函数
def load_img_name_list(dataset_path):
img_name_list = np.loadtxt(dataset_path, dtype=np.str_)
if img_name_list.ndim == 2:
return img_name_list[:, 0]
return img_name_list
#获取各个文件夹的路径
def get_img_path(root_dir, img_name):
return os.path.join(root_dir, IMG_FOLDER_NAME, img_name)
def get_img_post_path(root_dir, img_name):
return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name)
def get_label_path(root_dir, img_name):
return os.path.join(root_dir, LABEL_FOLDER_NAME, img_name)
def get_label1_path(root_dir, img_name):
return os.path.join(root_dir, LABEL1_FOLDER_NAME, img_name)
def get_label2_path(root_dir, img_name):
return os.path.join(root_dir, LABEL2_FOLDER_NAME, img_name)
class CDDataset(Dataset):
def __init__(self, root_dir, resolution=256, split='train', data_len=-1, label_transform=None):
self.root_dir = root_dir
self.resolution = resolution
self.data_len = data_len
self.split = split #train / val / test
self.label_transform = label_transform
self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split + '.txt')
self.img_name_list = load_img_name_list(self.list_path)
self.dataset_len = len(self.img_name_list)
if self.data_len <= 0:
self.data_len = self.dataset_len
else:
self.data_len = min(self.dataset_len, self.data_len)
def __len__(self):
return self.data_len
def __getitem__(self, index):
A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
img_A = Image.open(A_path).convert('RGB')
img_B = Image.open(B_path).convert('RGB')
L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
img_label = Image.open(L_path).convert("RGB")
img_A = Util.transform_augment_cd(img_A, min_max=(-1, 1))
img_B = Util.transform_augment_cd(img_B, min_max=(-1, 1))
img_label = Util.transform_augment_cd(img_label, min_max=(0, 1))
if img_label.dim() > 2:
img_label = img_label[0]
return {'A':img_A, 'B':img_B, 'L':img_label, 'Index':index}
class SCDDataset(Dataset):
def __init__(self, root_dir, resolution=512, split='train', data_len=-1, label_transform=None):
self.root_dir = root_dir
self.resolution = resolution
self.data_len = data_len
self.split = split #train / val / test
self.label_transform = label_transform
self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split + '.txt')
self.img_name_list = load_img_name_list(self.list_path)
self.dataset_len = len(self.img_name_list)
if self.data_len <= 0:
self.data_len = self.dataset_len
else:
self.data_len = min(self.dataset_len, self.data_len)
def __len__(self):
return self.data_len
def __getitem__(self, index):
A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
name = A_path.split('\\')[-1].split('.')[0]
img_A = Image.open(A_path).convert('RGB')
img_B = Image.open(B_path).convert('RGB')
L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
L1_path = get_label1_path(self.root_dir, self.img_name_list[index % self.data_len])
L2_path = get_label2_path(self.root_dir, self.img_name_list[index % self.data_len])
img_label = np.array(Image.open(L_path), dtype=np.uint8)
img_label1 = np.array(Image.open(L1_path), dtype=np.uint8)
img_label2 = np.array(Image.open(L2_path), dtype=np.uint8)
img_A = Util.transform_augment_cd(img_A, min_max=(-1, 1))
img_B = Util.transform_augment_cd(img_B, min_max=(-1, 1))
img_label = torch.from_numpy(img_label)
img_label1 = torch.from_numpy(img_label1)
# add cls label on label1
cls_category1 = torch.unique(img_label1)
cls_label1 = torch.zeros(7, dtype = int)
for index in cls_category1:
cls_label1[int(index)] = 1
img_label2 = torch.from_numpy(img_label2)
# add cls label on label2
cls_category2 = torch.unique(img_label2)
cls_label2 = torch.zeros(7, dtype=int)
for index in cls_category2:
cls_label2[int(index)] = 1
if img_label.dim() > 2:
img_label = img_label[0]
img_label1 = img_label1[0]
img_label2 = img_label2[0]
return {'A':img_A, 'B':img_B, 'L':img_label, 'L1':img_label1, 'L2':img_label2,
'Index':index, 'name':name, 'cls1':cls_label1, 'cls2':cls_label2}
if __name__ == '__main__':
root_dir = r'E:\cddataset\mmcd\Second_my'
cddata = SCDDataset(root_dir=root_dir)
list_path = os.path.join(root_dir, 'list', 'val', '.txt')
for i in range(593):
cls_labe1 = cddata.__getitem__(i)['cls1']
print(cls_labe1)
cls_labe2 = cddata.__getitem__(i)['cls2']
print(cls_labe2)
|