InPeerReview's picture
Upload 3 files
840ef2c verified
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import lr_scheduler
from typing import Iterable, Set, Tuple
import logging
import os
logger = logging.getLogger('base')
def simplex(t: Tensor, axis=1) -> bool:
_sum = t.sum(axis).float()
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)
def one_hot(t: Tensor, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])
def uniq(a: Tensor) -> Set:
return set(torch.unique(a.cpu()).numpy())
def sset(a: Tensor, sub: Iterable) -> bool:
return uniq(a).issubset(sub)
def class2one_hot(seg: Tensor, C: int) -> Tensor:
if len(seg.shape) == 2: # (H, W) 的情况
seg = seg.unsqueeze(dim=0)
assert sset(seg, list(range(C))), "输入 Tensor 中的类别索引超出范围!"
if seg.ndim == 4:
seg = seg.squeeze(dim=1)
b, w, h = seg.shape # 获取 batch 维度、宽度、高度
res = torch.stack([seg == c for c in range(C)], dim=1).int()
assert res.shape == (b, C, w, h)
assert one_hot(res), "转换后的 Tensor 不是 one-hot 编码!"
return res
def get_scheduler(optimizer, args):
"""返回学习率调度器"""
if args['scheduler']['lr_policy'] == 'linear':
def lambda_rule(epoch):
return 1.0 - epoch / float(args['n_epoch'] + 1)
return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif args['scheduler']['lr_policy'] == 'step':
step_size = args['n_epoch'] // args['scheduler']['n_steps']
return lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=args['scheduler']['gamma'])
else:
raise NotImplementedError(f"学习率策略 [{args['scheduler']['lr_policy']}] 未实现!")
def save_network(opt, epoch, cd_model, optimizer, is_best_model=False):
""" 保存当前 epoch 的模型和优化器参数 """
os.makedirs(opt['path_cd']['checkpoint'], exist_ok=True)
cd_gen_path = os.path.join(opt['path_cd']['checkpoint'], f'cd_model_E{epoch}_gen.pth')
cd_opt_path = os.path.join(opt['path_cd']['checkpoint'], f'cd_model_E{epoch}_opt.pth')
best_cd_gen_path = os.path.join(opt['path_cd']['checkpoint'], 'best_cd_model_gen.pth')
best_cd_opt_path = os.path.join(opt['path_cd']['checkpoint'], 'best_cd_model_opt.pth')
network = cd_model.module if isinstance(cd_model, nn.DataParallel) else cd_model
state_dict = {key: param.cpu() for key, param in network.state_dict().items()}
torch.save(state_dict, cd_gen_path)
if is_best_model:
torch.save(state_dict, best_cd_gen_path)
opt_state = {
'epoch': epoch,
'scheduler': None,
'optimizer': optimizer.state_dict()
}
torch.save(opt_state, cd_opt_path)
if is_best_model:
torch.save(opt_state, best_cd_opt_path)
logger.info(f'✅ 当前模型已保存至 [{cd_gen_path}]')
if is_best_model:
logger.info(f'🏆 最佳模型已更新至 [{best_cd_gen_path}]')