File size: 4,648 Bytes
032c113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
from tqdm import tqdm
import prettytable
import time
import os
import multiprocessing.pool as mpp
import multiprocessing as mp

from train import *

import argparse
from utils.config import Config
from tools.mask_convert import mask_save

def get_args():
    parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
    parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default=None)
    return parser.parse_args()

if __name__ == "__main__":
    args = get_args()
    cfg = Config.fromfile(args.config)

    ckpt = args.ckpt
    if ckpt is None:
        ckpt = cfg.test_ckpt_path
    assert ckpt is not None

    if args.output_dir:
        base_dir = args.output_dir
    else:
        base_dir = os.path.dirname(ckpt)
    masks_output_dir = os.path.join(base_dir, "mask_rgb") 

    model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
    model = model.to('cuda')

    model.eval()

    metric_cfg_1 = cfg.metric_cfg1
    metric_cfg_2 = cfg.metric_cfg2
    
    test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
    test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
    test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
    test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
    test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')

    results = []
    with torch.no_grad():
        test_loader = build_dataloader(cfg.dataset_config, mode='test')
        for input in tqdm(test_loader):

            raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]

            if cfg.net == 'SARASNet':
                mask = Variable(resize_label(mask.data.cpu().numpy(), \
                                        size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
                param = 1  # This parameter is balance precision and recall to get higher F1-score
                raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param 
                
            if cfg.argmax:
                pred = raw_predictions.argmax(dim=1)
            else:
                if cfg.net == 'maskcd':
                    pred = raw_predictions[0]
                    pred = pred > 0.5
                    pred.squeeze_(1)
                else:
                    pred = raw_predictions.squeeze(1)
                    pred = pred > 0.5

            test_oa(pred, mask)
            test_iou(pred, mask)
            test_prec(pred, mask)
            test_f1(pred, mask)
            test_recall(pred, mask)

            for i in range(raw_predictions.shape[0]):
                mask_real = mask[i].cpu().numpy()
                mask_pred = pred[i].cpu().numpy()
                mask_name = str(img_id[i])
                results.append((mask_real, mask_pred, masks_output_dir, mask_name))

    metrics = [test_prec.compute(),
               test_recall.compute(),
               test_f1.compute(),
               test_iou.compute()]
    
    total_metrics = [test_oa.compute().cpu().numpy(),
                     np.mean([item.cpu() for item in metrics[0]]),
                     np.mean([item.cpu() for item in metrics[1]]),
                     np.mean([item.cpu() for item in metrics[2]]),
                     np.mean([item.cpu() for item in metrics[3]])]

    result_table = prettytable.PrettyTable()
    result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']

    for i in range(2):
        item = [i, '--']
        for j in range(len(metrics)):
            item.append(np.round(metrics[j][i].cpu().numpy(), 4))
        result_table.add_row(item)

    total = [np.round(v, 4) for v in total_metrics]
    total.insert(0, 'total')
    result_table.add_row(total)

    print(result_table)

    file_name = os.path.join(base_dir, "test_res.txt") 
    f = open(file_name,"a")
    current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
    f.write(current_time+'\n')
    f.write(str(result_table)+'\n')
 
    if not os.path.exists(masks_output_dir):
        os.makedirs(masks_output_dir)
    print(masks_output_dir)

    t0 = time.time()
    mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
    t1 = time.time()
    img_write_time = t1 - t0
    print('images writing spends: {} s'.format(img_write_time))