|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import pathlib |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from apex.optimizers import FusedAdam, FusedLAMB |
|
from torch.nn.modules.loss import _Loss |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.optim import Optimizer |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
from tqdm import tqdm |
|
|
|
from se3_transformer.data_loading import QM9DataModule |
|
from se3_transformer.model import SE3TransformerPooled |
|
from se3_transformer.model.fiber import Fiber |
|
from se3_transformer.runtime import gpu_affinity |
|
from se3_transformer.runtime.arguments import PARSER |
|
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ |
|
PerformanceCallback |
|
from se3_transformer.runtime.inference import evaluate |
|
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger |
|
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ |
|
using_tensor_cores, increase_l2_fetch_granularity |
|
|
|
|
|
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): |
|
""" Saves model, optimizer and epoch states to path (only once per node) """ |
|
if get_local_rank() == 0: |
|
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() |
|
checkpoint = { |
|
'state_dict': state_dict, |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'epoch': epoch |
|
} |
|
for callback in callbacks: |
|
callback.on_checkpoint_save(checkpoint) |
|
|
|
torch.save(checkpoint, str(path)) |
|
logging.info(f'Saved checkpoint to {str(path)}') |
|
|
|
|
|
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): |
|
""" Loads model, optimizer and epoch states from path """ |
|
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) |
|
if isinstance(model, DistributedDataParallel): |
|
model.module.load_state_dict(checkpoint['state_dict']) |
|
else: |
|
model.load_state_dict(checkpoint['state_dict']) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
for callback in callbacks: |
|
callback.on_checkpoint_load(checkpoint) |
|
|
|
logging.info(f'Loaded checkpoint from {str(path)}') |
|
return checkpoint['epoch'] |
|
|
|
|
|
def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): |
|
losses = [] |
|
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', |
|
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): |
|
*inputs, target = to_cuda(batch) |
|
|
|
for callback in callbacks: |
|
callback.on_batch_start() |
|
|
|
with torch.cuda.amp.autocast(enabled=args.amp): |
|
pred = model(*inputs) |
|
loss = loss_fn(pred, target) / args.accumulate_grad_batches |
|
|
|
grad_scaler.scale(loss).backward() |
|
|
|
|
|
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): |
|
if args.gradient_clip: |
|
grad_scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) |
|
|
|
grad_scaler.step(optimizer) |
|
grad_scaler.update() |
|
optimizer.zero_grad() |
|
|
|
losses.append(loss.item()) |
|
|
|
return np.mean(losses) |
|
|
|
|
|
def train(model: nn.Module, |
|
loss_fn: _Loss, |
|
train_dataloader: DataLoader, |
|
val_dataloader: DataLoader, |
|
callbacks: List[BaseCallback], |
|
logger: Logger, |
|
args): |
|
device = torch.cuda.current_device() |
|
model.to(device=device) |
|
local_rank = get_local_rank() |
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
|
if dist.is_initialized(): |
|
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) |
|
|
|
model.train() |
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) |
|
if args.optimizer == 'adam': |
|
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), |
|
weight_decay=args.weight_decay) |
|
elif args.optimizer == 'lamb': |
|
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), |
|
weight_decay=args.weight_decay) |
|
else: |
|
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, |
|
weight_decay=args.weight_decay) |
|
|
|
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 |
|
|
|
for callback in callbacks: |
|
callback.on_fit_start(optimizer, args) |
|
|
|
for epoch_idx in range(epoch_start, args.epochs): |
|
if isinstance(train_dataloader.sampler, DistributedSampler): |
|
train_dataloader.sampler.set_epoch(epoch_idx) |
|
|
|
loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) |
|
if dist.is_initialized(): |
|
loss = torch.tensor(loss, dtype=torch.float, device=device) |
|
torch.distributed.all_reduce(loss) |
|
loss = (loss / world_size).item() |
|
|
|
logging.info(f'Train loss: {loss}') |
|
logger.log_metrics({'train loss': loss}, epoch_idx) |
|
|
|
for callback in callbacks: |
|
callback.on_epoch_end() |
|
|
|
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ |
|
and (epoch_idx + 1) % args.ckpt_interval == 0: |
|
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) |
|
|
|
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0: |
|
evaluate(model, val_dataloader, callbacks, args) |
|
model.train() |
|
|
|
for callback in callbacks: |
|
callback.on_validation_end(epoch_idx) |
|
|
|
if args.save_ckpt_path is not None and not args.benchmark: |
|
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) |
|
|
|
for callback in callbacks: |
|
callback.on_fit_end() |
|
|
|
|
|
def print_parameters_count(model): |
|
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logging.info(f'Number of trainable parameters: {num_params_trainable}') |
|
|
|
|
|
if __name__ == '__main__': |
|
is_distributed = init_distributed() |
|
local_rank = get_local_rank() |
|
args = PARSER.parse_args() |
|
|
|
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) |
|
|
|
logging.info('====== SE(3)-Transformer ======') |
|
logging.info('| Training procedure |') |
|
logging.info('===============================') |
|
|
|
if args.seed is not None: |
|
logging.info(f'Using seed {args.seed}') |
|
seed_everything(args.seed) |
|
|
|
logger = LoggerCollection([ |
|
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), |
|
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer') |
|
]) |
|
|
|
datamodule = QM9DataModule(**vars(args)) |
|
model = SE3TransformerPooled( |
|
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), |
|
fiber_out=Fiber({0: args.num_degrees * args.num_channels}), |
|
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), |
|
output_dim=1, |
|
tensor_cores=using_tensor_cores(args.amp), |
|
**vars(args) |
|
) |
|
loss_fn = nn.L1Loss() |
|
|
|
if args.benchmark: |
|
logging.info('Running benchmark mode') |
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] |
|
else: |
|
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), |
|
QM9LRSchedulerCallback(logger, epochs=args.epochs)] |
|
|
|
if is_distributed: |
|
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) |
|
|
|
print_parameters_count(model) |
|
logger.log_hyperparams(vars(args)) |
|
increase_l2_fetch_granularity() |
|
train(model, |
|
loss_fn, |
|
datamodule.train_dataloader(), |
|
datamodule.val_dataloader(), |
|
callbacks, |
|
logger, |
|
args) |
|
|
|
logging.info('Training finished successfully') |
|
|