File size: 5,339 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import *
from argparse import Namespace
from omegaconf import DictConfig
from torch.nn import Module
from accelerate import Accelerator
from src.options import Options

import os
from omegaconf import OmegaConf
from accelerate import load_checkpoint_and_dispatch


def load_ckpt(
    ckpt_dir: str, ckpt_iter: int,
    hdfs_dir: Optional[str] = None,
    model: Optional[Module] = None,
    accelerator: Optional[Accelerator] = None,
    strict: bool = True,
    use_ema: bool = False,
) -> Module:
    # Find the latest checkpoint
    if ckpt_iter < 0:
        if hdfs_dir is not None:
            ckpt_iter = int(sorted(get_hdfs_files(hdfs_dir))[-1].split(".")[0])
        else:
            ckpt_iter = int(sorted(os.listdir(ckpt_dir))[-1])

    # Download checkpoint
    ckpt_path = f"{ckpt_dir}/{ckpt_iter:06d}"
    if not os.path.exists(ckpt_path):
        assert hdfs_dir is not None
        if accelerator is not None:
            if accelerator.is_main_process:
                ensure_sysrun(f"mkdir -p {ckpt_dir}")
                ensure_sysrun(f"hdfs dfs -get {hdfs_dir}/{ckpt_iter:06d}.tar {ckpt_dir}")
                ensure_sysrun(f"tar -xvf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir}")
                ensure_sysrun(f"rm {ckpt_dir}/{ckpt_iter:06d}.tar")
            accelerator.wait_for_everyone()  # wait before preparing checkpoints by the main process
        else:
            ensure_sysrun(f"mkdir -p {ckpt_dir}")
            ensure_sysrun(f"hdfs dfs -get {hdfs_dir}/{ckpt_iter:06d}.tar {ckpt_dir}")
            ensure_sysrun(f"tar -xvf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir}")
            ensure_sysrun(f"rm {ckpt_dir}/{ckpt_iter:06d}.tar")

    if model is None:
        return ckpt_iter
    # Load checkpoint
    else:
        try:
            load_checkpoint_and_dispatch(model, ckpt_path, strict=strict)
        except:  # from DeepSpeed
            ckpt_dir = f"{ckpt_dir}/{ckpt_iter:06d}"
            if accelerator is not None:
                if accelerator.is_main_process:
                    ensure_sysrun(f"python3 {ckpt_dir}/zero_to_fp32.py {ckpt_dir} {ckpt_dir} --safe_serialization")
                accelerator.wait_for_everyone()  # wait before preparing checkpoints by the main process
            else:
                ensure_sysrun(f"python3 {ckpt_dir}/zero_to_fp32.py {ckpt_dir} {ckpt_dir} --safe_serialization")
            load_checkpoint_and_dispatch(model, ckpt_path, strict=strict)

        return model


def save_ckpt(ckpt_dir: str, ckpt_iter: int, hdfs_dir: Optional[str] = None):
    if hdfs_dir is not None:
        ensure_sysrun(f"tar -cf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir} {ckpt_iter:06d}")
        ensure_sysrun(f"hdfs dfs -put -f {ckpt_dir}/{ckpt_iter:06d}.tar {hdfs_dir}")
        ensure_sysrun(f"rm -rf {ckpt_dir}/{ckpt_iter:06d}.tar {ckpt_dir}/{ckpt_iter:06d}")


def get_configs(yaml_path: str, cli_configs: List[str]=[], **kwargs) -> DictConfig:
    yaml_configs = OmegaConf.load(yaml_path)
    cli_configs = OmegaConf.from_cli(cli_configs)

    configs = OmegaConf.merge(yaml_configs, cli_configs, kwargs)
    OmegaConf.resolve(configs)  # resolve ${...} placeholders
    return configs


def save_experiment_params(args: Namespace, configs: DictConfig, opt: Options, save_dir: str) -> Dict[str, Any]:
    os.makedirs(save_dir, exist_ok=True)

    params = OmegaConf.merge(configs, {k: str(v) for k, v in vars(args).items()})
    params = OmegaConf.merge(params, OmegaConf.create(vars(opt)))
    OmegaConf.save(params, os.path.join(save_dir, "params.yaml"))
    return dict(params)


def save_model_architecture(model: Module, save_dir: str) -> None:
    os.makedirs(save_dir, exist_ok=True)

    num_buffers = sum(b.numel() for b in model.buffers())
    num_params = sum(p.numel() for p in model.parameters())
    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    message = f"Number of buffers: {num_buffers}\n" +\
        f"Number of trainable / all parameters: {num_trainable_params} / {num_params}\n\n" +\
        f"Model architecture:\n{model}"

    with open(os.path.join(save_dir, "model.txt"), "w") as f:
        f.write(message)


def ensure_sysrun(cmd: str):
    while True:
        result = os.system(cmd)
        if result == 0:
            break
        else:
            print(f"Retry running {cmd}")


def get_hdfs_files(hdfs_path: str) -> List[str]:
    lines = get_hdfs_lines(hdfs_path)
    if len(lines) == 0:
        raise ValueError(f"No files found in {hdfs_path}")

    return [line.split()[-1].split("/")[-1] for line in lines]


def get_hdfs_size(hdfs_path: str, unit: str="B") -> int:
    lines = get_hdfs_lines(hdfs_path)
    if len(lines) == 0:
        raise ValueError(f"No files found in {hdfs_path}")

    byte_size = sum(int(line.split()[4]) for line in lines)
    if unit == "B":
        return byte_size
    elif unit == "KB":
        return byte_size / 1024
    elif unit == "MB":
        return byte_size / 1024 / 1024
    elif unit == "GB":
        return byte_size / 1024 / 1024 / 1024
    elif unit == "TB":
        return byte_size / 1024 / 1024 / 1024 / 1024
    else:
        raise ValueError(f"Invalid unit: {unit}")


def get_hdfs_lines(hdfs_path: str) -> List[str]:
    return [line for line in os.popen(f"hdfs dfs -ls {hdfs_path}").read().strip().split("\n")[1:]]