File size: 3,622 Bytes
3a1da90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import logging
import os
import random

import numpy as np
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
from tqdm import tqdm

from meanaudio.data.data_setup import setup_test_datasets
from meanaudio.runner_flowmatching import RunnerFlowMatching
from meanaudio.runner_meanflow import RunnerMeanFlow
from meanaudio.utils.dist_utils import info_if_rank_zero
from meanaudio.utils.logger import TensorboardLogger
import torch.distributed as distributed

local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])


def sample(cfg: DictConfig):
    # initial setup
    num_gpus = world_size
    run_dir = HydraConfig.get().run.dir  # ./output/$exp_name

    # wrap python logger with a tensorboard logger
    log = TensorboardLogger(cfg.exp_id,
                            run_dir,
                            logging.getLogger(),
                            is_rank0=(local_rank == 0),
                            enable_email=cfg.enable_email and not cfg.debug)

    info_if_rank_zero(log, f'All configuration: {cfg}')
    info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}')

    # cuda setup
    torch.cuda.set_device(local_rank)
    torch.backends.cudnn.benchmark = cfg.cudnn_benchmark

    # number of dataloader workers
    info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}')

    # Set seeds to ensure the same initialization
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    # setting up configurations
    info_if_rank_zero(log, f'Configuration: {cfg}')
    info_if_rank_zero(log, f'Batch size (per GPU): {cfg.eval_batch_size}')

    # construct the trainer
    if not cfg.use_repa: 
        if not cfg.use_meanflow:
            runner = RunnerFlowMatching(cfg, log=log, run_path=run_dir, for_training=False).enter_val()
        else:
            runner = RunnerMeanFlow(cfg, log=log, run_path=run_dir, for_training=False).enter_val()
    else: 
        raise NotImplementedError('REPA is not supported yet')
        runner = RunnerAT_REPA(cfg, log=log, run_path=run_dir, for_training=False).enter_val()

    ## we only load the ema ckpt for final eval
    weights = runner.get_final_ema_weight_path()
    if weights is not None:
        info_if_rank_zero(log, f'Automatically finding weight: {weights}')
        runner.load_weights(weights)

    # setup datasets
    dataset, sampler, loader = setup_test_datasets(cfg)
    data_cfg = cfg.data.AudioCaps_test_npz  # base_at data config
    with open_dict(data_cfg):
        if cfg.output_name is not None:
            # append to the tag
            data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}'

    # loop
    audio_path = None
    for curr_iter, data in enumerate(tqdm(loader)):
        new_audio_path = runner.inference_pass(data, curr_iter, data_cfg)  # generate audio
        if audio_path is None:
            audio_path = new_audio_path
        else:
            assert audio_path == new_audio_path, 'Different audio path detected'

    distributed.barrier()  # waiting till all processes finish generation
    info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}')
    output_metrics = runner.eval(audio_path, curr_iter, data_cfg)

    if local_rank == 0:
        # write the output metrics to run_dir
        output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json')
        with open(output_metrics_path, 'w') as f:
            json.dump(output_metrics, f, indent=4)
        print(f"Results saved in {output_metrics_path}")