|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch import Tensor
|
|
|
from torch.optim import lr_scheduler
|
|
|
from typing import Iterable, Set, Tuple
|
|
|
import logging
|
|
|
import os
|
|
|
|
|
|
logger = logging.getLogger('base')
|
|
|
|
|
|
def simplex(t: Tensor, axis=1) -> bool:
|
|
|
_sum = t.sum(axis).float()
|
|
|
_ones = torch.ones_like(_sum, dtype=torch.float32)
|
|
|
return torch.allclose(_sum, _ones)
|
|
|
|
|
|
def one_hot(t: Tensor, axis=1) -> bool:
|
|
|
return simplex(t, axis) and sset(t, [0, 1])
|
|
|
|
|
|
def uniq(a: Tensor) -> Set:
|
|
|
return set(torch.unique(a.cpu()).numpy())
|
|
|
|
|
|
def sset(a: Tensor, sub: Iterable) -> bool:
|
|
|
return uniq(a).issubset(sub)
|
|
|
|
|
|
def class2one_hot(seg: Tensor, C: int) -> Tensor:
|
|
|
if len(seg.shape) == 2:
|
|
|
seg = seg.unsqueeze(dim=0)
|
|
|
assert sset(seg, list(range(C))), "输入 Tensor 中的类别索引超出范围!"
|
|
|
|
|
|
if seg.ndim == 4:
|
|
|
seg = seg.squeeze(dim=1)
|
|
|
|
|
|
b, w, h = seg.shape
|
|
|
res = torch.stack([seg == c for c in range(C)], dim=1).int()
|
|
|
assert res.shape == (b, C, w, h)
|
|
|
assert one_hot(res), "转换后的 Tensor 不是 one-hot 编码!"
|
|
|
|
|
|
return res
|
|
|
|
|
|
def get_scheduler(optimizer, args):
|
|
|
"""返回学习率调度器"""
|
|
|
if args['scheduler']['lr_policy'] == 'linear':
|
|
|
def lambda_rule(epoch):
|
|
|
return 1.0 - epoch / float(args['n_epoch'] + 1)
|
|
|
return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
|
|
|
|
|
elif args['scheduler']['lr_policy'] == 'step':
|
|
|
step_size = args['n_epoch'] // args['scheduler']['n_steps']
|
|
|
return lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=args['scheduler']['gamma'])
|
|
|
|
|
|
else:
|
|
|
raise NotImplementedError(f"学习率策略 [{args['scheduler']['lr_policy']}] 未实现!")
|
|
|
|
|
|
def save_network(opt, epoch, cd_model, optimizer, is_best_model=False):
|
|
|
""" 保存当前 epoch 的模型和优化器参数 """
|
|
|
|
|
|
os.makedirs(opt['path_cd']['checkpoint'], exist_ok=True)
|
|
|
|
|
|
cd_gen_path = os.path.join(opt['path_cd']['checkpoint'], f'cd_model_E{epoch}_gen.pth')
|
|
|
cd_opt_path = os.path.join(opt['path_cd']['checkpoint'], f'cd_model_E{epoch}_opt.pth')
|
|
|
|
|
|
best_cd_gen_path = os.path.join(opt['path_cd']['checkpoint'], 'best_cd_model_gen.pth')
|
|
|
best_cd_opt_path = os.path.join(opt['path_cd']['checkpoint'], 'best_cd_model_opt.pth')
|
|
|
|
|
|
network = cd_model.module if isinstance(cd_model, nn.DataParallel) else cd_model
|
|
|
state_dict = {key: param.cpu() for key, param in network.state_dict().items()}
|
|
|
|
|
|
torch.save(state_dict, cd_gen_path)
|
|
|
if is_best_model:
|
|
|
torch.save(state_dict, best_cd_gen_path)
|
|
|
|
|
|
opt_state = {
|
|
|
'epoch': epoch,
|
|
|
'scheduler': None,
|
|
|
'optimizer': optimizer.state_dict()
|
|
|
}
|
|
|
torch.save(opt_state, cd_opt_path)
|
|
|
if is_best_model:
|
|
|
torch.save(opt_state, best_cd_opt_path)
|
|
|
|
|
|
logger.info(f'✅ 当前模型已保存至 [{cd_gen_path}]')
|
|
|
if is_best_model:
|
|
|
logger.info(f'🏆 最佳模型已更新至 [{best_cd_gen_path}]')
|
|
|
|