File size: 516 Bytes
3a1da90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
from logging import Logger
import torch.distributed as dist
from meanaudio.utils.logger import TensorboardLogger

local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1

def info_if_rank_zero(logger: Logger, msg: str):
    if local_rank == 0:
        logger.info(msg)

def string_if_rank_zero(logger: TensorboardLogger, tag: str, msg: str):
    if local_rank == 0:
        logger.log_string(tag, msg)