Spaces:
Running
Running
File size: 9,936 Bytes
911c613 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import os
import cv2
import json
import random
import datetime
import numpy as np
import matplotlib.pyplot as plt
class DataGen:
def __init__(self, path, split_ratio, x, y, color_space='rgb'):
self.x = x
self.y = y
self.path = path
self.color_space = color_space
self.path_train_images = path + "train/images/"
self.path_train_labels = path + "train/labels/"
self.path_test_images = path + "test/images/"
self.path_test_labels = path + "test/labels/"
self.image_file_list = get_png_filename_list(self.path_train_images)
self.label_file_list = get_png_filename_list(self.path_train_labels)
self.image_file_list[:], self.label_file_list[:] = self.shuffle_image_label_lists_together()
self.split_index = int(split_ratio * len(self.image_file_list))
self.x_train_file_list = self.image_file_list[self.split_index:]
self.y_train_file_list = self.label_file_list[self.split_index:]
self.x_val_file_list = self.image_file_list[:self.split_index]
self.y_val_file_list = self.label_file_list[:self.split_index]
self.x_test_file_list = get_png_filename_list(self.path_test_images)
self.y_test_file_list = get_png_filename_list(self.path_test_labels)
def generate_data(self, batch_size, train=False, val=False, test=False):
"""Replaces Keras' native ImageDataGenerator."""
try:
if train is True:
image_file_list = self.x_train_file_list
label_file_list = self.y_train_file_list
elif val is True:
image_file_list = self.x_val_file_list
label_file_list = self.y_val_file_list
elif test is True:
image_file_list = self.x_test_file_list
label_file_list = self.y_test_file_list
except ValueError:
print('one of train or val or test need to be True')
i = 0
while True:
image_batch = []
label_batch = []
for b in range(batch_size):
if i == len(self.x_train_file_list):
i = 0
if i < len(image_file_list):
sample_image_filename = image_file_list[i]
sample_label_filename = label_file_list[i]
# print('image: ', image_file_list[i])
# print('label: ', label_file_list[i])
if train or val:
image = cv2.imread(self.path_train_images + sample_image_filename, 1)
label = cv2.imread(self.path_train_labels + sample_label_filename, 0)
elif test is True:
image = cv2.imread(self.path_test_images + sample_image_filename, 1)
label = cv2.imread(self.path_test_labels + sample_label_filename, 0)
# image, label = self.change_color_space(image, label, self.color_space)
label = np.expand_dims(label, axis=2)
if image.shape[0] == self.x and image.shape[1] == self.y:
image_batch.append(image.astype("float32"))
else:
print('the input image shape is not {}x{}'.format(self.x, self.y))
if label.shape[0] == self.x and label.shape[1] == self.y:
label_batch.append(label.astype("float32"))
else:
print('the input label shape is not {}x{}'.format(self.x, self.y))
i += 1
if image_batch and label_batch:
image_batch = normalize(np.array(image_batch))
label_batch = normalize(np.array(label_batch))
yield (image_batch, label_batch)
def get_num_data_points(self, train=False, val=False):
try:
image_file_list = self.x_train_file_list if val is False and train is True else self.x_val_file_list
except ValueError:
print('one of train or val need to be True')
return len(image_file_list)
def shuffle_image_label_lists_together(self):
combined = list(zip(self.image_file_list, self.label_file_list))
random.shuffle(combined)
return zip(*combined)
@staticmethod
def change_color_space(image, label, color_space):
color_space = color_space.lower()
if color_space == 'hsi' or color_space == 'hsv':
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
label = cv2.cvtColor(label, cv2.COLOR_BGR2HSV)
elif color_space == 'lab':
image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
label = cv2.cvtColor(label, cv2.COLOR_BGR2LAB)
return image, label
def normalize(arr):
diff = np.amax(arr) - np.amin(arr)
diff = 255 if diff == 0 else diff
arr = arr / np.absolute(diff)
return arr
def get_png_filename_list(path):
file_list = []
for FileNameLength in range(0, 500):
for dirName, subdirList, fileList in os.walk(path):
for filename in fileList:
# check file extension
if ".png" in filename.lower() and len(filename) == FileNameLength:
file_list.append(filename)
break
file_list.sort()
return file_list
def get_jpg_filename_list(path):
file_list = []
for FileNameLength in range(0, 500):
for dirName, subdirList, fileList in os.walk(path):
for filename in fileList:
# check file extension
if ".jpg" in filename.lower() and len(filename) == FileNameLength:
file_list.append(filename)
break
file_list.sort()
return file_list
def load_jpg_images(path):
file_list = get_jpg_filename_list(path)
temp_list = []
for filename in file_list:
img = cv2.imread(path + filename, 1)
temp_list.append(img.astype("float32"))
temp_list = np.array(temp_list)
# x_train = np.reshape(x_train,(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))
return temp_list, file_list
def load_png_images(path):
temp_list = []
file_list = get_png_filename_list(path)
for filename in file_list:
img = cv2.imread(path + filename, 1)
temp_list.append(img.astype("float32"))
temp_list = np.array(temp_list)
#temp_list = np.reshape(temp_list,(temp_list.shape[0], temp_list.shape[1], temp_list.shape[2], 3))
return temp_list, file_list
def load_data(path):
# path_train_images = path + "train/images/padded/"
# path_train_labels = path + "train/labels/padded/"
# path_test_images = path + "test/images/padded/"
# path_test_labels = path + "test/labels/padded/"
path_train_images = path + "train/images/"
path_train_labels = path + "train/labels/"
path_test_images = path + "test/images/"
path_test_labels = path + "test/labels/"
x_train, train_image_filenames_list = load_png_images(path_train_images)
y_train, train_label_filenames_list = load_png_images(path_train_labels)
x_test, test_image_filenames_list = load_png_images(path_test_images)
y_test, test_label_filenames_list = load_png_images(path_test_labels)
x_train = normalize(x_train)
y_train = normalize(y_train)
x_test = normalize(x_test)
y_test = normalize(y_test)
return x_train, y_train, x_test, y_test, test_label_filenames_list
def load_test_images(path):
path_test_images = path + "test/images/"
x_test, test_image_filenames_list = load_png_images(path_test_images)
x_test = normalize(x_test)
return x_test, test_image_filenames_list
def save_results(np_array, color_space, outpath, test_label_filenames_list):
i = 0
for filename in test_label_filenames_list:
# predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1]))
pred = np_array[i]
# if color_space.lower() is 'hsi' or 'hsv':
# pred = cv2.cvtColor(pred, cv2.COLOR_HSV2RGB)
# elif color_space.lower() is 'lab':
# pred = cv2.cvtColor(pred, cv2.COLOR_Lab2RGB)
cv2.imwrite(outpath + filename, pred * 255.)
i += 1
def save_rgb_results(np_array, outpath, test_label_filenames_list):
i = 0
for filename in test_label_filenames_list:
# predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1]))
cv2.imwrite(outpath + filename, np_array[i] * 255.)
i += 1
def save_history(model, model_name, training_history, dataset, n_filters, epoch, learning_rate, loss,
color_space, path=None, temp_name=None):
save_weight_filename = temp_name if temp_name else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model.save('{}{}.hdf5'.format(path, save_weight_filename))
with open('{}{}.json'.format(path, save_weight_filename), 'w') as f:
json.dump(training_history.history, f, indent=2)
json_list = ['{}{}.json'.format(path, save_weight_filename)]
for json_filename in json_list:
with open(json_filename) as f:
# convert the loss json object to a python dict
loss_dict = json.load(f)
print_list = ['loss', 'val_loss', 'dice_coef', 'val_dice_coef']
for item in print_list:
item_list = []
if item in loss_dict:
item_list.extend(loss_dict.get(item))
plt.plot(item_list)
plt.title('model:{} lr:{} epoch:{} #filtr:{} Colorspaces:{}'.format(model_name, learning_rate,
epoch, n_filters, color_space))
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'test_loss', 'train_dice', 'test_dice'], loc='upper left')
plt.savefig('{}{}.png'.format(path, save_weight_filename))
plt.show()
plt.clf()
|