File size: 1,666 Bytes
287c28c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
import os
from functools import wraps
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
def get_rank() -> int:
rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def rank_zero_only(fn: Callable) -> Callable:
@wraps(fn)
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
if rank_zero_only.rank == 0:
return fn(*args, **kwargs)
return None
return wrapped_fn
rank_zero_only.rank = getattr(rank_zero_only, "rank", get_rank())
@rank_zero_only
def rank_zero_print(message: str, *args, **kwargs) -> None: # pylint: disable=unused-argument
print(message)
@rank_zero_only
def rank_zero_logger_info(message: str, logger: "Logger", *args, **kwargs) -> None: # pylint: disable=unused-argument
logger.info(message)
def reduce_tensor(tensor, num_gpus):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= num_gpus
return rt
def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
# Set cuda device so everything is done on the right GPU.
torch.cuda.set_device(rank % torch.cuda.device_count())
# Initialize distributed communication
dist.init_process_group(
dist_backend,
init_method=dist_url,
world_size=num_gpus,
rank=rank,
group_name=group_name,
)
|