| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import os.path as osp |
| | import glob |
| | import onnxruntime |
| | from .arcface_onnx import * |
| | from .retinaface import * |
| | |
| | from .landmark import * |
| | from .attribute import Attribute |
| | from .inswapper import INSwapper |
| | from ..utils import download_onnx |
| |
|
| | __all__ = ['get_model'] |
| |
|
| |
|
| | class PickableInferenceSession(onnxruntime.InferenceSession): |
| | |
| | def __init__(self, model_path, **kwargs): |
| | super().__init__(model_path, **kwargs) |
| | self.model_path = model_path |
| |
|
| | def __getstate__(self): |
| | return {'model_path': self.model_path} |
| |
|
| | def __setstate__(self, values): |
| | model_path = values['model_path'] |
| | self.__init__(model_path) |
| |
|
| | class ModelRouter: |
| | def __init__(self, onnx_file): |
| | self.onnx_file = onnx_file |
| |
|
| | def get_model(self, **kwargs): |
| | session = PickableInferenceSession(self.onnx_file, **kwargs) |
| | |
| | inputs = session.get_inputs() |
| | input_cfg = inputs[0] |
| | input_shape = input_cfg.shape |
| | outputs = session.get_outputs() |
| |
|
| | if len(outputs)>=5: |
| | return RetinaFace(model_file=self.onnx_file, session=session) |
| | elif input_shape[2]==192 and input_shape[3]==192: |
| | return Landmark(model_file=self.onnx_file, session=session) |
| | elif input_shape[2]==96 and input_shape[3]==96: |
| | return Attribute(model_file=self.onnx_file, session=session) |
| | elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128: |
| | return INSwapper(model_file=self.onnx_file, session=session) |
| | elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0: |
| | return ArcFaceONNX(model_file=self.onnx_file, session=session) |
| | else: |
| | |
| | return None |
| |
|
| | def find_onnx_file(dir_path): |
| | if not os.path.exists(dir_path): |
| | return None |
| | paths = glob.glob("%s/*.onnx" % dir_path) |
| | if len(paths) == 0: |
| | return None |
| | paths = sorted(paths) |
| | return paths[-1] |
| |
|
| | def get_default_providers(): |
| | return ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider'] |
| |
|
| | def get_default_provider_options(): |
| | return None |
| |
|
| | def get_model(name, **kwargs): |
| | root = kwargs.get('root', '~/.insightface') |
| | root = os.path.expanduser(root) |
| | model_root = osp.join(root, 'models') |
| | allow_download = kwargs.get('download', False) |
| | download_zip = kwargs.get('download_zip', False) |
| | if not name.endswith('.onnx'): |
| | model_dir = os.path.join(model_root, name) |
| | model_file = find_onnx_file(model_dir) |
| | if model_file is None: |
| | return None |
| | else: |
| | model_file = name |
| | if not osp.exists(model_file) and allow_download: |
| | model_file = download_onnx('models', model_file, root=root, download_zip=download_zip) |
| | assert osp.exists(model_file), 'model_file %s should exist'%model_file |
| | assert osp.isfile(model_file), 'model_file %s should be a file'%model_file |
| | router = ModelRouter(model_file) |
| | providers = kwargs.get('providers', get_default_providers()) |
| | provider_options = kwargs.get('provider_options', get_default_provider_options()) |
| | model = router.get_model(providers=providers, provider_options=provider_options) |
| | return model |
| |
|