| | from io import BytesIO |
| | import pickle |
| | import time |
| | import torch |
| | from tqdm import tqdm |
| | from collections import OrderedDict |
| |
|
| |
|
| | def load_inputs(path, device, is_half=False): |
| | parm = torch.load(path, map_location=torch.device("cpu")) |
| | for key in parm.keys(): |
| | parm[key] = parm[key].to(device) |
| | if is_half and parm[key].dtype == torch.float32: |
| | parm[key] = parm[key].half() |
| | elif not is_half and parm[key].dtype == torch.float16: |
| | parm[key] = parm[key].float() |
| | return parm |
| |
|
| |
|
| | def benchmark( |
| | model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False |
| | ): |
| | parm = load_inputs(inputs_path, device, is_half) |
| | total_ts = 0.0 |
| | bar = tqdm(range(epoch)) |
| | for i in bar: |
| | start_time = time.perf_counter() |
| | o = model(**parm) |
| | total_ts += time.perf_counter() - start_time |
| | print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}") |
| |
|
| |
|
| | def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False): |
| | benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half) |
| |
|
| |
|
| | def to_jit_model( |
| | model_path, |
| | model_type: str, |
| | mode: str = "trace", |
| | inputs_path: str = None, |
| | device=torch.device("cpu"), |
| | is_half=False, |
| | ): |
| | model = None |
| | if model_type.lower() == "synthesizer": |
| | from .get_synthesizer import get_synthesizer |
| |
|
| | model, _ = get_synthesizer(model_path, device) |
| | model.forward = model.infer |
| | elif model_type.lower() == "rmvpe": |
| | from .get_rmvpe import get_rmvpe |
| |
|
| | model = get_rmvpe(model_path, device) |
| | elif model_type.lower() == "hubert": |
| | from .get_hubert import get_hubert_model |
| |
|
| | model = get_hubert_model(model_path, device) |
| | model.forward = model.infer |
| | else: |
| | raise ValueError(f"No model type named {model_type}") |
| | model = model.eval() |
| | model = model.half() if is_half else model.float() |
| | if mode == "trace": |
| | assert not inputs_path |
| | inputs = load_inputs(inputs_path, device, is_half) |
| | model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) |
| | elif mode == "script": |
| | model_jit = torch.jit.script(model) |
| | model_jit.to(device) |
| | model_jit = model_jit.half() if is_half else model_jit.float() |
| | |
| | return (model, model_jit) |
| |
|
| |
|
| | def export( |
| | model: torch.nn.Module, |
| | mode: str = "trace", |
| | inputs: dict = None, |
| | device=torch.device("cpu"), |
| | is_half: bool = False, |
| | ) -> dict: |
| | model = model.half() if is_half else model.float() |
| | model.eval() |
| | if mode == "trace": |
| | assert inputs is not None |
| | model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) |
| | elif mode == "script": |
| | model_jit = torch.jit.script(model) |
| | model_jit.to(device) |
| | model_jit = model_jit.half() if is_half else model_jit.float() |
| | buffer = BytesIO() |
| | |
| | torch.jit.save(model_jit, buffer) |
| | del model_jit |
| | cpt = OrderedDict() |
| | cpt["model"] = buffer.getvalue() |
| | cpt["is_half"] = is_half |
| | return cpt |
| |
|
| |
|
| | def load(path: str): |
| | with open(path, "rb") as f: |
| | return pickle.load(f) |
| |
|
| |
|
| | def save(ckpt: dict, save_path: str): |
| | with open(save_path, "wb") as f: |
| | pickle.dump(ckpt, f) |
| |
|
| |
|
| | def rmvpe_jit_export( |
| | model_path: str, |
| | mode: str = "script", |
| | inputs_path: str = None, |
| | save_path: str = None, |
| | device=torch.device("cpu"), |
| | is_half=False, |
| | ): |
| | if not save_path: |
| | save_path = model_path.rstrip(".pth") |
| | save_path += ".half.jit" if is_half else ".jit" |
| | if "cuda" in str(device) and ":" not in str(device): |
| | device = torch.device("cuda:0") |
| | from .get_rmvpe import get_rmvpe |
| |
|
| | model = get_rmvpe(model_path, device) |
| | inputs = None |
| | if mode == "trace": |
| | inputs = load_inputs(inputs_path, device, is_half) |
| | ckpt = export(model, mode, inputs, device, is_half) |
| | ckpt["device"] = str(device) |
| | save(ckpt, save_path) |
| | return ckpt |
| |
|
| |
|
| | def synthesizer_jit_export( |
| | model_path: str, |
| | mode: str = "script", |
| | inputs_path: str = None, |
| | save_path: str = None, |
| | device=torch.device("cpu"), |
| | is_half=False, |
| | ): |
| | if not save_path: |
| | save_path = model_path.rstrip(".pth") |
| | save_path += ".half.jit" if is_half else ".jit" |
| | if "cuda" in str(device) and ":" not in str(device): |
| | device = torch.device("cuda:0") |
| | from .get_synthesizer import get_synthesizer |
| |
|
| | model, cpt = get_synthesizer(model_path, device) |
| | assert isinstance(cpt, dict) |
| | model.forward = model.infer |
| | inputs = None |
| | if mode == "trace": |
| | inputs = load_inputs(inputs_path, device, is_half) |
| | ckpt = export(model, mode, inputs, device, is_half) |
| | cpt.pop("weight") |
| | cpt["model"] = ckpt["model"] |
| | cpt["device"] = device |
| | save(cpt, save_path) |
| | return cpt |
| |
|