|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
import threading |
|
from typing import List, NamedTuple, Tuple |
|
|
|
import torch |
|
|
|
from cosmos_transfer1.utils import distributed, log, misc |
|
from cosmos_transfer1.utils.checkpointer import Checkpointer as BaseCheckpointer |
|
from cosmos_transfer1.utils.model import Model |
|
|
|
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
|
if TORCH_VERSION >= (1, 11): |
|
from torch.ao import quantization |
|
from torch.ao.quantization import FakeQuantizeBase, ObserverBase |
|
elif ( |
|
TORCH_VERSION >= (1, 8) |
|
and hasattr(torch.quantization, "FakeQuantizeBase") |
|
and hasattr(torch.quantization, "ObserverBase") |
|
): |
|
from torch import quantization |
|
from torch.quantization import FakeQuantizeBase, ObserverBase |
|
|
|
|
|
class _IncompatibleKeys( |
|
NamedTuple( |
|
"IncompatibleKeys", |
|
[ |
|
("missing_keys", List[str]), |
|
("unexpected_keys", List[str]), |
|
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), |
|
], |
|
) |
|
): |
|
pass |
|
|
|
|
|
class MultiRankCheckpointer(BaseCheckpointer): |
|
def save( |
|
self, |
|
model: Model, |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
grad_scaler: torch.amp.GradScaler, |
|
iteration: int, |
|
) -> None: |
|
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint. |
|
|
|
Args: |
|
model (Model): The PyTorch model. |
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). |
|
iteration (int): Current iteration number. |
|
""" |
|
|
|
postfix, _, total_ema_num = model.get_ckpt_postfix() |
|
checkpoint_file = f"iter_{iteration:09}{postfix}.pt" |
|
save_ranks = list(range(total_ema_num)) |
|
for _rank in save_ranks: |
|
if distributed.get_rank() == _rank: |
|
state_dict = dict( |
|
model=model.state_dict(), |
|
optimizer=optimizer.state_dict(), |
|
scheduler=scheduler.state_dict(), |
|
grad_scaler=grad_scaler.state_dict(), |
|
iteration=iteration, |
|
) |
|
state_dict = misc.to(state_dict, device="cpu") |
|
self.callbacks.on_save_checkpoint(model, state_dict=state_dict) |
|
|
|
if self.save_thread: |
|
self.save_thread.join() |
|
|
|
self.save_thread = threading.Thread( |
|
target=self._save_worker_local, |
|
daemon=False, |
|
args=(state_dict, checkpoint_file, distributed.get_rank()), |
|
) |
|
self.save_thread.start() |
|
|
|
@misc.timer("checkpoint loading") |
|
def load( |
|
self, |
|
model: Model, |
|
optimizer: torch.optim.Optimizer | None = None, |
|
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, |
|
grad_scaler: torch.amp.GradScaler | None = None, |
|
) -> int: |
|
"""Load network weights and optimizer states from a checkpoint in a single process. |
|
|
|
The priority of the checkpoint loading logic is: |
|
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. |
|
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. |
|
- This is typically used for inference mode. |
|
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. |
|
3. If none of the above, randomly initialize the model parameters and train from scratch. |
|
|
|
Args: |
|
model (Model): The PyTorch model. |
|
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). |
|
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). |
|
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). |
|
|
|
Returns: |
|
iteration (int): the iteration number to start/resume from. |
|
""" |
|
latest_checkpoint_file = self._read_latest_checkpoint_file() |
|
if latest_checkpoint_file is not None: |
|
|
|
postfix, _, total_ema_num = model.get_ckpt_postfix() |
|
latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt") |
|
|
|
checkpoint_dir = self.checkpoint_dir_local |
|
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) |
|
resume = True |
|
else: |
|
if self.load_path: |
|
|
|
checkpoint_path = self.load_path |
|
|
|
postfix, _, total_ema_num = model.get_ckpt_postfix() |
|
checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt") |
|
resume = self.load_training_state |
|
else: |
|
|
|
checkpoint_path = None |
|
resume = False |
|
|
|
if checkpoint_path is not None: |
|
self._check_checkpoint_exists(checkpoint_path) |
|
log.info(f"Loading checkpoint (local): {checkpoint_path}") |
|
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) |
|
log.success(f"Complete loading checkpoint (local): {checkpoint_path}") |
|
self.callbacks.on_load_checkpoint(model, state_dict=state_dict) |
|
|
|
log.info("- Loading the model...") |
|
log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume)) |
|
if resume: |
|
iteration = state_dict["iteration"] |
|
assert optimizer and scheduler |
|
log.info("- Loading the optimizer...") |
|
optimizer.load_state_dict(state_dict["optimizer"]) |
|
log.info("- Loading the scheduler...") |
|
scheduler.load_state_dict(state_dict["scheduler"]) |
|
scheduler.last_epoch = iteration |
|
log.info("- Loading the gradient scaler...") |
|
grad_scaler.load_state_dict(state_dict["grad_scaler"]) |
|
log.success(f"Done with loading the checkpoint (iteration {iteration}).") |
|
else: |
|
iteration = 0 |
|
log.success("Done with loading the checkpoint.") |
|
else: |
|
|
|
iteration = 0 |
|
log.info("Training from scratch.") |
|
torch.cuda.empty_cache() |
|
return iteration |
|
|
|
|
|
|
|
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: |
|
|
|
model_state_dict = model.state_dict() |
|
incorrect_shapes = [] |
|
for k in list(checkpoint_state_dict.keys()): |
|
if k in model_state_dict: |
|
if "_extra_state" in k: |
|
log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") |
|
continue |
|
model_param = model_state_dict[k] |
|
|
|
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): |
|
continue |
|
if not isinstance(model_param, torch.Tensor): |
|
raise ValueError( |
|
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." |
|
) |
|
|
|
shape_model = tuple(model_param.shape) |
|
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
|
if shape_model != shape_checkpoint: |
|
has_observer_base_classes = ( |
|
TORCH_VERSION >= (1, 8) |
|
and hasattr(quantization, "ObserverBase") |
|
and hasattr(quantization, "FakeQuantizeBase") |
|
) |
|
if has_observer_base_classes: |
|
|
|
|
|
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: |
|
|
|
key_parts = key.split(".")[:-1] |
|
cur_module = model |
|
for key_part in key_parts: |
|
cur_module = getattr(cur_module, key_part) |
|
return cur_module |
|
|
|
cls_to_skip = ( |
|
ObserverBase, |
|
FakeQuantizeBase, |
|
) |
|
target_module = _get_module_for_key(model, k) |
|
if isinstance(target_module, cls_to_skip): |
|
|
|
|
|
|
|
continue |
|
|
|
incorrect_shapes.append((k, shape_checkpoint, shape_model)) |
|
checkpoint_state_dict.pop(k) |
|
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) |
|
|
|
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] |
|
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] |
|
return _IncompatibleKeys( |
|
missing_keys=missing_keys, |
|
unexpected_keys=unexpected_keys, |
|
incorrect_shapes=incorrect_shapes, |
|
) |
|
|