|
import datetime |
|
import json |
|
import os |
|
import re |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Tuple, Union |
|
from urllib.parse import urlparse |
|
|
|
import fsspec |
|
import torch |
|
from coqpit import Coqpit |
|
|
|
from trainer.logger import logger |
|
|
|
|
|
def copy_model_files(config: Coqpit, out_path, new_fields): |
|
"""Copy config.json and other model files to training folder and add |
|
new fields. |
|
|
|
Args: |
|
config (Coqpit): Coqpit config defining the training run. |
|
out_path (str): output path to copy the file. |
|
new_fields (dict): new fileds to be added or edited |
|
in the config file. |
|
""" |
|
copy_config_path = os.path.join(out_path, "config.json") |
|
|
|
new_config = {**config.to_dict(), **new_fields} |
|
|
|
with fsspec.open(copy_config_path, "w", encoding="utf8") as f: |
|
json.dump(new_config, f, indent=4) |
|
|
|
|
|
def load_fsspec( |
|
path: str, |
|
map_location: Union[ |
|
str, |
|
Callable, |
|
torch.device, |
|
Dict[Union[str, torch.device], Union[str, torch.device]], |
|
] = None, |
|
**kwargs, |
|
) -> Any: |
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://). |
|
|
|
Args: |
|
path: Any path or url supported by fsspec. |
|
map_location: torch.device or str. |
|
**kwargs: Keyword arguments forwarded to torch.load. |
|
|
|
Returns: |
|
Object stored in path. |
|
""" |
|
with fsspec.open(path, "rb") as f: |
|
return torch.load(f, map_location=map_location, **kwargs) |
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) |
|
model.load_state_dict(state["model"]) |
|
if use_cuda: |
|
model.cuda() |
|
if eval: |
|
model.eval() |
|
return model, state |
|
|
|
|
|
def save_fsspec(state: Any, path: str, **kwargs): |
|
"""Like torch.save but can save to other locations (e.g. s3:// , gs://). |
|
|
|
Args: |
|
state: State object to save |
|
path: Any path or url supported by fsspec. |
|
**kwargs: Keyword arguments forwarded to torch.save. |
|
""" |
|
with fsspec.open(path, "wb") as f: |
|
torch.save(state, f, **kwargs) |
|
|
|
|
|
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, save_func, **kwargs): |
|
if hasattr(model, "module"): |
|
model_state = model.module.state_dict() |
|
else: |
|
model_state = model.state_dict() |
|
if isinstance(optimizer, list): |
|
optimizer_state = [optim.state_dict() for optim in optimizer] |
|
else: |
|
optimizer_state = optimizer.state_dict() if optimizer is not None else None |
|
|
|
if isinstance(scaler, list): |
|
scaler_state = [s.state_dict() for s in scaler] |
|
else: |
|
scaler_state = scaler.state_dict() if scaler is not None else None |
|
|
|
if isinstance(config, Coqpit): |
|
config = config.to_dict() |
|
|
|
state = { |
|
"config": config, |
|
"model": model_state, |
|
"optimizer": optimizer_state, |
|
"scaler": scaler_state, |
|
"step": current_step, |
|
"epoch": epoch, |
|
"date": datetime.date.today().strftime("%B %d, %Y"), |
|
} |
|
state.update(kwargs) |
|
if save_func: |
|
save_func(state, output_path) |
|
else: |
|
save_fsspec(state, output_path) |
|
|
|
|
|
def save_checkpoint( |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
output_folder, |
|
save_n_checkpoints=None, |
|
save_func=None, |
|
**kwargs, |
|
): |
|
file_name = f"checkpoint_{current_step}.pth" |
|
checkpoint_path = os.path.join(output_folder, file_name) |
|
|
|
logger.info("\n > CHECKPOINT : %s", checkpoint_path) |
|
save_model( |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
checkpoint_path, |
|
save_func=save_func, |
|
**kwargs, |
|
) |
|
if save_n_checkpoints is not None: |
|
keep_n_checkpoints(output_folder, save_n_checkpoints) |
|
|
|
|
|
def save_best_model( |
|
current_loss, |
|
best_loss, |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
out_path, |
|
keep_all_best=False, |
|
keep_after=10000, |
|
save_func=None, |
|
**kwargs, |
|
): |
|
if current_loss < best_loss: |
|
best_model_name = f"best_model_{current_step}.pth" |
|
checkpoint_path = os.path.join(out_path, best_model_name) |
|
logger.info(" > BEST MODEL : %s", checkpoint_path) |
|
save_model( |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
checkpoint_path, |
|
model_loss=current_loss, |
|
save_func=save_func, |
|
**kwargs, |
|
) |
|
fs = fsspec.get_mapper(out_path).fs |
|
|
|
if not keep_all_best or (current_step < keep_after): |
|
model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) |
|
for model_name in model_names: |
|
if os.path.basename(model_name) != best_model_name: |
|
fs.rm(model_name) |
|
|
|
shortcut_name = "best_model.pth" |
|
shortcut_path = os.path.join(out_path, shortcut_name) |
|
fs.copy(checkpoint_path, shortcut_path) |
|
best_loss = current_loss |
|
return best_loss |
|
|
|
|
|
def get_last_checkpoint(path: str) -> Tuple[str, str]: |
|
"""Get latest checkpoint or/and best model in path. |
|
|
|
It is based on globbing for `*.pth` and the RegEx |
|
`(checkpoint|best_model)_([0-9]+)`. |
|
|
|
Args: |
|
path: Path to files to be compared. |
|
|
|
Raises: |
|
ValueError: If no checkpoint or best_model files are found. |
|
|
|
Returns: |
|
Path to the last checkpoint |
|
Path to best checkpoint |
|
""" |
|
fs = fsspec.get_mapper(path).fs |
|
file_names = fs.glob(os.path.join(path, "*.pth")) |
|
scheme = urlparse(path).scheme |
|
if scheme: |
|
file_names = [scheme + "://" + file_name for file_name in file_names] |
|
last_models = {} |
|
last_model_nums = {} |
|
for key in ["checkpoint", "best_model"]: |
|
last_model_num = None |
|
last_model = None |
|
|
|
|
|
for file_name in file_names: |
|
match = re.search(f"{key}_([0-9]+)", file_name) |
|
if match is not None: |
|
model_num = int(match.groups()[0]) |
|
if last_model_num is None or model_num > last_model_num: |
|
last_model_num = model_num |
|
last_model = file_name |
|
|
|
|
|
|
|
|
|
key_file_names = [fn for fn in file_names if key in fn] |
|
if last_model is None and len(key_file_names) > 0: |
|
last_model = max(key_file_names, key=os.path.getctime) |
|
last_model_num = load_fsspec(last_model)["step"] |
|
|
|
if last_model is not None: |
|
last_models[key] = last_model |
|
last_model_nums[key] = last_model_num |
|
|
|
|
|
if not last_models: |
|
raise ValueError(f"No models found in continue path {path}!") |
|
if "checkpoint" not in last_models: |
|
last_models["checkpoint"] = last_models["best_model"] |
|
elif "best_model" not in last_models: |
|
|
|
last_models["best_model"] = last_models["checkpoint"] |
|
|
|
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: |
|
last_models["checkpoint"] = last_models["best_model"] |
|
|
|
return last_models["checkpoint"], last_models["best_model"] |
|
|
|
|
|
def keep_n_checkpoints(path: str, n: int) -> None: |
|
"""Keep only the last n checkpoints in path. |
|
|
|
Args: |
|
path: Path to files to be compared. |
|
n: Number of checkpoints to keep. |
|
""" |
|
fs = fsspec.get_mapper(path).fs |
|
file_names = sort_checkpoints(path, "checkpoint") |
|
if len(file_names) > n: |
|
for file_name in file_names[:-n]: |
|
fs.rm(file_name) |
|
|
|
|
|
def sort_checkpoints(output_path: str, checkpoint_prefix: str, use_mtime: bool = False) -> List[str]: |
|
"""Sort checkpoint paths based on the checkpoint step number. |
|
|
|
Args: |
|
output_path (str): Path to directory containing checkpoints. |
|
checkpoint_prefix (str): Prefix of the checkpoint files. |
|
use_mtime (bool): If True, use modification dates to determine checkpoint order. |
|
""" |
|
ordering_and_checkpoint_path = [] |
|
|
|
glob_checkpoints = [str(x) for x in Path(output_path).glob(f"{checkpoint_prefix}_*")] |
|
|
|
for path in glob_checkpoints: |
|
if use_mtime: |
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) |
|
else: |
|
regex_match = re.match(f".*{checkpoint_prefix}_([0-9]+)", path) |
|
if regex_match is not None and regex_match.groups() is not None: |
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) |
|
|
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path) |
|
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] |
|
return checkpoints_sorted |
|
|