# Copyright (c) OpenMMLab. All rights reserved. # Modified from mmdeploy/tools/deploy.py, removed some codes to only focus on ONNX report import argparse import logging import os import os.path as osp from functools import partial import mmengine import torch.multiprocessing as mp from torch.multiprocessing import Process, set_start_method from mmdeploy.apis import (create_calib_input_data, extract_model, get_predefined_partition_cfg, torch2onnx, torch2torchscript, visualize_model) from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.apis.utils import to_backend from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, get_ir_config, get_partition_config, get_root_logger, load_config, target_wrapper) def parse_args(): parser = argparse.ArgumentParser(description='Export model to backends.') parser.add_argument('deploy_cfg', help='deploy config path') parser.add_argument('model_cfg', help='model config path') parser.add_argument('checkpoint', help='model checkpoint path') parser.add_argument('img', help='image used to convert model model') parser.add_argument( '--test-img', default=None, type=str, nargs='+', help='image used to test model') parser.add_argument( '--work-dir', default=os.getcwd(), help='the dir to save logs and models') parser.add_argument( '--calib-dataset-cfg', help='dataset config path used to calibrate in int8 mode. If not \ specified, it will use "val" dataset in model config instead.', default=None) parser.add_argument( '--device', help='device used for conversion', default='cpu') parser.add_argument( '--log-level', help='set log level', default='INFO', choices=list(logging._nameToLevel.keys())) parser.add_argument( '--show', action='store_true', help='Show detection outputs') parser.add_argument( '--dump-info', action='store_true', help='Output information for SDK') parser.add_argument( '--quant-image-dir', default=None, help='Image directory for quantize model.') parser.add_argument( '--quant', action='store_true', help='Quantize model to low bit.') parser.add_argument( '--uri', default='192.168.1.1:60000', help='Remote ipv4:port or ipv6:port for inference on edge device.') args = parser.parse_args() return args def create_process(name, target, args, kwargs, ret_value=None): logger = get_root_logger() logger.info(f'{name} start.') log_level = logger.level wrap_func = partial(target_wrapper, target, log_level, ret_value) process = Process(target=wrap_func, args=args, kwargs=kwargs) process.start() process.join() if ret_value is not None: if ret_value.value != 0: logger.error(f'{name} failed.') exit(1) else: logger.info(f'{name} success.') def torch2ir(ir_type: IR): """Return the conversion function from torch to the intermediate representation. Args: ir_type (IR): The type of the intermediate representation. """ if ir_type == IR.ONNX: return torch2onnx elif ir_type == IR.TORCHSCRIPT: return torch2torchscript else: raise KeyError(f'Unexpected IR type {ir_type}') def main(): args = parse_args() set_start_method('spawn', force=True) logger = get_root_logger() log_level = logging.getLevelName(args.log_level) logger.setLevel(log_level) pipeline_funcs = [ torch2onnx, torch2torchscript, extract_model, create_calib_input_data ] PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) deploy_cfg_path = args.deploy_cfg model_cfg_path = args.model_cfg checkpoint_path = args.checkpoint quant = args.quant quant_image_dir = args.quant_image_dir # load deploy_cfg deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) # create work_dir if not mmengine.mkdir_or_exist(osp.abspath(args.work_dir)) if args.dump_info: export2SDK( deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path, device=args.device) ret_value = mp.Value('d', 0, lock=False) # convert to IR ir_config = get_ir_config(deploy_cfg) ir_save_file = ir_config['save_file'] ir_type = IR.get(ir_config['type']) torch2ir(ir_type)( args.img, args.work_dir, ir_save_file, deploy_cfg_path, model_cfg_path, checkpoint_path, device=args.device) # convert backend ir_files = [osp.join(args.work_dir, ir_save_file)] # partition model partition_cfgs = get_partition_config(deploy_cfg) if partition_cfgs is not None: if 'partition_cfg' in partition_cfgs: partition_cfgs = partition_cfgs.get('partition_cfg', None) else: assert 'type' in partition_cfgs partition_cfgs = get_predefined_partition_cfg( deploy_cfg, partition_cfgs['type']) origin_ir_file = ir_files[0] ir_files = [] for partition_cfg in partition_cfgs: save_file = partition_cfg['save_file'] save_path = osp.join(args.work_dir, save_file) start = partition_cfg['start'] end = partition_cfg['end'] dynamic_axes = partition_cfg.get('dynamic_axes', None) extract_model( origin_ir_file, start, end, dynamic_axes=dynamic_axes, save_file=save_path) ir_files.append(save_path) backend_files = ir_files # convert backend backend = get_backend(deploy_cfg) # convert to backend PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) if backend == Backend.TENSORRT: PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) backend_files = to_backend( backend, ir_files, work_dir=args.work_dir, deploy_cfg=deploy_cfg, log_level=log_level, device=args.device, uri=args.uri) if args.test_img is None: args.test_img = args.img extra = dict( backend=backend, output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), show_result=args.show) if backend == Backend.SNPE: extra['uri'] = args.uri # get backend inference result, try render create_process( f'visualize {backend.value} model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, args.device), kwargs=extra, ret_value=ret_value) # get pytorch model inference result, try visualize if possible create_process( 'visualize pytorch model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], args.test_img, args.device), kwargs=dict( backend=Backend.PYTORCH, output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), show_result=args.show), ret_value=ret_value) logger.info('All process success.') if __name__ == '__main__': main()