| |
| |
| import argparse |
| import logging |
| import os |
| import os.path as osp |
| from functools import partial |
|
|
| import mmengine |
| import torch.multiprocessing as mp |
| from torch.multiprocessing import Process, set_start_method |
|
|
| from mmdeploy.apis import (create_calib_input_data, extract_model, |
| get_predefined_partition_cfg, torch2onnx, |
| torch2torchscript, visualize_model) |
| from mmdeploy.apis.core import PIPELINE_MANAGER |
| from mmdeploy.apis.utils import to_backend |
| from mmdeploy.backend.sdk.export_info import export2SDK |
| from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, |
| get_ir_config, get_partition_config, |
| get_root_logger, load_config, target_wrapper) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Export model to backends.') |
| parser.add_argument('deploy_cfg', help='deploy config path') |
| parser.add_argument('model_cfg', help='model config path') |
| parser.add_argument('checkpoint', help='model checkpoint path') |
| parser.add_argument('img', help='image used to convert model model') |
| parser.add_argument( |
| '--test-img', |
| default=None, |
| type=str, |
| nargs='+', |
| help='image used to test model') |
| parser.add_argument( |
| '--work-dir', |
| default=os.getcwd(), |
| help='the dir to save logs and models') |
| parser.add_argument( |
| '--calib-dataset-cfg', |
| help='dataset config path used to calibrate in int8 mode. If not \ |
| specified, it will use "val" dataset in model config instead.', |
| default=None) |
| parser.add_argument( |
| '--device', help='device used for conversion', default='cpu') |
| parser.add_argument( |
| '--log-level', |
| help='set log level', |
| default='INFO', |
| choices=list(logging._nameToLevel.keys())) |
| parser.add_argument( |
| '--show', action='store_true', help='Show detection outputs') |
| parser.add_argument( |
| '--dump-info', action='store_true', help='Output information for SDK') |
| parser.add_argument( |
| '--quant-image-dir', |
| default=None, |
| help='Image directory for quantize model.') |
| parser.add_argument( |
| '--quant', action='store_true', help='Quantize model to low bit.') |
| parser.add_argument( |
| '--uri', |
| default='192.168.1.1:60000', |
| help='Remote ipv4:port or ipv6:port for inference on edge device.') |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def create_process(name, target, args, kwargs, ret_value=None): |
| logger = get_root_logger() |
| logger.info(f'{name} start.') |
| log_level = logger.level |
|
|
| wrap_func = partial(target_wrapper, target, log_level, ret_value) |
|
|
| process = Process(target=wrap_func, args=args, kwargs=kwargs) |
| process.start() |
| process.join() |
|
|
| if ret_value is not None: |
| if ret_value.value != 0: |
| logger.error(f'{name} failed.') |
| exit(1) |
| else: |
| logger.info(f'{name} success.') |
|
|
|
|
| def torch2ir(ir_type: IR): |
| """Return the conversion function from torch to the intermediate |
| representation. |
| |
| Args: |
| ir_type (IR): The type of the intermediate representation. |
| """ |
| if ir_type == IR.ONNX: |
| return torch2onnx |
| elif ir_type == IR.TORCHSCRIPT: |
| return torch2torchscript |
| else: |
| raise KeyError(f'Unexpected IR type {ir_type}') |
|
|
|
|
| def main(): |
| args = parse_args() |
| set_start_method('spawn', force=True) |
| logger = get_root_logger() |
| log_level = logging.getLevelName(args.log_level) |
| logger.setLevel(log_level) |
|
|
| pipeline_funcs = [ |
| torch2onnx, torch2torchscript, extract_model, create_calib_input_data |
| ] |
| PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) |
| PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) |
|
|
| deploy_cfg_path = args.deploy_cfg |
| model_cfg_path = args.model_cfg |
| checkpoint_path = args.checkpoint |
| quant = args.quant |
| quant_image_dir = args.quant_image_dir |
|
|
| |
| deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) |
|
|
| |
| mmengine.mkdir_or_exist(osp.abspath(args.work_dir)) |
|
|
| if args.dump_info: |
| export2SDK( |
| deploy_cfg, |
| model_cfg, |
| args.work_dir, |
| pth=checkpoint_path, |
| device=args.device) |
|
|
| ret_value = mp.Value('d', 0, lock=False) |
|
|
| |
| ir_config = get_ir_config(deploy_cfg) |
| ir_save_file = ir_config['save_file'] |
| ir_type = IR.get(ir_config['type']) |
| torch2ir(ir_type)( |
| args.img, |
| args.work_dir, |
| ir_save_file, |
| deploy_cfg_path, |
| model_cfg_path, |
| checkpoint_path, |
| device=args.device) |
|
|
| |
| ir_files = [osp.join(args.work_dir, ir_save_file)] |
|
|
| |
| partition_cfgs = get_partition_config(deploy_cfg) |
|
|
| if partition_cfgs is not None: |
|
|
| if 'partition_cfg' in partition_cfgs: |
| partition_cfgs = partition_cfgs.get('partition_cfg', None) |
| else: |
| assert 'type' in partition_cfgs |
| partition_cfgs = get_predefined_partition_cfg( |
| deploy_cfg, partition_cfgs['type']) |
|
|
| origin_ir_file = ir_files[0] |
| ir_files = [] |
| for partition_cfg in partition_cfgs: |
| save_file = partition_cfg['save_file'] |
| save_path = osp.join(args.work_dir, save_file) |
| start = partition_cfg['start'] |
| end = partition_cfg['end'] |
| dynamic_axes = partition_cfg.get('dynamic_axes', None) |
|
|
| extract_model( |
| origin_ir_file, |
| start, |
| end, |
| dynamic_axes=dynamic_axes, |
| save_file=save_path) |
|
|
| ir_files.append(save_path) |
|
|
| backend_files = ir_files |
| |
| backend = get_backend(deploy_cfg) |
|
|
| |
| PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) |
| if backend == Backend.TENSORRT: |
| PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) |
| backend_files = to_backend( |
| backend, |
| ir_files, |
| work_dir=args.work_dir, |
| deploy_cfg=deploy_cfg, |
| log_level=log_level, |
| device=args.device, |
| uri=args.uri) |
|
|
| if args.test_img is None: |
| args.test_img = args.img |
|
|
| extra = dict( |
| backend=backend, |
| output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), |
| show_result=args.show) |
| if backend == Backend.SNPE: |
| extra['uri'] = args.uri |
|
|
| |
| create_process( |
| f'visualize {backend.value} model', |
| target=visualize_model, |
| args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, |
| args.device), |
| kwargs=extra, |
| ret_value=ret_value) |
|
|
| |
| create_process( |
| 'visualize pytorch model', |
| target=visualize_model, |
| args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], |
| args.test_img, args.device), |
| kwargs=dict( |
| backend=Backend.PYTORCH, |
| output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), |
| show_result=args.show), |
| ret_value=ret_value) |
| logger.info('All process success.') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|