|
""" |
|
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py |
|
""" |
|
from __future__ import print_function |
|
|
|
__all__ = ['get_model_file'] |
|
import os |
|
import zipfile |
|
import glob |
|
|
|
from ..utils import download, check_sha1 |
|
|
|
_model_sha1 = { |
|
name: checksum |
|
for checksum, name in [ |
|
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'), |
|
('', 'arcface_mfn_v1'), |
|
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'), |
|
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'), |
|
('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'), |
|
('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'), |
|
] |
|
} |
|
|
|
base_repo_url = 'https://insightface.ai/files/' |
|
_url_format = '{repo_url}models/{file_name}.zip' |
|
|
|
|
|
def short_hash(name): |
|
if name not in _model_sha1: |
|
raise ValueError( |
|
'Pretrained model for {name} is not available.'.format(name=name)) |
|
return _model_sha1[name][:8] |
|
|
|
|
|
def find_params_file(dir_path): |
|
if not os.path.exists(dir_path): |
|
return None |
|
paths = glob.glob("%s/*.params" % dir_path) |
|
if len(paths) == 0: |
|
return None |
|
paths = sorted(paths) |
|
return paths[-1] |
|
|
|
|
|
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')): |
|
r"""Return location for the pretrained on local file system. |
|
|
|
This function will download from online model zoo when model cannot be found or has mismatch. |
|
The root directory will be created if it doesn't exist. |
|
|
|
Parameters |
|
---------- |
|
name : str |
|
Name of the model. |
|
root : str, default '~/.mxnet/models' |
|
Location for keeping the model parameters. |
|
|
|
Returns |
|
------- |
|
file_path |
|
Path to the requested pretrained model file. |
|
""" |
|
|
|
file_name = name |
|
root = os.path.expanduser(root) |
|
dir_path = os.path.join(root, name) |
|
file_path = find_params_file(dir_path) |
|
|
|
sha1_hash = _model_sha1[name] |
|
if file_path is not None: |
|
if check_sha1(file_path, sha1_hash): |
|
return file_path |
|
else: |
|
print( |
|
'Mismatch in the content of model file detected. Downloading again.' |
|
) |
|
else: |
|
print('Model file is not found. Downloading.') |
|
|
|
if not os.path.exists(root): |
|
os.makedirs(root) |
|
if not os.path.exists(dir_path): |
|
os.makedirs(dir_path) |
|
|
|
zip_file_path = os.path.join(root, file_name + '.zip') |
|
repo_url = base_repo_url |
|
if repo_url[-1] != '/': |
|
repo_url = repo_url + '/' |
|
download(_url_format.format(repo_url=repo_url, file_name=file_name), |
|
path=zip_file_path, |
|
overwrite=True) |
|
with zipfile.ZipFile(zip_file_path) as zf: |
|
zf.extractall(dir_path) |
|
os.remove(zip_file_path) |
|
file_path = find_params_file(dir_path) |
|
|
|
if check_sha1(file_path, sha1_hash): |
|
return file_path |
|
else: |
|
raise ValueError( |
|
'Downloaded file has different hash. Please try again.') |
|
|
|
|