|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
import threading |
|
|
|
import torch |
|
from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp import StateDictType |
|
|
|
from cosmos_transfer1.checkpointer.fsdp_optim_fix import scatter_full_optim_state_dict |
|
from cosmos_transfer1.utils import callback, distributed, log, misc |
|
from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig |
|
from cosmos_transfer1.utils.model import Model |
|
|
|
|
|
class FSDPCheckpointer: |
|
"""The checkpointer class. Supports checkpoint saving/loading to local disk.""" |
|
|
|
def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): |
|
"""Constructor of the checkpointer. |
|
|
|
Args: |
|
config_checkpoint (CheckpointConfig): The config object for the checkpointer. |
|
""" |
|
|
|
self.callbacks = callbacks |
|
self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" |
|
self.strict_resume = config_checkpoint.strict_resume |
|
self.load_path = config_checkpoint.load_path |
|
self.load_training_state = config_checkpoint.load_training_state |
|
self.save_thread = None |
|
self.config_checkpoint = config_checkpoint |
|
|
|
def _load_ckpt_file_during_init(self): |
|
latest_checkpoint_file = self._read_latest_checkpoint_file() |
|
if latest_checkpoint_file is not None: |
|
|
|
checkpoint_dir = self.checkpoint_dir_local |
|
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) |
|
resume = True |
|
log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}") |
|
log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}") |
|
log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)") |
|
else: |
|
if self.load_path: |
|
|
|
checkpoint_path = self.load_path |
|
resume = self.load_training_state |
|
log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}") |
|
if resume: |
|
log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)") |
|
else: |
|
log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)") |
|
else: |
|
|
|
checkpoint_path = None |
|
resume = False |
|
log.critical("[Checkpoint] No checkpoint path specified") |
|
log.critical("[Checkpoint] Starting fresh training with random initialization") |
|
return checkpoint_path, resume |
|
|
|
@misc.timer("FSDP.load_model_during_init") |
|
def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): |
|
if ema_id > 0: |
|
assert is_ema, "ema_id should be used with is_ema=True" |
|
checkpoint_path, _ = self._load_ckpt_file_during_init() |
|
if checkpoint_path is not None: |
|
tag = "reg" if not is_ema else "ema" |
|
default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt") |
|
if not os.path.exists(default_checkpoint_path): |
|
default_checkpoint_path = checkpoint_path |
|
log.warning(f"is_ema={is_ema} model is not found. Loading from {default_checkpoint_path}") |
|
if tag == "ema" and ema_id > 0: |
|
_checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt") |
|
_checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt") |
|
if self._check_checkpoint_exists(_checkpoint_path, is_raise=False): |
|
default_checkpoint_path = _checkpoint_path |
|
else: |
|
print( |
|
f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} " |
|
f"(fallback to {default_checkpoint_path})" |
|
) |
|
checkpoint_path = default_checkpoint_path |
|
self._check_checkpoint_exists(checkpoint_path) |
|
|
|
log.info(f"Loading checkpoint (local): {checkpoint_path}") |
|
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
log.success(f"Complete loading checkpoint (local): {checkpoint_path}") |
|
log.info("- Loading the model...") |
|
if self.strict_resume: |
|
log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) |
|
else: |
|
log.critical("\t Using non-strict model") |
|
from cosmos_transfer1.diffusion.inference.inference_utils import non_strict_load_model |
|
|
|
log.info(non_strict_load_model(model, state_dict)) |
|
log.info("-finish model loading") |
|
else: |
|
log.info(f"is_ema={is_ema} model is not found and loaded.") |
|
|
|
@misc.timer("FSDP.load_optim_scheduler_during_init") |
|
def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler): |
|
checkpoint_path, resume = self._load_ckpt_file_during_init() |
|
log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}") |
|
if checkpoint_path is not None: |
|
if resume: |
|
checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt") |
|
self._check_checkpoint_exists(checkpoint_path) |
|
if distributed.get_rank() == 0: |
|
log.info(f"Loading checkpoint (local): {checkpoint_path}") |
|
state_dict = torch.load( |
|
checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False |
|
) |
|
log.success(f"Complete loading checkpoint (local): {checkpoint_path}") |
|
log.info("- Loading the optimizer (FSDP scatter)...") |
|
else: |
|
state_dict = { |
|
"optimizer": None, |
|
"scheduler": None, |
|
} |
|
distributed.barrier() |
|
sharded_optimizer_state_dict = scatter_full_optim_state_dict( |
|
state_dict["optimizer"], |
|
fsdp_model, |
|
) |
|
log.info("- Loading the optimizer (FSDP load_state_dict)...") |
|
log.info(optimizer.load_state_dict(sharded_optimizer_state_dict)) |
|
log.critical("Skip loading the scheduler...") |
|
return |
|
log.info("- Loading the scheduler...") |
|
scheduler.load_state_dict(state_dict["scheduler"]) |
|
|
|
@misc.timer("FSDP get_optim_scheduler_state") |
|
def get_optim_scheduler_state(self, optim, fsdp_model, scheduler): |
|
with FSDP.state_dict_type( |
|
fsdp_model, |
|
StateDictType.FULL_STATE_DICT, |
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
): |
|
optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) |
|
scheduler_statedict = scheduler.state_dict() |
|
return { |
|
"optimizer": optim_statedict, |
|
"scheduler": scheduler_statedict, |
|
} |
|
|
|
def save( |
|
self, |
|
model: Model, |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
grad_scaler: torch.amp.GradScaler, |
|
iteration: int, |
|
) -> None: |
|
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint. |
|
|
|
Args: |
|
model (Model): The PyTorch model. |
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). |
|
iteration (int): Current iteration number. |
|
""" |
|
self.callbacks.on_save_checkpoint_start(model, iteration) |
|
|
|
model_state_dict = model.state_dict_model() |
|
optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler) |
|
torch.cuda.empty_cache() |
|
state_dict = dict( |
|
iteration=iteration, |
|
) |
|
self.callbacks.on_save_checkpoint(model, state_dict=state_dict) |
|
|
|
postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix() |
|
if replicate_idx == 0 and shard_idx == 0: |
|
pass |
|
elif replicate_idx < total_ema_num and shard_idx == 0: |
|
model_state_dict["model"] = None |
|
optim_scheduler_state_dict = None |
|
state_dict = None |
|
else: |
|
return |
|
|
|
checkpoint_file = f"iter_{iteration:09}{postfix}.pt" |
|
|
|
if self.save_thread: |
|
self.save_thread.join() |
|
|
|
self.save_thread = threading.Thread( |
|
target=self._save_worker_local, |
|
daemon=False, |
|
args=(model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank()), |
|
) |
|
self.save_thread.start() |
|
|
|
|
|
|
|
self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) |
|
|
|
@misc.timer("checkpoint saving (local)") |
|
def _save_worker_local( |
|
self, |
|
model_state_dict: dict[str, torch.Tensor], |
|
optim_scheduler_state_dict: dict[str, torch.Tensor], |
|
state_dict: dict[str, torch.Tensor], |
|
checkpoint_file: str, |
|
rank: int = 0, |
|
) -> None: |
|
"""Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). |
|
|
|
Args: |
|
state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. |
|
checkpoint_file (str): The file name of the model checkpoint. |
|
rank (int): GPU device (default: 0). |
|
""" |
|
checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) |
|
os.makedirs(self.checkpoint_dir_local, exist_ok=True) |
|
try: |
|
model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"] |
|
if model_state_dict is not None: |
|
torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt")) |
|
if ema_model_state_dict is not None: |
|
torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt")) |
|
if optim_scheduler_state_dict is not None: |
|
torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt")) |
|
if state_dict is not None: |
|
torch.save(state_dict, checkpoint_path) |
|
if rank == 0: |
|
self._write_latest_checkpoint_file(checkpoint_file) |
|
log.success(f"Saved checkpoint (local): {checkpoint_path}") |
|
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) |
|
self.callbacks.on_save_checkpoint_success(iteration=iteration) |
|
except Exception as e: |
|
log.exception(f"Checkpoint failed to save (local): {e}") |
|
|
|
@misc.timer("checkpoint loading") |
|
def load( |
|
self, |
|
model: Model, |
|
optimizer: torch.optim.Optimizer | None = None, |
|
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, |
|
grad_scaler: torch.amp.GradScaler | None = None, |
|
) -> int: |
|
"""Load network weights and optimizer states from a checkpoint in a single process. |
|
|
|
The priority of the checkpoint loading logic is: |
|
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. |
|
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. |
|
- This is typically used for inference mode. |
|
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. |
|
3. If none of the above, randomly initialize the model parameters and train from scratch. |
|
|
|
Args: |
|
model (FSDPDiffModle): The PyTorch model. |
|
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). |
|
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). |
|
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). |
|
|
|
Returns: |
|
iteration (int): the iteration number to start/resume from. |
|
""" |
|
self.callbacks.on_load_checkpoint_start(model) |
|
|
|
del optimizer, grad_scaler |
|
checkpoint_path, resume = self._load_ckpt_file_during_init() |
|
iteration = 0 |
|
if checkpoint_path is not None: |
|
self._check_checkpoint_exists(checkpoint_path) |
|
log.info(f"Loading checkpoint (local): {checkpoint_path}") |
|
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
log.success(f"Complete loading checkpoint (local): {checkpoint_path}") |
|
self.callbacks.on_load_checkpoint(model, state_dict=state_dict) |
|
if resume: |
|
iteration = state_dict["iteration"] |
|
log.success("Done with loading the checkpoint.") |
|
else: |
|
log.info("Training from scratch.") |
|
torch.cuda.empty_cache() |
|
|
|
self.callbacks.on_load_checkpoint_end(model) |
|
|
|
if scheduler is not None: |
|
scheduler.last_epoch = iteration |
|
log.critical(f"resume scheduler from {iteration}", rank0_only=False) |
|
|
|
return iteration |
|
|
|
def _read_latest_checkpoint_file(self) -> str | None: |
|
"""Get the file name of the latest saved checkpoint. If it doesn't exist, return None. |
|
|
|
Returns: |
|
checkpoint_file (str | None): file name of the latest saved checkpoint. |
|
""" |
|
checkpoint_file = None |
|
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") |
|
if os.path.isfile(latest_path): |
|
checkpoint_file = open(latest_path).read().strip() |
|
if checkpoint_file is None: |
|
log.warning(f"Latest ckpt file not found: {latest_path}") |
|
else: |
|
log.info(f"Found latest checkpoint: {checkpoint_file}") |
|
return checkpoint_file |
|
|
|
def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: |
|
"""Track the file name of the latest saved checkpoint. |
|
|
|
Args: |
|
checkpoint_file (str): file name of the latest saved checkpoint. |
|
""" |
|
content = f"{checkpoint_file}\n" |
|
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") |
|
with open(latest_path, "w") as file: |
|
file.write(content) |
|
|
|
def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None: |
|
"""If the file checkpoint_path does not exist, raise an error. |
|
|
|
Args: |
|
checkpoint_path (str): full path to the checkpoint. |
|
""" |
|
if not os.path.exists(checkpoint_path): |
|
if is_raise: |
|
raise FileNotFoundError(f"File not found (local): {checkpoint_path}") |
|
return True |
|
|
|
def finalize(self) -> None: |
|
"""Finalize the checkpointer.""" |
|
if self.save_thread: |
|
self.save_thread.join() |
|
|
|
|
|
class FSDPInferenceCheckpointer: |
|
def __init__( |
|
self, |
|
ckpt_path: str, |
|
strict_resume: bool = True, |
|
): |
|
self.ckpt_path = ckpt_path |
|
self.strict_resume = strict_resume |
|
|
|
@misc.timer("FSDPInferenceCheckpointer.load_model_during_init") |
|
def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): |
|
del ema_id |
|
if is_ema: |
|
log.warning("EMA model is not supported in inference mode.") |
|
return |
|
assert os.path.exists(self.ckpt_path) |
|
log.info(f"Loading from {self.ckpt_path}") |
|
state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
if self.strict_resume: |
|
log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) |
|
else: |
|
log.critical("\t Using non-strict model") |
|
from cosmos_transfer1.checkpointer.fsdp_checkpointer import non_strict_load_model |
|
|
|
log.info(non_strict_load_model(model, state_dict)) |
|
log.info("-finish model loading") |
|
|
|
def load_optim_scheduler_during_init(self, *args, **kwargs): |
|
""" |
|
We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. |
|
""" |
|
pass |
|
|
|
def save(self, *args, **kwargs): |
|
""" |
|
We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors. |
|
""" |
|
pass |
|
|
|
def load(self, *args, **kwargs): |
|
""" |
|
We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. |
|
""" |
|
return 0 |
|
|