Rakhi-2025's picture
Upload 95 files
911c613 verified
import os
import cv2
import json
import random
import datetime
import numpy as np
import matplotlib.pyplot as plt
class DataGen:
def __init__(self, path, split_ratio, x, y, color_space='rgb'):
self.x = x
self.y = y
self.path = path
self.color_space = color_space
self.path_train_images = path + "train/images/"
self.path_train_labels = path + "train/labels/"
self.path_test_images = path + "test/images/"
self.path_test_labels = path + "test/labels/"
self.image_file_list = get_png_filename_list(self.path_train_images)
self.label_file_list = get_png_filename_list(self.path_train_labels)
self.image_file_list[:], self.label_file_list[:] = self.shuffle_image_label_lists_together()
self.split_index = int(split_ratio * len(self.image_file_list))
self.x_train_file_list = self.image_file_list[self.split_index:]
self.y_train_file_list = self.label_file_list[self.split_index:]
self.x_val_file_list = self.image_file_list[:self.split_index]
self.y_val_file_list = self.label_file_list[:self.split_index]
self.x_test_file_list = get_png_filename_list(self.path_test_images)
self.y_test_file_list = get_png_filename_list(self.path_test_labels)
def generate_data(self, batch_size, train=False, val=False, test=False):
"""Replaces Keras' native ImageDataGenerator."""
try:
if train is True:
image_file_list = self.x_train_file_list
label_file_list = self.y_train_file_list
elif val is True:
image_file_list = self.x_val_file_list
label_file_list = self.y_val_file_list
elif test is True:
image_file_list = self.x_test_file_list
label_file_list = self.y_test_file_list
except ValueError:
print('one of train or val or test need to be True')
i = 0
while True:
image_batch = []
label_batch = []
for b in range(batch_size):
if i == len(self.x_train_file_list):
i = 0
if i < len(image_file_list):
sample_image_filename = image_file_list[i]
sample_label_filename = label_file_list[i]
# print('image: ', image_file_list[i])
# print('label: ', label_file_list[i])
if train or val:
image = cv2.imread(self.path_train_images + sample_image_filename, 1)
label = cv2.imread(self.path_train_labels + sample_label_filename, 0)
elif test is True:
image = cv2.imread(self.path_test_images + sample_image_filename, 1)
label = cv2.imread(self.path_test_labels + sample_label_filename, 0)
# image, label = self.change_color_space(image, label, self.color_space)
label = np.expand_dims(label, axis=2)
if image.shape[0] == self.x and image.shape[1] == self.y:
image_batch.append(image.astype("float32"))
else:
print('the input image shape is not {}x{}'.format(self.x, self.y))
if label.shape[0] == self.x and label.shape[1] == self.y:
label_batch.append(label.astype("float32"))
else:
print('the input label shape is not {}x{}'.format(self.x, self.y))
i += 1
if image_batch and label_batch:
image_batch = normalize(np.array(image_batch))
label_batch = normalize(np.array(label_batch))
yield (image_batch, label_batch)
def get_num_data_points(self, train=False, val=False):
try:
image_file_list = self.x_train_file_list if val is False and train is True else self.x_val_file_list
except ValueError:
print('one of train or val need to be True')
return len(image_file_list)
def shuffle_image_label_lists_together(self):
combined = list(zip(self.image_file_list, self.label_file_list))
random.shuffle(combined)
return zip(*combined)
@staticmethod
def change_color_space(image, label, color_space):
color_space = color_space.lower()
if color_space == 'hsi' or color_space == 'hsv':
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
label = cv2.cvtColor(label, cv2.COLOR_BGR2HSV)
elif color_space == 'lab':
image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
label = cv2.cvtColor(label, cv2.COLOR_BGR2LAB)
return image, label
def normalize(arr):
diff = np.amax(arr) - np.amin(arr)
diff = 255 if diff == 0 else diff
arr = arr / np.absolute(diff)
return arr
def get_png_filename_list(path):
file_list = []
for FileNameLength in range(0, 500):
for dirName, subdirList, fileList in os.walk(path):
for filename in fileList:
# check file extension
if ".png" in filename.lower() and len(filename) == FileNameLength:
file_list.append(filename)
break
file_list.sort()
return file_list
def get_jpg_filename_list(path):
file_list = []
for FileNameLength in range(0, 500):
for dirName, subdirList, fileList in os.walk(path):
for filename in fileList:
# check file extension
if ".jpg" in filename.lower() and len(filename) == FileNameLength:
file_list.append(filename)
break
file_list.sort()
return file_list
def load_jpg_images(path):
file_list = get_jpg_filename_list(path)
temp_list = []
for filename in file_list:
img = cv2.imread(path + filename, 1)
temp_list.append(img.astype("float32"))
temp_list = np.array(temp_list)
# x_train = np.reshape(x_train,(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))
return temp_list, file_list
def load_png_images(path):
temp_list = []
file_list = get_png_filename_list(path)
for filename in file_list:
img = cv2.imread(path + filename, 1)
temp_list.append(img.astype("float32"))
temp_list = np.array(temp_list)
#temp_list = np.reshape(temp_list,(temp_list.shape[0], temp_list.shape[1], temp_list.shape[2], 3))
return temp_list, file_list
def load_data(path):
# path_train_images = path + "train/images/padded/"
# path_train_labels = path + "train/labels/padded/"
# path_test_images = path + "test/images/padded/"
# path_test_labels = path + "test/labels/padded/"
path_train_images = path + "train/images/"
path_train_labels = path + "train/labels/"
path_test_images = path + "test/images/"
path_test_labels = path + "test/labels/"
x_train, train_image_filenames_list = load_png_images(path_train_images)
y_train, train_label_filenames_list = load_png_images(path_train_labels)
x_test, test_image_filenames_list = load_png_images(path_test_images)
y_test, test_label_filenames_list = load_png_images(path_test_labels)
x_train = normalize(x_train)
y_train = normalize(y_train)
x_test = normalize(x_test)
y_test = normalize(y_test)
return x_train, y_train, x_test, y_test, test_label_filenames_list
def load_test_images(path):
path_test_images = path + "test/images/"
x_test, test_image_filenames_list = load_png_images(path_test_images)
x_test = normalize(x_test)
return x_test, test_image_filenames_list
def save_results(np_array, color_space, outpath, test_label_filenames_list):
i = 0
for filename in test_label_filenames_list:
# predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1]))
pred = np_array[i]
# if color_space.lower() is 'hsi' or 'hsv':
# pred = cv2.cvtColor(pred, cv2.COLOR_HSV2RGB)
# elif color_space.lower() is 'lab':
# pred = cv2.cvtColor(pred, cv2.COLOR_Lab2RGB)
cv2.imwrite(outpath + filename, pred * 255.)
i += 1
def save_rgb_results(np_array, outpath, test_label_filenames_list):
i = 0
for filename in test_label_filenames_list:
# predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1]))
cv2.imwrite(outpath + filename, np_array[i] * 255.)
i += 1
def save_history(model, model_name, training_history, dataset, n_filters, epoch, learning_rate, loss,
color_space, path=None, temp_name=None):
save_weight_filename = temp_name if temp_name else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model.save('{}{}.hdf5'.format(path, save_weight_filename))
with open('{}{}.json'.format(path, save_weight_filename), 'w') as f:
json.dump(training_history.history, f, indent=2)
json_list = ['{}{}.json'.format(path, save_weight_filename)]
for json_filename in json_list:
with open(json_filename) as f:
# convert the loss json object to a python dict
loss_dict = json.load(f)
print_list = ['loss', 'val_loss', 'dice_coef', 'val_dice_coef']
for item in print_list:
item_list = []
if item in loss_dict:
item_list.extend(loss_dict.get(item))
plt.plot(item_list)
plt.title('model:{} lr:{} epoch:{} #filtr:{} Colorspaces:{}'.format(model_name, learning_rate,
epoch, n_filters, color_space))
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'test_loss', 'train_dice', 'test_dice'], loc='upper left')
plt.savefig('{}{}.png'.format(path, save_weight_filename))
plt.show()
plt.clf()