Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from argparse import Namespace | |
from omegaconf import DictConfig | |
from torch.nn import Module | |
from accelerate import Accelerator | |
from src.options import Options | |
import os | |
from omegaconf import OmegaConf | |
from accelerate import load_checkpoint_and_dispatch | |
def load_ckpt( | |
ckpt_dir: str, ckpt_iter: int, | |
hdfs_dir: Optional[str] = None, | |
model: Optional[Module] = None, | |
accelerator: Optional[Accelerator] = None, | |
strict: bool = True, | |
use_ema: bool = False, | |
) -> Module: | |
# Find the latest checkpoint | |
if ckpt_iter < 0: | |
if hdfs_dir is not None: | |
ckpt_iter = int(sorted(get_hdfs_files(hdfs_dir))[-1].split(".")[0]) | |
else: | |
ckpt_iter = int(sorted(os.listdir(ckpt_dir))[-1]) | |
# Download checkpoint | |
ckpt_path = f"{ckpt_dir}/{ckpt_iter:06d}" | |
if not os.path.exists(ckpt_path): | |
assert hdfs_dir is not None | |
if accelerator is not None: | |
if accelerator.is_main_process: | |
ensure_sysrun(f"mkdir -p {ckpt_dir}") | |
ensure_sysrun(f"hdfs dfs -get {hdfs_dir}/{ckpt_iter:06d}.tar {ckpt_dir}") | |
ensure_sysrun(f"tar -xvf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir}") | |
ensure_sysrun(f"rm {ckpt_dir}/{ckpt_iter:06d}.tar") | |
accelerator.wait_for_everyone() # wait before preparing checkpoints by the main process | |
else: | |
ensure_sysrun(f"mkdir -p {ckpt_dir}") | |
ensure_sysrun(f"hdfs dfs -get {hdfs_dir}/{ckpt_iter:06d}.tar {ckpt_dir}") | |
ensure_sysrun(f"tar -xvf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir}") | |
ensure_sysrun(f"rm {ckpt_dir}/{ckpt_iter:06d}.tar") | |
if model is None: | |
return ckpt_iter | |
# Load checkpoint | |
else: | |
try: | |
load_checkpoint_and_dispatch(model, ckpt_path, strict=strict) | |
except: # from DeepSpeed | |
ckpt_dir = f"{ckpt_dir}/{ckpt_iter:06d}" | |
if accelerator is not None: | |
if accelerator.is_main_process: | |
ensure_sysrun(f"python3 {ckpt_dir}/zero_to_fp32.py {ckpt_dir} {ckpt_dir} --safe_serialization") | |
accelerator.wait_for_everyone() # wait before preparing checkpoints by the main process | |
else: | |
ensure_sysrun(f"python3 {ckpt_dir}/zero_to_fp32.py {ckpt_dir} {ckpt_dir} --safe_serialization") | |
load_checkpoint_and_dispatch(model, ckpt_path, strict=strict) | |
return model | |
def save_ckpt(ckpt_dir: str, ckpt_iter: int, hdfs_dir: Optional[str] = None): | |
if hdfs_dir is not None: | |
ensure_sysrun(f"tar -cf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir} {ckpt_iter:06d}") | |
ensure_sysrun(f"hdfs dfs -put -f {ckpt_dir}/{ckpt_iter:06d}.tar {hdfs_dir}") | |
ensure_sysrun(f"rm -rf {ckpt_dir}/{ckpt_iter:06d}.tar {ckpt_dir}/{ckpt_iter:06d}") | |
def get_configs(yaml_path: str, cli_configs: List[str]=[], **kwargs) -> DictConfig: | |
yaml_configs = OmegaConf.load(yaml_path) | |
cli_configs = OmegaConf.from_cli(cli_configs) | |
configs = OmegaConf.merge(yaml_configs, cli_configs, kwargs) | |
OmegaConf.resolve(configs) # resolve ${...} placeholders | |
return configs | |
def save_experiment_params(args: Namespace, configs: DictConfig, opt: Options, save_dir: str) -> Dict[str, Any]: | |
os.makedirs(save_dir, exist_ok=True) | |
params = OmegaConf.merge(configs, {k: str(v) for k, v in vars(args).items()}) | |
params = OmegaConf.merge(params, OmegaConf.create(vars(opt))) | |
OmegaConf.save(params, os.path.join(save_dir, "params.yaml")) | |
return dict(params) | |
def save_model_architecture(model: Module, save_dir: str) -> None: | |
os.makedirs(save_dir, exist_ok=True) | |
num_buffers = sum(b.numel() for b in model.buffers()) | |
num_params = sum(p.numel() for p in model.parameters()) | |
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
message = f"Number of buffers: {num_buffers}\n" +\ | |
f"Number of trainable / all parameters: {num_trainable_params} / {num_params}\n\n" +\ | |
f"Model architecture:\n{model}" | |
with open(os.path.join(save_dir, "model.txt"), "w") as f: | |
f.write(message) | |
def ensure_sysrun(cmd: str): | |
while True: | |
result = os.system(cmd) | |
if result == 0: | |
break | |
else: | |
print(f"Retry running {cmd}") | |
def get_hdfs_files(hdfs_path: str) -> List[str]: | |
lines = get_hdfs_lines(hdfs_path) | |
if len(lines) == 0: | |
raise ValueError(f"No files found in {hdfs_path}") | |
return [line.split()[-1].split("/")[-1] for line in lines] | |
def get_hdfs_size(hdfs_path: str, unit: str="B") -> int: | |
lines = get_hdfs_lines(hdfs_path) | |
if len(lines) == 0: | |
raise ValueError(f"No files found in {hdfs_path}") | |
byte_size = sum(int(line.split()[4]) for line in lines) | |
if unit == "B": | |
return byte_size | |
elif unit == "KB": | |
return byte_size / 1024 | |
elif unit == "MB": | |
return byte_size / 1024 / 1024 | |
elif unit == "GB": | |
return byte_size / 1024 / 1024 / 1024 | |
elif unit == "TB": | |
return byte_size / 1024 / 1024 / 1024 / 1024 | |
else: | |
raise ValueError(f"Invalid unit: {unit}") | |
def get_hdfs_lines(hdfs_path: str) -> List[str]: | |
return [line for line in os.popen(f"hdfs dfs -ls {hdfs_path}").read().strip().split("\n")[1:]] | |