| | import types |
| |
|
| | import argparse |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import onnx |
| | import onnxsim |
| |
|
| | from basicsr.archs.ddcolor_arch import DDColor |
| |
|
| | from onnx import load_model, save_model, shape_inference |
| | from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Export DDColor model to ONNX.") |
| | parser.add_argument( |
| | "--input_size", |
| | type=int, |
| | default=512, |
| | help="Input image dimension.", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=1, |
| | help="Input batch size.", |
| | ) |
| | parser.add_argument( |
| | "--model_path", |
| | type=str, |
| | required=True, |
| | help="Path to export ONNX model.", |
| | ) |
| | parser.add_argument( |
| | "--model_size", |
| | type=str, |
| | default="tiny", |
| | help="Path to export ONNX model.", |
| | ) |
| | parser.add_argument( |
| | "--decoder_type", |
| | type=str, |
| | default="MultiScaleColorDecoder", |
| | help="Path to export ONNX model.", |
| | ) |
| | parser.add_argument( |
| | "--export_path", |
| | type=str, |
| | default="./model.onnx", |
| | help="Path to export ONNX model.", |
| | ) |
| | parser.add_argument( |
| | "--opset", |
| | type=int, |
| | default=12, |
| | help="ONNX opset version.", |
| | ) |
| |
|
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | def create_onnx_export(args): |
| | input_size = args.input_size |
| | device = torch.device('cpu') |
| | if args.model_size == 'tiny': |
| | encoder_name = 'convnext-t' |
| | else: |
| | encoder_name = 'convnext-l' |
| |
|
| | |
| | |
| |
|
| | if args.decoder_type == 'MultiScaleColorDecoder': |
| | model = DDColor( |
| | encoder_name=encoder_name, |
| | decoder_name='MultiScaleColorDecoder', |
| | input_size=[input_size, input_size], |
| | num_output_channels=2, |
| | last_norm='Spectral', |
| | do_normalize=False, |
| | num_queries=100, |
| | num_scales=3, |
| | dec_layers=9, |
| | ).to(device) |
| | elif args.decoder_type == 'SingleColorDecoder': |
| | model = DDColor( |
| | encoder_name=encoder_name, |
| | decoder_name='SingleColorDecoder', |
| | input_size=[input_size, input_size], |
| | num_output_channels=2, |
| | last_norm='Spectral', |
| | do_normalize=False, |
| | num_queries=256, |
| | ).to(device) |
| | else: |
| | raise("decoder_type not implemented.") |
| |
|
| | model.load_state_dict( |
| | torch.load(args.model_path, map_location=device)['params'], |
| | strict=False) |
| | model.eval() |
| |
|
| | channels = 3 |
| |
|
| | random_input = torch.rand((args.batch_size, channels, input_size, input_size), dtype=torch.float32) |
| |
|
| | dynamic_axes = {} |
| | if args.batch_size == 0: |
| | dynamic_axes[0] = "batch" |
| | if input_size == 0: |
| | dynamic_axes[2] = "height" |
| | dynamic_axes[3] = "width" |
| | |
| | torch.onnx.export( |
| | model, |
| | random_input, |
| | args.export_path, |
| | opset_version=args.opset, |
| | input_names=["input"], |
| | output_names=["output"], |
| | dynamic_axes={ |
| | "input": dynamic_axes, |
| | "output": dynamic_axes |
| | }, |
| | ) |
| |
|
| | def check_onnx_export(export_path): |
| | save_model( |
| | shape_inference.infer_shapes( |
| | load_model(export_path), |
| | check_type=True, |
| | strict_mode=True, |
| | data_prop=True |
| |
|
| | ), |
| | export_path |
| | ) |
| |
|
| | save_model( |
| | SymbolicShapeInference.infer_shapes(load_model(export_path), |
| | auto_merge=True, |
| | guess_output_rank=True |
| | ), |
| | export_path, |
| | ) |
| |
|
| | model_onnx = onnx.load(export_path) |
| | onnx.checker.check_model(model_onnx) |
| |
|
| | model_onnx, check = onnxsim.simplify(model_onnx) |
| | assert check, "assert check failed" |
| | onnx.save(model_onnx, export_path) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = parse_args() |
| |
|
| | create_onnx_export(args) |
| | print(f'ONNX file successfully created at {args.export_path}') |
| | check_onnx_export(args.export_path) |
| | print(f'ONNX file at {args.export_path} verifed shapes and simplified') |
| | |