from typing import * from argparse import Namespace from omegaconf import DictConfig from torch.nn import Module from accelerate import Accelerator from src.options import Options import os from omegaconf import OmegaConf from accelerate import load_checkpoint_and_dispatch def load_ckpt( ckpt_dir: str, ckpt_iter: int, hdfs_dir: Optional[str] = None, model: Optional[Module] = None, accelerator: Optional[Accelerator] = None, strict: bool = True, use_ema: bool = False, ) -> Module: # Find the latest checkpoint if ckpt_iter < 0: if hdfs_dir is not None: ckpt_iter = int(sorted(get_hdfs_files(hdfs_dir))[-1].split(".")[0]) else: ckpt_iter = int(sorted(os.listdir(ckpt_dir))[-1]) # Download checkpoint ckpt_path = f"{ckpt_dir}/{ckpt_iter:06d}" if not os.path.exists(ckpt_path): assert hdfs_dir is not None if accelerator is not None: if accelerator.is_main_process: ensure_sysrun(f"mkdir -p {ckpt_dir}") ensure_sysrun(f"hdfs dfs -get {hdfs_dir}/{ckpt_iter:06d}.tar {ckpt_dir}") ensure_sysrun(f"tar -xvf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir}") ensure_sysrun(f"rm {ckpt_dir}/{ckpt_iter:06d}.tar") accelerator.wait_for_everyone() # wait before preparing checkpoints by the main process else: ensure_sysrun(f"mkdir -p {ckpt_dir}") ensure_sysrun(f"hdfs dfs -get {hdfs_dir}/{ckpt_iter:06d}.tar {ckpt_dir}") ensure_sysrun(f"tar -xvf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir}") ensure_sysrun(f"rm {ckpt_dir}/{ckpt_iter:06d}.tar") if model is None: return ckpt_iter # Load checkpoint else: try: load_checkpoint_and_dispatch(model, ckpt_path, strict=strict) except: # from DeepSpeed ckpt_dir = f"{ckpt_dir}/{ckpt_iter:06d}" if accelerator is not None: if accelerator.is_main_process: ensure_sysrun(f"python3 {ckpt_dir}/zero_to_fp32.py {ckpt_dir} {ckpt_dir} --safe_serialization") accelerator.wait_for_everyone() # wait before preparing checkpoints by the main process else: ensure_sysrun(f"python3 {ckpt_dir}/zero_to_fp32.py {ckpt_dir} {ckpt_dir} --safe_serialization") load_checkpoint_and_dispatch(model, ckpt_path, strict=strict) return model def save_ckpt(ckpt_dir: str, ckpt_iter: int, hdfs_dir: Optional[str] = None): if hdfs_dir is not None: ensure_sysrun(f"tar -cf {ckpt_dir}/{ckpt_iter:06d}.tar -C {ckpt_dir} {ckpt_iter:06d}") ensure_sysrun(f"hdfs dfs -put -f {ckpt_dir}/{ckpt_iter:06d}.tar {hdfs_dir}") ensure_sysrun(f"rm -rf {ckpt_dir}/{ckpt_iter:06d}.tar {ckpt_dir}/{ckpt_iter:06d}") def get_configs(yaml_path: str, cli_configs: List[str]=[], **kwargs) -> DictConfig: yaml_configs = OmegaConf.load(yaml_path) cli_configs = OmegaConf.from_cli(cli_configs) configs = OmegaConf.merge(yaml_configs, cli_configs, kwargs) OmegaConf.resolve(configs) # resolve ${...} placeholders return configs def save_experiment_params(args: Namespace, configs: DictConfig, opt: Options, save_dir: str) -> Dict[str, Any]: os.makedirs(save_dir, exist_ok=True) params = OmegaConf.merge(configs, {k: str(v) for k, v in vars(args).items()}) params = OmegaConf.merge(params, OmegaConf.create(vars(opt))) OmegaConf.save(params, os.path.join(save_dir, "params.yaml")) return dict(params) def save_model_architecture(model: Module, save_dir: str) -> None: os.makedirs(save_dir, exist_ok=True) num_buffers = sum(b.numel() for b in model.buffers()) num_params = sum(p.numel() for p in model.parameters()) num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) message = f"Number of buffers: {num_buffers}\n" +\ f"Number of trainable / all parameters: {num_trainable_params} / {num_params}\n\n" +\ f"Model architecture:\n{model}" with open(os.path.join(save_dir, "model.txt"), "w") as f: f.write(message) def ensure_sysrun(cmd: str): while True: result = os.system(cmd) if result == 0: break else: print(f"Retry running {cmd}") def get_hdfs_files(hdfs_path: str) -> List[str]: lines = get_hdfs_lines(hdfs_path) if len(lines) == 0: raise ValueError(f"No files found in {hdfs_path}") return [line.split()[-1].split("/")[-1] for line in lines] def get_hdfs_size(hdfs_path: str, unit: str="B") -> int: lines = get_hdfs_lines(hdfs_path) if len(lines) == 0: raise ValueError(f"No files found in {hdfs_path}") byte_size = sum(int(line.split()[4]) for line in lines) if unit == "B": return byte_size elif unit == "KB": return byte_size / 1024 elif unit == "MB": return byte_size / 1024 / 1024 elif unit == "GB": return byte_size / 1024 / 1024 / 1024 elif unit == "TB": return byte_size / 1024 / 1024 / 1024 / 1024 else: raise ValueError(f"Invalid unit: {unit}") def get_hdfs_lines(hdfs_path: str) -> List[str]: return [line for line in os.popen(f"hdfs dfs -ls {hdfs_path}").read().strip().split("\n")[1:]]