| | """ |
| | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py |
| | """ |
| | from __future__ import print_function |
| |
|
| | __all__ = ['get_model_file'] |
| | import os |
| | import zipfile |
| | import glob |
| |
|
| | from ..utils import download, check_sha1 |
| |
|
| | _model_sha1 = { |
| | name: checksum |
| | for checksum, name in [ |
| | ('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'), |
| | ('', 'arcface_mfn_v1'), |
| | ('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'), |
| | ('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'), |
| | ('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'), |
| | ('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'), |
| | ] |
| | } |
| |
|
| | base_repo_url = 'https://insightface.ai/files/' |
| | _url_format = '{repo_url}models/{file_name}.zip' |
| |
|
| |
|
| | def short_hash(name): |
| | if name not in _model_sha1: |
| | raise ValueError( |
| | 'Pretrained model for {name} is not available.'.format(name=name)) |
| | return _model_sha1[name][:8] |
| |
|
| |
|
| | def find_params_file(dir_path): |
| | if not os.path.exists(dir_path): |
| | return None |
| | paths = glob.glob("%s/*.params" % dir_path) |
| | if len(paths) == 0: |
| | return None |
| | paths = sorted(paths) |
| | return paths[-1] |
| |
|
| |
|
| | def get_model_file(name, root=os.path.join('~', '.insightface', 'models')): |
| | r"""Return location for the pretrained on local file system. |
| | |
| | This function will download from online model zoo when model cannot be found or has mismatch. |
| | The root directory will be created if it doesn't exist. |
| | |
| | Parameters |
| | ---------- |
| | name : str |
| | Name of the model. |
| | root : str, default '~/.mxnet/models' |
| | Location for keeping the model parameters. |
| | |
| | Returns |
| | ------- |
| | file_path |
| | Path to the requested pretrained model file. |
| | """ |
| |
|
| | file_name = name |
| | root = os.path.expanduser(root) |
| | dir_path = os.path.join(root, name) |
| | file_path = find_params_file(dir_path) |
| | |
| | sha1_hash = _model_sha1[name] |
| | if file_path is not None: |
| | if check_sha1(file_path, sha1_hash): |
| | return file_path |
| | else: |
| | print( |
| | 'Mismatch in the content of model file detected. Downloading again.' |
| | ) |
| | else: |
| | print('Model file is not found. Downloading.') |
| |
|
| | if not os.path.exists(root): |
| | os.makedirs(root) |
| | if not os.path.exists(dir_path): |
| | os.makedirs(dir_path) |
| |
|
| | zip_file_path = os.path.join(root, file_name + '.zip') |
| | repo_url = base_repo_url |
| | if repo_url[-1] != '/': |
| | repo_url = repo_url + '/' |
| | download(_url_format.format(repo_url=repo_url, file_name=file_name), |
| | path=zip_file_path, |
| | overwrite=True) |
| | with zipfile.ZipFile(zip_file_path) as zf: |
| | zf.extractall(dir_path) |
| | os.remove(zip_file_path) |
| | file_path = find_params_file(dir_path) |
| |
|
| | if check_sha1(file_path, sha1_hash): |
| | return file_path |
| | else: |
| | raise ValueError( |
| | 'Downloaded file has different hash. Please try again.') |
| |
|
| |
|