import os import sys import argparse import multiprocessing as mp from torch.nn.parallel import DistributedDataParallel import pointcept.utils.comm as comm from pointcept.utils.env import get_random_seed, set_seed from pointcept.utils.config import Config, DictAction def create_ddp_model(model, *, fp16_compression=False, **kwargs): """ Create a DistributedDataParallel model if there are >1 processes. Args: model: a torch.nn.Module fp16_compression: add fp16 compression hooks to the ddp object. See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. """ if comm.get_world_size() == 1: return model # kwargs['find_unused_parameters'] = True if "device_ids" not in kwargs: kwargs["device_ids"] = [comm.get_local_rank()] if "output_device" not in kwargs: kwargs["output_device"] = [comm.get_local_rank()] ddp = DistributedDataParallel(model, **kwargs) if fp16_compression: from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) return ddp def worker_init_fn(worker_id, num_workers, rank, seed): """Worker init func for dataloader. The seed of each worker equals to num_worker * rank + worker_id + user_seed Args: worker_id (int): Worker id. num_workers (int): Number of workers. rank (int): The rank of current process. seed (int): The random seed to use. """ worker_seed = num_workers * rank + worker_id + seed set_seed(worker_seed) def default_argument_parser(epilog=None): parser = argparse.ArgumentParser( epilog=epilog or f""" Examples: Run on single machine: $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml Change some config options: $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 Run on multiple machines: (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] """, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file" ) parser.add_argument( "--num-gpus", type=int, default=1, help="number of gpus *per machine*" ) parser.add_argument( "--num-machines", type=int, default=1, help="total number of machines" ) parser.add_argument( "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)", ) # PyTorch still may leave orphan processes in multi-gpu training. # Therefore we use a deterministic way to obtain port, # so that users are aware of orphan processes by seeing the port occupied. # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 parser.add_argument( "--dist-url", # default="tcp://127.0.0.1:{}".format(port), default="auto", help="initialization URL for pytorch distributed backend. See " "https://pytorch.org/docs/stable/distributed.html for details.", ) parser.add_argument( "--options", nargs="+", action=DictAction, help="custom options" ) return parser def default_config_parser(file_path, options): # config name protocol: dataset_name/model_name-exp_name if os.path.isfile(file_path): cfg = Config.fromfile(file_path) else: sep = file_path.find("-") cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) if options is not None: cfg.merge_from_dict(options) if cfg.seed is None: cfg.seed = get_random_seed() cfg.data.train.loop = cfg.epoch // cfg.eval_epoch os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) if not cfg.resume: cfg.dump(os.path.join(cfg.save_path, "config.py")) return cfg def default_setup(cfg): # scalar by world size world_size = comm.get_world_size() cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() cfg.num_worker_per_gpu = cfg.num_worker // world_size assert cfg.batch_size % world_size == 0 assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 cfg.batch_size_per_gpu = cfg.batch_size // world_size cfg.batch_size_val_per_gpu = ( cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 ) cfg.batch_size_test_per_gpu = ( cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 ) # update data loop assert cfg.epoch % cfg.eval_epoch == 0 # settle random seed rank = comm.get_rank() seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank set_seed(seed) return cfg