InPeerReview's picture
Upload 3 files
8b04570 verified
import torch
import torch.optim as optim
import data as Data
import models as Model
import torch.nn as nn
import argparse
import logging
import core.logger as Logger
import os
import numpy as np
from misc.metric_tools import ConfuseMatrixMeter
from models.loss import *
from collections import OrderedDict
import core.metrics as Metrics
from misc.torchutils import get_scheduler, save_network
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./config/whu/whu_test.json',
help='JSON file for configuration')
parser.add_argument('--phase', type=str, default='test',
choices=['train', 'test'], help='Run either train(training + validation) or testing',)
parser.add_argument('--gpu_ids', type=str, default=None)
parser.add_argument('-log_eval', action='store_true')
args = parser.parse_args()
opt = Logger.parse(args)
opt = Logger.dict_to_nonedict(opt)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
Logger.setup_logger(logger_name=None, root=opt['path_cd']['log'], phase='train',
level=logging.INFO, screen=True)
Logger.setup_logger(logger_name='test', root=opt['path_cd']['log'], phase='test',
level=logging.INFO)
logger = logging.getLogger('base')
logger.info(Logger.dict2str(opt))
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train' and args.phase != 'test':
print("Create [train] change-detection dataloader")
train_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
train_loader = Data.create_cd_dataloader(train_set, dataset_opt, phase)
opt['len_train_dataloader'] = len(train_loader)
elif phase == 'val' and args.phase != 'test':
print("Create [val] change-detection dataloader")
val_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
val_loader = Data.create_cd_dataloader(val_set, dataset_opt, phase)
opt['len_val_dataloader'] = len(val_loader)
elif phase == 'test' and args.phase == 'test':
print("Create [test] change-detection dataloader")
test_set = Data.create_cd_dataset(dataset_opt=dataset_opt, phase=phase)
test_loader = Data.create_cd_dataloader(test_set, dataset_opt, phase)
opt['len_test_dataloader'] = len(test_loader)
logger.info('Initial Dataset Finished')
cd_model = Model.create_CD_model(opt)
if opt['model']['loss'] == 'ce_dice':
loss_fun = ce_dice
elif opt['model']['loss'] == 'ce':
loss_fun = cross_entropy
if opt['train']["optimizer"]["type"] == 'adam':
optimer = optim.Adam(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"])
elif opt['train']["optimizer"]["type"] == 'adamw':
optimer = optim.AdamW(cd_model.parameters(), lr=opt['train']["optimizer"]["lr"])
device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
cd_model.to(device)
if len(opt['gpu_ids']) > 0:
cd_model = nn.DataParallel(cd_model)
metric = ConfuseMatrixMeter(n_class=2)
log_dict = OrderedDict()
if opt['phase'] == 'train':
best_mF1 = 0.0
for current_epoch in range(0, opt['train']['n_epoch']):
print("......Begin Training......")
metric.clear()
cd_model.train()
train_result_path = '{}/train/{}'.format(opt['path_cd']['result'], current_epoch)
os.makedirs(train_result_path, exist_ok=True)
message = 'lr: %0.7f\n \n' % optimer.param_groups[0]['lr']
logger.info(message)
for current_step, train_data in enumerate(train_loader):
train_im1 = train_data['A'].to(device)
train_im2 = train_data['B'].to(device)
pred_img = cd_model(train_im1, train_im2)
gt = train_data['L'].to(device).long()
train_loss = loss_fun(pred_img, gt)
optimer.zero_grad()
train_loss.backward()
optimer.step()
log_dict['loss'] = train_loss.item()
G_pred = pred_img.detach()
G_pred = torch.argmax(G_pred, dim=1)
current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
log_dict['running_acc'] = current_score.item()
if current_step % opt['train']['train_print_iter'] == 0:
logs = log_dict
message = '[Training CD]. epoch: [%d/%d]. Itter: [%d/%d], CD_loss: %.5f, running_mf1: %.5f\n' % \
(current_epoch, opt['train']['n_epoch'] - 1, current_step, len(train_loader), logs['loss'],
logs['running_acc'])
logger.info(message)
out_dict = OrderedDict()
out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
out_dict['gt_cm'] = gt
visuals = out_dict
img_mode = "grid"
if img_mode == "single":
img_A = Metrics.tensor2img(train_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
img_B = Metrics.tensor2img(train_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8,
min_max=(0, 1)) # uint8
pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1), out_type=np.uint8,
min_max=(0, 1)) # uint8
Metrics.save_img(
img_A, '{}/img_A_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
Metrics.save_img(
img_B, '{}/img_B_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
Metrics.save_img(
pred_cm, '{}/img_pred_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
Metrics.save_img(
gt_cm, '{}/img_gt_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
else:
visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
grid_img = torch.cat((train_data['A'].to(device),
train_data['B'].to(device),
visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
dim=0)
grid_img = Metrics.tensor2img(grid_img) # uint8
Metrics.save_img(
grid_img,
'{}/img_A_B_pred_gt_e{}_b{}.png'.format(train_result_path, current_epoch, current_step))
scores = metric.get_scores()
epoch_acc = scores['mf1']
log_dict['epoch_acc'] = epoch_acc.item()
for k, v in scores.items():
log_dict[k] = v
logs = log_dict
message = '[Training CD (epoch summary)]: epoch: [%d/%d]. epoch_mF1=%.5f \n' % \
(current_epoch, opt['train']['n_epoch'] - 1, logs['epoch_acc'])
for k, v in logs.items():
message += '{:s}: {:.4e} '.format(k, v)
message += '\n'
logger.info(message)
metric.clear()
cd_model.eval()
with torch.no_grad():
if current_epoch % opt['train']['val_freq'] == 0:
val_result_path = '{}/val/{}'.format(opt['path_cd']['result'], current_epoch)
os.makedirs(val_result_path, exist_ok=True)
for current_step, val_data in enumerate(val_loader):
val_img1 = val_data['A'].to(device)
val_img2 = val_data['B'].to(device)
pred_img = cd_model(val_img1, val_img2)
gt = val_data['L'].to(device).long()
val_loss = loss_fun(pred_img, gt)
log_dict['loss'] = val_loss.item()
G_pred = pred_img.detach()
G_pred = torch.argmax(G_pred, dim=1)
current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
log_dict['running_acc'] = current_score.item()
if current_step % opt['train']['val_print_iter'] == 0:
logs = log_dict
message = '[Validation CD]. epoch: [%d/%d]. Itter: [%d/%d], running_mf1: %.5f\n' % \
(current_epoch, opt['train']['n_epoch'] - 1, current_step, len(val_loader), logs['running_acc'])
logger.info(message)
out_dict = OrderedDict()
out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
out_dict['gt_cm'] = gt
visuals = out_dict
img_mode = "single"
if img_mode == "single":
img_A = Metrics.tensor2img(val_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
img_B = Metrics.tensor2img(val_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
out_type=np.uint8, min_max=(0, 1)) # uint8
pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
out_type=np.uint8, min_max=(0, 1)) # uint8
Metrics.save_img(
img_A, '{}/img_A_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
Metrics.save_img(
img_B, '{}/img_B_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
Metrics.save_img(
pred_cm, '{}/img_pred_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
Metrics.save_img(
gt_cm, '{}/img_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
else:
visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
grid_img = torch.cat((val_data['A'].to(device),
val_data['B'].to(device),
visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
dim=0)
grid_img = Metrics.tensor2img(grid_img) # uint8
Metrics.save_img(
grid_img,'{}/img_A_B_pred_gt_e{}_b{}.png'.format(val_result_path, current_epoch, current_step))
scores = metric.get_scores()
epoch_acc = scores['mf1']
log_dict['epoch_acc'] = epoch_acc.item()
for k, v in scores.items():
log_dict[k] = v
logs = log_dict
message = '[Validation CD (epoch summary)]: epoch: [%d/%d]. epoch_mF1=%.5f \n' % \
(current_epoch, opt['train']['n_epoch'], logs['epoch_acc'])
for k, v in logs.items():
message += '{:s}: {:.4e} '.format(k, v)
message += '\n'
logger.info(message)
if logs['epoch_acc'] > best_mF1:
is_best_model = True
best_mF1 = logs['epoch_acc']
logger.info('[Validation CD] Best model updated. Saving the models (current + best) and training states.')
else:
is_best_model = False
logger.info('[Validation CD] Saving the current cd model and training states.')
logger.info('--- Proceed To The Next Epoch ----\n \n')
save_network(opt, current_epoch, cd_model, optimer, is_best_model)
metric.clear()
get_scheduler(optimizer=optimer, args=opt['train']).step()
logger.info('End of training.')
else:
logger.info('Begin model evaluation (testing phase)')
test_result_path = '{}/test/'.format(opt['path_cd']['result'])
os.makedirs(test_result_path, exist_ok=True)
logger_test = logging.getLogger('test')
load_path = opt["path_cd"]["resume_state"]
print(load_path)
if load_path is not None:
logger.info('Loading pre-trained change detection model [{:s}] ...'.format(load_path))
gen_path = '{}_gen.pth'.format(load_path)
opt_path = '{}_opt.pth'.format(load_path)
cd_model = Model.create_CD_model(opt)
cpkt_state = torch.load(gen_path)
missing_keys, unexpected_keys = cd_model.load_state_dict(cpkt_state, strict=False)
print(missing_keys)
cd_model.to(device)
metric.clear()
cd_model.eval()
with torch.no_grad():
for current_step, test_data in enumerate(test_loader):
test_img1 = test_data['A'].to(device)
test_img2 = test_data['B'].to(device)
pred_img = cd_model(test_img1, test_img2)
if isinstance(pred_img, tuple):
pred_img = pred_img[0]
gt = test_data['L'].to(device).long()
G_pred = pred_img.detach()
G_pred = torch.argmax(G_pred, dim=1)
current_score = metric.update_cm(pr=G_pred.cpu().numpy(), gt=gt.detach().cpu().numpy())
log_dict['running_acc'] = current_score.item()
logs = log_dict
message = '[Test Change Detection] Iteration: [%d/%d], current mF1: %.5f\n' % \
(current_step, len(test_loader), logs['running_acc'])
logger_test.info(message)
out_dict = OrderedDict()
out_dict['pred_cm'] = torch.argmax(pred_img, dim=1, keepdim=False)
out_dict['gt_cm'] = gt
visuals = out_dict
img_mode = 'single'
if img_mode == 'single':
visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
img_A = Metrics.tensor2img(test_data['A'], out_type=np.uint8, min_max=(-1, 1)) # uint8
img_B = Metrics.tensor2img(test_data['B'], out_type=np.uint8, min_max=(-1, 1)) # uint8
gt_cm = Metrics.tensor2img(visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
out_type=np.uint8, min_max=(0, 1)) # uint8
pred_cm = Metrics.tensor2img(visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
out_type=np.uint8, min_max=(0, 1)) # uint8
Metrics.save_img(
img_A, '{}/img_A_{}.png'.format(test_result_path, current_step))
Metrics.save_img(
img_B, '{}/img_B_{}.png'.format(test_result_path, current_step))
Metrics.save_img(
pred_cm, '{}/img_pred_cm{}.png'.format(test_result_path, current_step))
Metrics.save_img(
gt_cm, '{}/img_gt_cm{}.png'.format(test_result_path, current_step))
else:
visuals['pred_cm'] = visuals['pred_cm'] * 2.0 - 1.0
visuals['gt_cm'] = visuals['gt_cm'] * 2.0 - 1.0
grid_img = torch.cat((test_data['A'],
test_data['B'],
visuals['pred_cm'].unsqueeze(1).repeat(1, 3, 1, 1),
visuals['gt_cm'].unsqueeze(1).repeat(1, 3, 1, 1)),
dim=0)
grid_img = Metrics.tensor2img(grid_img) # uint8
Metrics.save_img(
grid_img, '{}/img_A_B_pred_gt_{}.png'.format(test_result_path, current_step))
scores = metric.get_scores()
epoch_acc = scores['mf1']
log_dict['epoch_acc'] = epoch_acc.item()
for k, v in scores.items():
log_dict[k] = v
logs = log_dict
message = '[Test Change Detection Summary]: Test mF1=%.5f \n' % \
(logs['epoch_acc'])
for k, v in logs.items():
message += '{:s}: {:.4e} '.format(k, v)
message += '\n'
logger_test.info(message)
logger.info('Testing finished...')