import torch.distributed as dist def rank0_print(*args): if dist.is_initialized(): if dist.get_rank() == 0: print(f"Rank {dist.get_rank()}: ", *args) else: print(*args) def rank_print(*args): if dist.is_initialized(): print(f"Rank {dist.get_rank()}: ", *args) else: print(*args)