wizzseen's picture
Upload 948 files
8a6df40 verified
import argparse
import os
import numpy as np
from PIL import Image
from scipy.stats import entropy
from skimage.metrics import structural_similarity as ssim
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as Transforms
from torchvision.models.inception import inception_v3
import eval_models as models
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--evaluation', default='LPIPS')
parser.add_argument('--predict_dir', default='./result/bg_ver1/output/')
parser.add_argument('--ground_truth_dir', default='./data/zalando-hd-resize/test/image')
parser.add_argument('--resolution', type=int, default=1024)
opt = parser.parse_args()
return opt
def Evaluation(opt, pred_list, gt_list):
T1 = Transforms.ToTensor()
T2 = Transforms.Compose([Transforms.Resize((128, 128)),
Transforms.ToTensor(),
Transforms.Normalize(mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))])
T3 = Transforms.Compose([Transforms.Resize((299, 299)),
Transforms.ToTensor(),
Transforms.Normalize(mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))])
splits = 1 # Hyper-parameter for IS score
model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=True)
model.eval()
inception_model = inception_v3(pretrained=True, transform_input=False).type(torch.cuda.FloatTensor)
inception_model.eval()
avg_ssim, avg_mse, avg_distance = 0.0, 0.0, 0.0
preds = np.zeros((len(gt_list), 1000))
lpips_list = []
with torch.no_grad():
print("Calculate SSIM, MSE, LPIPS...")
for i, img_pred in enumerate(pred_list):
img = img_pred.split('_')[0] + '_00.jpg'
# Calculate SSIM
gt_img = Image.open(os.path.join(opt.ground_truth_dir, img))
if not opt.resolution == 1024:
if opt.resolution == 512:
gt_img = gt_img.resize((384,512), Image.BILINEAR)
elif opt.resolution == 256:
gt_img = gt_img.resize((192,256), Image.BILINEAR)
else:
raise NotImplementedError
gt_np = np.asarray(gt_img.convert('L'))
pred_img = Image.open(os.path.join(opt.predict_dir, img_pred))
assert gt_img.size == pred_img.size, f"{gt_img.size} vs {pred_img.size}"
pred_np = np.asarray(pred_img.convert('L'))
avg_ssim += ssim(gt_np, pred_np, data_range=255, gaussian_weights=True, use_sample_covariance=False)
# Calculate LPIPS
gt_img_LPIPS = T2(gt_img).unsqueeze(0).cuda()
pred_img_LPIPS = T2(pred_img).unsqueeze(0).cuda()
lpips_list.append((img_pred, model.forward(gt_img_LPIPS, pred_img_LPIPS).item()))
avg_distance += lpips_list[-1][1]
# Calculate Inception model prediction
pred_img_IS = T3(pred_img).unsqueeze(0).cuda()
preds[i] = F.softmax(inception_model(pred_img_IS)).data.cpu().numpy()
gt_img_MSE = T1(gt_img).unsqueeze(0).cuda()
pred_img_MSE = T1(pred_img).unsqueeze(0).cuda()
avg_mse += F.mse_loss(gt_img_MSE, pred_img_MSE)
print(f"step: {i+1} evaluation... lpips:{lpips_list[-1][1]}")
avg_ssim /= len(gt_list)
avg_mse = avg_mse / len(gt_list)
avg_distance = avg_distance / len(gt_list)
# Calculate Inception Score
split_scores = [] # Now compute the mean kl-divergence
lpips_list.sort(key=lambda x: x[1], reverse=True)
for name, score in lpips_list:
f = open(os.path.join(opt.predict_dir, 'lpips.txt'), 'a')
f.write(f"{name} {score}\n")
f.close()
print("Calculate Inception Score...")
for k in range(splits):
part = preds[k * (len(gt_list) // splits): (k+1) * (len(gt_list) // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
IS_mean, IS_std = np.mean(split_scores), np.std(split_scores)
f = open(os.path.join(opt.predict_dir, 'eval.txt'), 'a')
f.write(f"SSIM : {avg_ssim} / MSE : {avg_mse} / LPIPS : {avg_distance}\n")
f.write(f"IS_mean : {IS_mean} / IS_std : {IS_std}\n")
f.close()
return avg_ssim, avg_mse, avg_distance, IS_mean, IS_std
def main():
opt = get_opt()
# Output과 Ground Truth Data
pred_list = os.listdir(opt.predict_dir)
gt_list = os.listdir(opt.ground_truth_dir)
pred_list.sort()
gt_list.sort()
avg_ssim, avg_mse, avg_distance, IS_mean, IS_std = Evaluation(opt, pred_list, gt_list)
print("SSIM : %f / MSE : %f / LPIPS : %f" % (avg_ssim, avg_mse, avg_distance))
print("IS_mean : %f / IS_std : %f" % (IS_mean, IS_std))
if __name__ == '__main__':
main()