xuan3986's picture
Upload 111 files
03022ee verified
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
import time
import torch
import logging
import os
from tqdm import tqdm
from funcineforge.utils.misc import deep_update
from funcineforge.utils.set_all_random_seed import set_all_random_seed
from funcineforge.utils.load_pretrained_model import load_pretrained_model
from funcineforge.download.download_model_from_hub import download_model
from funcineforge.tokenizer import FunCineForgeTokenizer
from funcineforge.face import FaceRecIR101
import importlib
def prepare_data_iterator(data_in, input_len):
""" """
data_list = []
key_list = []
for idx in range(input_len):
item = data_in[idx]
utt = item["utt"]
data_list.append(item)
key_list.append(utt)
return key_list, data_list
class AutoModel:
def __init__(self, **kwargs):
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)
model, kwargs = self.build_model(**kwargs)
self.kwargs = kwargs
self.model = model
self.model_path = kwargs.get("model_path")
@staticmethod
def build_model(**kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from {} or local dir".format(kwargs.get("hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
device = "cpu"
kwargs["batch_size"] = 1
kwargs["device"] = device
torch.set_num_threads(kwargs.get("ncpu", 4))
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer = FunCineForgeTokenizer(**kwargs.get("tokenizer_conf", {}))
kwargs["token_list"] = (
tokenizer.token_list if hasattr(tokenizer, "token_list") else None
)
kwargs["token_list"] = (
tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
)
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
vocab_size = tokenizer.get_vocab_size()
else:
vocab_size = -1
kwargs["tokenizer"] = tokenizer
# build face_encoder
face_encoder = kwargs.get("face_encoder", None)
if face_encoder is not None:
face_encoder = FaceRecIR101(**kwargs.get("face_encoder_conf", {}))
kwargs["face_encoder"] = face_encoder
model_conf = {}
model_class_name = kwargs["model"]
deep_update(model_conf, kwargs.get("model_conf", {}))
deep_update(model_conf, kwargs)
module = importlib.import_module("funcineforge.models")
model_class = getattr(module, model_class_name)
model = model_class(**model_conf, vocab_size=vocab_size)
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None and os.path.exists(init_param):
logging.info(f"Loading pretrained params from ckpt: {init_param}")
load_pretrained_model(
path=init_param,
model=model,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
scope_map=kwargs.get("scope_map", []),
excludes=kwargs.get("excludes", None),
use_deepspeed=kwargs.get("train_conf", {}).get("use_deepspeed", False),
save_deepspeed_zero_fp32=kwargs.get("save_deepspeed_zero_fp32", True),
)
# fp16
if kwargs.get("fp16", False):
model.to(torch.float16)
elif kwargs.get("bf16", False):
model.to(torch.bfloat16)
model.to(device)
return model, kwargs
def __call__(self, *args, **cfg):
kwargs = self.kwargs
deep_update(kwargs, cfg)
res = self.model(*args, kwargs)
return res
def inference(self, input, input_len=None, model=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
deep_update(kwargs, cfg)
model = self.model if model is None else model
model.eval()
batch_size = kwargs.get("batch_size", 1)
key_list, data_list = prepare_data_iterator(
input, input_len=input_len
)
speed_stats = {}
num_samples = len(data_list)
disable_pbar = self.kwargs.get("disable_pbar", False)
pbar = (
tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
)
time_speech_total = 0.0
time_escape_total = 0.0
count = 0
log_interval = kwargs.get("log_interval", None)
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "data_lengths": end_idx - beg_idx, "key": key_batch}
time1 = time.perf_counter()
with torch.no_grad():
res = model.inference(**batch, **kwargs)
if isinstance(res, (list, tuple)):
results = res[0] if len(res) > 0 else [{"text": ""}]
meta_data = res[1] if len(res) > 1 else {}
time2 = time.perf_counter()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = f"{speed_stats}, "
if pbar:
pbar.update(batch_size)
pbar.set_description(description)
else:
if log_interval is not None and count % log_interval == 0:
logging.info(
f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}"
)
time_speech_total += batch_data_time
time_escape_total += time_escape
count += 1
if pbar:
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return