| | |
| | |
| |
|
| | import os |
| | from rknn.api import RKNN |
| | from math import exp |
| | from sys import exit |
| | import argparse |
| | import onnxscript |
| | from onnxscript.rewriter import pattern |
| | import onnx.numpy_helper as onh |
| | import numpy as np |
| | import onnx |
| | import onnxruntime as ort |
| | from rknn.utils import onnx_edit |
| |
|
| | os.chdir(os.path.dirname(os.path.abspath(__file__))) |
| |
|
| | speech_length = 171 |
| |
|
| | def convert_encoder(): |
| | rknn = RKNN(verbose=True) |
| |
|
| | ONNX_MODEL=f"sense-voice-encoder.onnx" |
| | RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn") |
| | DATASET="dataset.txt" |
| | QUANTIZE=False |
| |
|
| | |
| | |
| | onnx.utils.extract_model(ONNX_MODEL, "extract_model.onnx", ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0']) |
| | sess = ort.InferenceSession("extract_model.onnx", providers=['CPUExecutionProvider']) |
| | extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0] |
| |
|
| | |
| | ret = onnx_edit(model = ONNX_MODEL, |
| | export_path = ONNX_MODEL.replace(".onnx", "_edited.onnx"), |
| | |
| | |
| | outputs_transform = {'encoder_out': 'a,b,c->a,c,b'}, |
| | ) |
| | ONNX_MODEL = ONNX_MODEL.replace(".onnx", "_edited.onnx") |
| |
|
| | |
| | print('--> Config model') |
| | rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3) |
| | print('done') |
| |
|
| | |
| | print("--> Loading model") |
| | ret = rknn.load_onnx( |
| | model=ONNX_MODEL, |
| | inputs=["speech", "/make_pad_mask/Cast_2_output_0"], |
| | input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]], |
| | input_initial_val=[None, extract_result], |
| | |
| | ) |
| |
|
| | if ret != 0: |
| | print('Load model failed!') |
| | exit(ret) |
| | print('done') |
| |
|
| | |
| | print('--> Building model') |
| | ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None) |
| | if ret != 0: |
| | print('Build model failed!') |
| | exit(ret) |
| | print('done') |
| |
|
| | |
| | print('--> Export RKNN model') |
| | ret = rknn.export_rknn(RKNN_MODEL) |
| | if ret != 0: |
| | print('Export RKNN model failed!') |
| | exit(ret) |
| | print('done') |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("model", type=str, help="model to convert", choices=["encoder", "all"], nargs='?') |
| | args = parser.parse_args() |
| | if args.model is None: |
| | args.model = "all" |
| |
|
| | if args.model == "encoder": |
| | convert_encoder() |
| | elif args.model == "all": |
| | convert_encoder() |
| | else: |
| | print(f"Unknown model: {args.model}") |
| | exit(1) |
| |
|