import torch from torch import nn from tqdm import tqdm import prettytable import time import os import multiprocessing.pool as mpp import multiprocessing as mp from train import * import argparse from utils.config import Config from tools.mask_convert import mask_save def get_args(): parser = argparse.ArgumentParser('description=Change detection of remote sensing images') parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py") parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--output_dir", type=str, default=None) return parser.parse_args() if __name__ == "__main__": args = get_args() cfg = Config.fromfile(args.config) ckpt = args.ckpt if ckpt is None: ckpt = cfg.test_ckpt_path assert ckpt is not None if args.output_dir: base_dir = args.output_dir else: base_dir = os.path.dirname(ckpt) masks_output_dir = os.path.join(base_dir, "mask_rgb") model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg) model = model.to('cuda') model.eval() metric_cfg_1 = cfg.metric_cfg1 metric_cfg_2 = cfg.metric_cfg2 test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda') test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda') test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda') test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda') test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda') results = [] with torch.no_grad(): test_loader = build_dataloader(cfg.dataset_config, mode='test') for input in tqdm(test_loader): raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3] if cfg.net == 'SARASNet': mask = Variable(resize_label(mask.data.cpu().numpy(), \ size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long() param = 1 # This parameter is balance precision and recall to get higher F1-score raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param if cfg.argmax: pred = raw_predictions.argmax(dim=1) else: if cfg.net == 'maskcd': pred = raw_predictions[0] pred = pred > 0.5 pred.squeeze_(1) else: pred = raw_predictions.squeeze(1) pred = pred > 0.5 test_oa(pred, mask) test_iou(pred, mask) test_prec(pred, mask) test_f1(pred, mask) test_recall(pred, mask) for i in range(raw_predictions.shape[0]): mask_real = mask[i].cpu().numpy() mask_pred = pred[i].cpu().numpy() mask_name = str(img_id[i]) results.append((mask_real, mask_pred, masks_output_dir, mask_name)) metrics = [test_prec.compute(), test_recall.compute(), test_f1.compute(), test_iou.compute()] total_metrics = [test_oa.compute().cpu().numpy(), np.mean([item.cpu() for item in metrics[0]]), np.mean([item.cpu() for item in metrics[1]]), np.mean([item.cpu() for item in metrics[2]]), np.mean([item.cpu() for item in metrics[3]])] result_table = prettytable.PrettyTable() result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] for i in range(2): item = [i, '--'] for j in range(len(metrics)): item.append(np.round(metrics[j][i].cpu().numpy(), 4)) result_table.add_row(item) total = [np.round(v, 4) for v in total_metrics] total.insert(0, 'total') result_table.add_row(total) print(result_table) file_name = os.path.join(base_dir, "test_res.txt") f = open(file_name,"a") current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time())) f.write(current_time+'\n') f.write(str(result_table)+'\n') if not os.path.exists(masks_output_dir): os.makedirs(masks_output_dir) print(masks_output_dir) t0 = time.time() mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) t1 = time.time() img_write_time = t1 - t0 print('images writing spends: {} s'.format(img_write_time))