Spaces:
Sleeping
Sleeping
from collections import OrderedDict | |
from importlib import import_module | |
import os | |
import random | |
import re | |
import warnings | |
from typing import Union, Any | |
import numpy as np | |
import torch | |
from torch import distributed as dist | |
import torch.nn as nn | |
from torch.nn.parallel import DataParallel, DistributedDataParallel | |
from .dist_util import get_dist_info | |
MODULE_WRAPPERS = [DataParallel, DistributedDataParallel] | |
MODEL_ABBR_MAP = { | |
's': 'small', | |
'b': 'base', | |
'l': 'large', | |
'h': 'huge' | |
} | |
def infer_dataset_by_path(model_path: str) -> Union[str, Any]: | |
model = os.path.basename(model_path) | |
p = r'-([a-zA-Z0-9_]+)\.[pth, onnx, engine]' | |
m = re.search(p, model) | |
if not m: | |
raise ValueError('Could not infer the dataset from ckpt name, specify it') | |
return m.group(1) | |
def dyn_model_import(dataset: str, model: str): | |
config_name = f'configs.ViTPose_{dataset}' | |
imp = import_module(config_name) | |
model = f'model_{MODEL_ABBR_MAP[model]}' | |
return getattr(imp, model) | |
def init_random_seed(seed=None, device='cuda'): | |
"""Initialize random seed. | |
If the seed is not set, the seed will be automatically randomized, | |
and then broadcast to all processes to prevent some potential bugs. | |
Args: | |
seed (int, Optional): The seed. Default to None. | |
device (str): The device where the seed will be put on. | |
Default to 'cuda'. | |
Returns: | |
int: Seed to be used. | |
""" | |
if seed is not None: | |
return seed | |
# Make sure all ranks share the same random seed to prevent | |
# some potential bugs. Please refer to | |
# https://github.com/open-mmlab/mmdetection/issues/6339 | |
rank, world_size = get_dist_info() | |
seed = np.random.randint(2**31) | |
if world_size == 1: | |
return seed | |
if rank == 0: | |
random_num = torch.tensor(seed, dtype=torch.int32, device=device) | |
else: | |
random_num = torch.tensor(0, dtype=torch.int32, device=device) | |
dist.broadcast(random_num, src=0) | |
return random_num.item() | |
def set_random_seed(seed: int, | |
deterministic: bool = False, | |
use_rank_shift: bool = False) -> None: | |
"""Set random seed. | |
Args: | |
seed (int): Seed to be used. | |
deterministic (bool): Whether to set the deterministic option for | |
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` | |
to True and `torch.backends.cudnn.benchmark` to False. | |
Default: False. | |
rank_shift (bool): Whether to add rank number to the random seed to | |
have different random seed in different threads. Default: False. | |
""" | |
if use_rank_shift: | |
rank, _ = get_dist_info() | |
seed += rank | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
if deterministic: | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def is_module_wrapper(module: nn.Module) -> bool: | |
""" Check if module wrrapper exists recursively """ | |
def is_module_in_wrapper(module, module_wrapper): | |
module_wrappers = tuple(module_wrapper.module_dict.values()) | |
if isinstance(module, module_wrappers): | |
return True | |
for child in module_wrapper.children.values(): | |
if is_module_in_wrapper(module, child): | |
return True | |
return is_module_in_wrapper(module, MODULE_WRAPPERS) | |
def load_state_dict(module, state_dict, strict=False, logger=None): | |
"""Load state_dict to a module. | |
This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |
Default value for ``strict`` is set to ``False`` and the message for | |
param mismatch will be shown even if strict is False. | |
Args: | |
module (Module): Module that receives the state_dict. | |
state_dict (OrderedDict): Weights. | |
strict (bool): whether to strictly enforce that the keys | |
in :attr:`state_dict` match the keys returned by this module's | |
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |
logger (:obj:`logging.Logger`, optional): Logger to log the error | |
message. If not specified, print function will be used. | |
""" | |
unexpected_keys = [] | |
all_missing_keys = [] | |
err_msg = [] | |
metadata = getattr(state_dict, '_metadata', None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
# use _load_from_state_dict to enable checkpoint version control | |
def load(module, prefix=''): | |
# recursively check parallel module in case that the model has a | |
# complicated structure, e.g., nn.Module(nn.Module(DDP)) | |
if is_module_wrapper(module): | |
module = module.module | |
local_metadata = {} if metadata is None else metadata.get( | |
prefix[:-1], {}) | |
module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |
all_missing_keys, unexpected_keys, | |
err_msg) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + '.') | |
load(module) | |
load = None # break load->load reference cycle | |
# ignore "num_batches_tracked" of BN layers | |
missing_keys = [ | |
key for key in all_missing_keys if 'num_batches_tracked' not in key | |
] | |
if unexpected_keys: | |
err_msg.append('unexpected key in source ' | |
f'state_dict: {", ".join(unexpected_keys)}\n') | |
if missing_keys: | |
err_msg.append( | |
f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |
rank, _ = get_dist_info() | |
if len(err_msg) > 0 and rank == 0: | |
err_msg.insert( | |
0, 'The model and loaded state dict do not match exactly\n') | |
err_msg = '\n'.join(err_msg) | |
if strict: | |
raise RuntimeError(err_msg) | |
elif logger is not None: | |
logger.warning(err_msg) | |
else: | |
print(err_msg) | |
def load_checkpoint(model, | |
filename, | |
map_location='cpu', | |
strict=False, | |
logger=None): | |
"""Load checkpoint from a file or URI. | |
Args: | |
model (Module): Module to load checkpoint. | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. | |
map_location (str): Same as :func:`torch.load`. | |
strict (bool): Whether to allow different params for the model and | |
checkpoint. | |
logger (:mod:`logging.Logger` or None): The logger for error message. | |
Returns: | |
dict or OrderedDict: The loaded checkpoint. | |
""" | |
checkpoint = torch.load(filename, map_location=map_location) | |
# OrderedDict is a subclass of dict | |
if not isinstance(checkpoint, dict): | |
raise RuntimeError( | |
f'No state_dict found in checkpoint file {filename}') | |
# get state_dict from checkpoint | |
if 'state_dict' in checkpoint: | |
state_dict_tmp = checkpoint['state_dict'] | |
else: | |
state_dict_tmp = checkpoint | |
state_dict = OrderedDict() | |
# strip prefix of state_dict | |
for k, v in state_dict_tmp.items(): | |
if k.startswith('module.backbone.'): | |
state_dict[k[16:]] = v | |
elif k.startswith('module.'): | |
state_dict[k[7:]] = v | |
elif k.startswith('backbone.'): | |
state_dict[k[9:]] = v | |
else: | |
state_dict[k] = v | |
# load state_dict | |
load_state_dict(model, state_dict, strict, logger) | |
return checkpoint | |
def resize(input, | |
size=None, | |
scale_factor=None, | |
mode='nearest', | |
align_corners=None, | |
warning=True): | |
if warning: | |
if size is not None and align_corners: | |
input_h, input_w = int(input.shape[0]), int(input.shape[1]) | |
output_h, output_w = int(size[0]), int(size[1]) | |
if output_h > input_h or output_w > output_h: | |
if ((output_h > 1 and output_w > 1 and input_h > 1 | |
and input_w > 1) and (output_h - 1) % (input_h - 1) | |
and (output_w - 1) % (input_w - 1)): | |
warnings.warn( | |
f'When align_corners={align_corners}, ' | |
'the output would more aligned if ' | |
f'input size {(input_h, input_w)} is `x+1` and ' | |
f'out size {(output_h, output_w)} is `nx+1`') | |
def constant_init(module: nn.Module, val: float, bias: float = 0) -> None: | |
if hasattr(module, 'weight') and module.weight is not None: | |
nn.init.constant_(module.weight, val) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |
def normal_init(module: nn.Module, | |
mean: float = 0, | |
std: float = 1, | |
bias: float = 0) -> None: | |
if hasattr(module, 'weight') and module.weight is not None: | |
nn.init.normal_(module.weight, mean, std) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |