import os import cv2 import sys import tqdm import torch import datetime import torch.nn as nn import torch.distributed as dist import torch.cuda as cuda from torch.utils.data.dataloader import DataLoader from torch.optim import Adam, SGD from torch.utils.data.distributed import DistributedSampler from torch.cuda.amp.grad_scaler import GradScaler from torch.cuda.amp.autocast_mode import autocast filepath = os.path.split(os.path.abspath(__file__))[0] repopath = os.path.split(filepath)[0] sys.path.append(repopath) from lib import * from lib.optim import * from data.dataloader import * from utils.misc import * torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False def train(opt, args): train_dataset = eval(opt.Train.Dataset.type)( root=opt.Train.Dataset.root, sets=opt.Train.Dataset.sets, tfs=opt.Train.Dataset.transforms) if args.device_num > 1: cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.device_num, timeout=datetime.timedelta(seconds=3600)) train_sampler = DistributedSampler(train_dataset, shuffle=True) else: train_sampler = None train_loader = DataLoader(dataset=train_dataset, batch_size=opt.Train.Dataloader.batch_size, shuffle=train_sampler is None, sampler=train_sampler, num_workers=opt.Train.Dataloader.num_workers, pin_memory=opt.Train.Dataloader.pin_memory, drop_last=True) model_ckpt = None state_ckpt = None if args.resume is True: if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')): model_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'), map_location='cpu') if args.local_rank <= 0: print('Resume from checkpoint') if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth')): state_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth'), map_location='cpu') if args.local_rank <= 0: print('Resume from state') model = eval(opt.Model.name)(**opt.Model) if model_ckpt is not None: model.load_state_dict(model_ckpt) if args.device_num > 1: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = model.cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) else: model = model.cuda() backbone_params = nn.ParameterList() decoder_params = nn.ParameterList() for name, param in model.named_parameters(): if 'backbone' in name: backbone_params.append(param) else: decoder_params.append(param) params_list = [{'params': backbone_params}, { 'params': decoder_params, 'lr': opt.Train.Optimizer.lr * 10}] optimizer = eval(opt.Train.Optimizer.type)( params_list, opt.Train.Optimizer.lr, weight_decay=opt.Train.Optimizer.weight_decay) if state_ckpt is not None: optimizer.load_state_dict(state_ckpt['optimizer']) if opt.Train.Optimizer.mixed_precision is True: scaler = GradScaler() else: scaler = None scheduler = eval(opt.Train.Scheduler.type)(optimizer, gamma=opt.Train.Scheduler.gamma, minimum_lr=opt.Train.Scheduler.minimum_lr, max_iteration=len(train_loader) * opt.Train.Scheduler.epoch, warmup_iteration=opt.Train.Scheduler.warmup_iteration) if state_ckpt is not None: scheduler.load_state_dict(state_ckpt['scheduler']) model.train() start = 1 if state_ckpt is not None: start = state_ckpt['epoch'] epoch_iter = range(start, opt.Train.Scheduler.epoch + 1) if args.local_rank <= 0 and args.verbose is True: epoch_iter = tqdm.tqdm(epoch_iter, desc='Epoch', total=opt.Train.Scheduler.epoch, initial=start - 1, position=0, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}') for epoch in epoch_iter: if args.local_rank <= 0 and args.verbose is True: step_iter = tqdm.tqdm(enumerate(train_loader, start=1), desc='Iter', total=len( train_loader), position=1, leave=False, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}') if args.device_num > 1 and train_sampler is not None: train_sampler.set_epoch(epoch) else: step_iter = enumerate(train_loader, start=1) for i, sample in step_iter: optimizer.zero_grad() if opt.Train.Optimizer.mixed_precision is True and scaler is not None: with autocast(): sample = to_cuda(sample) out = model(sample) scaler.scale(out['loss']).backward() scaler.step(optimizer) scaler.update() scheduler.step() else: sample = to_cuda(sample) out = model(sample) out['loss'].backward() optimizer.step() scheduler.step() if args.local_rank <= 0 and args.verbose is True: step_iter.set_postfix({'loss': out['loss'].item()}) if args.local_rank <= 0: os.makedirs(opt.Train.Checkpoint.checkpoint_dir, exist_ok=True) os.makedirs(os.path.join( opt.Train.Checkpoint.checkpoint_dir, 'debug'), exist_ok=True) if epoch % opt.Train.Checkpoint.checkpoint_epoch == 0: if args.device_num > 1: model_ckpt = model.module.state_dict() else: model_ckpt = model.state_dict() state_ckpt = {'epoch': epoch + 1, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()} torch.save(model_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')) torch.save(state_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth')) if args.debug is True: debout = debug_tile(sum([out[k] for k in opt.Train.Debug.keys], []), activation=torch.sigmoid) cv2.imwrite(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'debug', str(epoch) + '.png'), debout) if args.local_rank <= 0: torch.save(model.module.state_dict() if args.device_num > 1 else model.state_dict(), os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')) if __name__ == '__main__': args = parse_args() opt = load_config(args.config) train(opt, args)