Upload 10 files
Browse files- .gitattributes +2 -0
- config.json +28 -0
- configuration_fastconformer.py +33 -0
- decoder_joint-asr.fp16.ts +3 -0
- encoder-asr.fp16.ts +3 -0
- example_inference.py +13 -0
- model_meta.json +18 -0
- modeling_fastconformer.py +168 -0
- preproc.pt +3 -0
- processing_fastconformer.py +92 -0
- tokenizer.model +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
decoder_joint-asr.fp16.ts filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
encoder-asr.fp16.ts filter=lfs diff=lfs merge=lfs -text
|
config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "fastconformer_tdt",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FastConformerTDTForCTC"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_fastconformer.FastConformerTDTConfig",
|
| 8 |
+
"AutoModel": "modeling_fastconformer.FastConformerTDTForCTC",
|
| 9 |
+
"AutoModelForCTC": "modeling_fastconformer.FastConformerTDTForCTC",
|
| 10 |
+
"AutoProcessor": "processing_fastconformer.FastConformerProcessor"
|
| 11 |
+
},
|
| 12 |
+
"vocab_size": 5000,
|
| 13 |
+
"blank_id": 5000,
|
| 14 |
+
"durations": [
|
| 15 |
+
0,
|
| 16 |
+
1,
|
| 17 |
+
2,
|
| 18 |
+
3,
|
| 19 |
+
4
|
| 20 |
+
],
|
| 21 |
+
"num_durations": 5,
|
| 22 |
+
"pred_hidden": 640,
|
| 23 |
+
"pred_rnn_layers": 1,
|
| 24 |
+
"max_symbols": 10,
|
| 25 |
+
"enc_d_model": 1024,
|
| 26 |
+
"feat_in": 128,
|
| 27 |
+
"sample_rate": 16000
|
| 28 |
+
}
|
configuration_fastconformer.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FastConformerTDTConfig(PretrainedConfig):
|
| 5 |
+
"""Config for the FastConformer TDT (RNNT) ASR model wrapped for HF transformers."""
|
| 6 |
+
|
| 7 |
+
model_type = "fastconformer_tdt"
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
vocab_size=3000,
|
| 12 |
+
blank_id=3000,
|
| 13 |
+
durations=(0, 1, 2, 3, 4),
|
| 14 |
+
num_durations=5,
|
| 15 |
+
pred_hidden=640,
|
| 16 |
+
pred_rnn_layers=1,
|
| 17 |
+
max_symbols=10,
|
| 18 |
+
enc_d_model=1024,
|
| 19 |
+
feat_in=128,
|
| 20 |
+
sample_rate=16000,
|
| 21 |
+
**kwargs,
|
| 22 |
+
):
|
| 23 |
+
self.vocab_size = vocab_size
|
| 24 |
+
self.blank_id = blank_id
|
| 25 |
+
self.durations = list(durations)
|
| 26 |
+
self.num_durations = num_durations
|
| 27 |
+
self.pred_hidden = pred_hidden
|
| 28 |
+
self.pred_rnn_layers = pred_rnn_layers
|
| 29 |
+
self.max_symbols = max_symbols
|
| 30 |
+
self.enc_d_model = enc_d_model
|
| 31 |
+
self.feat_in = feat_in
|
| 32 |
+
self.sample_rate = sample_rate
|
| 33 |
+
super().__init__(**kwargs)
|
decoder_joint-asr.fp16.ts
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6396aa30a8c252a7cf981ce3f4c63b1dc8783c83da7e9390acb04e4e1ec1a5e0
|
| 3 |
+
size 21539480
|
encoder-asr.fp16.ts
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b2c3789d91b06c0f19cd4b55156688fd291e325990361f7024b7f225444fb0d
|
| 3 |
+
size 887312617
|
example_inference.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, torch, torchaudio
|
| 2 |
+
from transformers import AutoModel, AutoProcessor
|
| 3 |
+
|
| 4 |
+
REPO = "."
|
| 5 |
+
DEV = "cuda" if torch.cuda.is_available() else "cpu"
|
| 6 |
+
model = AutoModel.from_pretrained(REPO, trust_remote_code=True).to(DEV).eval()
|
| 7 |
+
proc = AutoProcessor.from_pretrained(REPO, trust_remote_code=True)
|
| 8 |
+
|
| 9 |
+
for path in sys.argv[1:]:
|
| 10 |
+
wav, sr = torchaudio.load(path)
|
| 11 |
+
wav = wav.mean(0) if wav.shape[0] > 1 else wav.squeeze(0)
|
| 12 |
+
inputs = proc(wav, sampling_rate=sr, return_tensors="pt").to(DEV)
|
| 13 |
+
print(f"{path}\t{proc.batch_decode(model.generate(**inputs))[0]}")
|
model_meta.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 5000,
|
| 3 |
+
"blank_id": 5000,
|
| 4 |
+
"durations": [
|
| 5 |
+
0,
|
| 6 |
+
1,
|
| 7 |
+
2,
|
| 8 |
+
3,
|
| 9 |
+
4
|
| 10 |
+
],
|
| 11 |
+
"num_durations": 5,
|
| 12 |
+
"subsampling_factor": 8,
|
| 13 |
+
"pred_hidden": 640,
|
| 14 |
+
"pred_rnn_layers": 1,
|
| 15 |
+
"enc_d_model": 1024,
|
| 16 |
+
"max_symbols": 10,
|
| 17 |
+
"feat_in": 128
|
| 18 |
+
}
|
modeling_fastconformer.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HF `transformers`-compatible inference wrapper for the FastConformer TDT (RNNT)
|
| 3 |
+
ASR model. Wraps the exported TorchScript encoder + decoder_joint graphs.
|
| 4 |
+
|
| 5 |
+
Runtime deps: torch, transformers, sentencepiece. **No nemo_toolkit.**
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import PreTrainedModel
|
| 10 |
+
from transformers.modeling_outputs import ModelOutput
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, List
|
| 13 |
+
|
| 14 |
+
from .configuration_fastconformer import FastConformerTDTConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _normalize_per_feature(x, seq_len, constant):
|
| 18 |
+
B, _, max_time = x.shape
|
| 19 |
+
steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(B, max_time)
|
| 20 |
+
valid = steps < seq_len.unsqueeze(1)
|
| 21 |
+
denom = valid.sum(dim=1)
|
| 22 |
+
mean = torch.where(valid.unsqueeze(1), x, torch.zeros_like(x)).sum(dim=2) / denom.unsqueeze(1)
|
| 23 |
+
var = torch.sum(torch.where(valid.unsqueeze(1), x - mean.unsqueeze(2), torch.zeros_like(x)) ** 2,
|
| 24 |
+
dim=2) / (denom.unsqueeze(1) - 1.0)
|
| 25 |
+
std = torch.sqrt(var).masked_fill(torch.sqrt(var).isnan(), 0.0) + constant
|
| 26 |
+
return (x - mean.unsqueeze(2)) / std.unsqueeze(2)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ASRGreedyOutput(ModelOutput):
|
| 31 |
+
sequences: Optional[torch.LongTensor] = None
|
| 32 |
+
token_lists: Optional[List[List[int]]] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class FastConformerTDTForCTC(PreTrainedModel):
|
| 36 |
+
"""Named *ForCTC for AutoModel discoverability, but decoding is TDT/RNNT greedy."""
|
| 37 |
+
|
| 38 |
+
config_class = FastConformerTDTConfig
|
| 39 |
+
main_input_name = "input_features"
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: FastConformerTDTConfig):
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
self._anchor = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
|
| 44 |
+
self._artifacts_dir = getattr(config, "_name_or_path", ".") or "."
|
| 45 |
+
self.encoder = None
|
| 46 |
+
self.decoder_joint = None
|
| 47 |
+
self._loaded = False
|
| 48 |
+
self._io_dtype = torch.float32
|
| 49 |
+
self._req_dtype = None # user-requested compute dtype (torch_dtype / .half())
|
| 50 |
+
self.post_init()
|
| 51 |
+
|
| 52 |
+
def _ensure_loaded(self, device=None):
|
| 53 |
+
if self._loaded:
|
| 54 |
+
return
|
| 55 |
+
d = self._artifacts_dir
|
| 56 |
+
dev = device or self._anchor.device
|
| 57 |
+
on_cpu = (torch.device(dev).type == "cpu")
|
| 58 |
+
# The shipped graphs are fp16 on disk (half the size) but exported *unfrozen*,
|
| 59 |
+
# so weights are real parameters. CPU has no fp16 conv kernel, so on CPU we
|
| 60 |
+
# upcast to fp32 (lossless widening). On GPU we keep fp16 only if the user asked
|
| 61 |
+
# for it (torch_dtype=float16 or .half()); otherwise we upcast for exact fp32
|
| 62 |
+
# numerics. A legacy frozen fp32 graph (no parameters) is used as-is.
|
| 63 |
+
want_fp16 = (self._req_dtype == torch.float16) or (self._anchor.dtype == torch.float16)
|
| 64 |
+
|
| 65 |
+
def load(fp16_name, fp32_name):
|
| 66 |
+
p16 = os.path.join(d, fp16_name)
|
| 67 |
+
path = p16 if os.path.exists(p16) else os.path.join(d, fp32_name)
|
| 68 |
+
mod = torch.jit.load(path, map_location=dev).eval()
|
| 69 |
+
params = list(mod.parameters())
|
| 70 |
+
pdtype = params[0].dtype if params else torch.float32
|
| 71 |
+
target = torch.float16 if (want_fp16 and not on_cpu and params) else torch.float32
|
| 72 |
+
if params and pdtype != target:
|
| 73 |
+
mod = mod.half() if target == torch.float16 else mod.float()
|
| 74 |
+
return mod, target
|
| 75 |
+
|
| 76 |
+
self.encoder, io = load("encoder-asr.fp16.ts", "encoder-asr.ts")
|
| 77 |
+
self.decoder_joint, _ = load("decoder_joint-asr.fp16.ts", "decoder_joint-asr.ts")
|
| 78 |
+
self._io_dtype = io
|
| 79 |
+
pp = torch.load(os.path.join(d, "preproc.pt"), map_location="cpu")
|
| 80 |
+
self._p = pp["params"]
|
| 81 |
+
self.register_buffer("_window", pp["window"].to(dev), persistent=False)
|
| 82 |
+
self.register_buffer("_fb", pp["fb"].to(dev), persistent=False)
|
| 83 |
+
self._loaded = True
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_pretrained(cls, path, *args, **kwargs):
|
| 87 |
+
config = kwargs.pop("config", None) or FastConformerTDTConfig.from_pretrained(path)
|
| 88 |
+
config._name_or_path = path
|
| 89 |
+
model = cls(config)
|
| 90 |
+
model._artifacts_dir = path
|
| 91 |
+
dtype = kwargs.get("torch_dtype", None)
|
| 92 |
+
if isinstance(dtype, str):
|
| 93 |
+
dtype = None if dtype == "auto" else getattr(torch, dtype, None)
|
| 94 |
+
if dtype in (torch.float16, torch.float32):
|
| 95 |
+
model._req_dtype = dtype
|
| 96 |
+
if dtype == torch.float16:
|
| 97 |
+
model = model.half()
|
| 98 |
+
return model
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def extract_features(self, wav, wav_len):
|
| 102 |
+
self._ensure_loaded()
|
| 103 |
+
p = self._p
|
| 104 |
+
dev = self._anchor.device
|
| 105 |
+
n_fft, hop, win = p["n_fft"], p["hop_length"], p["win_length"]
|
| 106 |
+
wav = wav.to(dev).float()
|
| 107 |
+
wav_len = wav_len.to(dev)
|
| 108 |
+
seq_len = torch.div(wav_len + n_fft - n_fft, hop, rounding_mode="floor")
|
| 109 |
+
tmask = torch.arange(wav.shape[1], device=dev).unsqueeze(0) < wav_len.unsqueeze(1)
|
| 110 |
+
x = torch.cat((wav[:, :1], wav[:, 1:] - p["preemph"] * wav[:, :-1]), dim=1).masked_fill(~tmask, 0.0)
|
| 111 |
+
spec = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=win, window=self._window.float(),
|
| 112 |
+
center=True, pad_mode="constant", return_complex=True)
|
| 113 |
+
x = torch.view_as_real(spec)
|
| 114 |
+
x = torch.sqrt(x.pow(2).sum(-1)).pow(p["mag_power"])
|
| 115 |
+
x = torch.matmul(self._fb.float(), x)
|
| 116 |
+
x = torch.log(x + p["log_zero_guard_value"])
|
| 117 |
+
x = _normalize_per_feature(x, seq_len, p["CONSTANT"])
|
| 118 |
+
max_len = x.size(-1)
|
| 119 |
+
m = torch.arange(max_len, device=dev).repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
|
| 120 |
+
return x.masked_fill(m.unsqueeze(1), p["pad_value"]), seq_len
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def forward(self, input_features, feature_lengths=None, **kwargs):
|
| 124 |
+
self._ensure_loaded()
|
| 125 |
+
if feature_lengths is None:
|
| 126 |
+
feature_lengths = torch.full((input_features.size(0),), input_features.size(-1),
|
| 127 |
+
dtype=torch.long, device=input_features.device)
|
| 128 |
+
feats = input_features.to(self._anchor.device).to(self._io_dtype)
|
| 129 |
+
enc, enc_len = self.encoder(feats, feature_lengths.to(self._anchor.device))
|
| 130 |
+
return ModelOutput(last_hidden_state=enc, encoder_lengths=enc_len)
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def _greedy_one(self, enc_out, T):
|
| 134 |
+
cfg, dev = self.config, self._anchor.device
|
| 135 |
+
nd, blank, durs = cfg.num_durations, cfg.blank_id, cfg.durations
|
| 136 |
+
h = torch.zeros(cfg.pred_rnn_layers, 1, cfg.pred_hidden, device=dev, dtype=self._io_dtype)
|
| 137 |
+
c = torch.zeros(cfg.pred_rnn_layers, 1, cfg.pred_hidden, device=dev, dtype=self._io_dtype)
|
| 138 |
+
last, toks = blank, []
|
| 139 |
+
tlen = torch.ones(1, dtype=torch.int32, device=dev)
|
| 140 |
+
t = 0
|
| 141 |
+
while t < T:
|
| 142 |
+
f = enc_out.narrow(2, t, 1)
|
| 143 |
+
added, need = 0, True
|
| 144 |
+
while need and added < cfg.max_symbols:
|
| 145 |
+
tgt = torch.tensor([[last]], dtype=torch.int32, device=dev)
|
| 146 |
+
logits, _, h2, c2 = self.decoder_joint(f, tgt, tlen, h, c)
|
| 147 |
+
logits = logits[0, 0, 0]
|
| 148 |
+
k = int(logits[:-nd].argmax().item())
|
| 149 |
+
skip = durs[int(logits[-nd:].argmax().item())]
|
| 150 |
+
if k != blank:
|
| 151 |
+
toks.append(k); h, c, last = h2, c2, k
|
| 152 |
+
added += 1; t += skip; need = (skip == 0)
|
| 153 |
+
if added == cfg.max_symbols:
|
| 154 |
+
t += 1
|
| 155 |
+
return toks
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def generate(self, input_features=None, feature_lengths=None, **kwargs):
|
| 159 |
+
out = self.forward(input_features, feature_lengths)
|
| 160 |
+
enc, enc_len = out.last_hidden_state, out.encoder_lengths
|
| 161 |
+
token_lists = [self._greedy_one(enc[i:i + 1], int(enc_len[i].item())) for i in range(enc.size(0))]
|
| 162 |
+
maxlen = max((len(t) for t in token_lists), default=0)
|
| 163 |
+
pad = self.config.blank_id
|
| 164 |
+
seqs = torch.full((len(token_lists), max(maxlen, 1)), pad, dtype=torch.long)
|
| 165 |
+
for i, t in enumerate(token_lists):
|
| 166 |
+
if t:
|
| 167 |
+
seqs[i, :len(t)] = torch.tensor(t, dtype=torch.long)
|
| 168 |
+
return ASRGreedyOutput(sequences=seqs, token_lists=token_lists)
|
preproc.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8119ff424215f42c11eaf4e745c12bbe3076d5a037593ab523c1f15ef32f5b2f
|
| 3 |
+
size 135269
|
processing_fastconformer.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF processor: raw audio -> mel input_features, and token ids -> text (SentencePiece)."""
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import sentencepiece as spm
|
| 5 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _normalize_per_feature(x, seq_len, constant):
|
| 9 |
+
B, _, max_time = x.shape
|
| 10 |
+
steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(B, max_time)
|
| 11 |
+
valid = steps < seq_len.unsqueeze(1)
|
| 12 |
+
denom = valid.sum(dim=1)
|
| 13 |
+
mean = torch.where(valid.unsqueeze(1), x, torch.zeros_like(x)).sum(dim=2) / denom.unsqueeze(1)
|
| 14 |
+
var = torch.sum(torch.where(valid.unsqueeze(1), x - mean.unsqueeze(2), torch.zeros_like(x)) ** 2,
|
| 15 |
+
dim=2) / (denom.unsqueeze(1) - 1.0)
|
| 16 |
+
std = torch.sqrt(var).masked_fill(torch.sqrt(var).isnan(), 0.0) + constant
|
| 17 |
+
return (x - mean.unsqueeze(2)) / std.unsqueeze(2)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FastConformerProcessor:
|
| 21 |
+
def __init__(self, sp, window, fb, params, blank_id=3000, sample_rate=16000):
|
| 22 |
+
self.sp = sp
|
| 23 |
+
self.window = window
|
| 24 |
+
self.fb = fb
|
| 25 |
+
self.p = params
|
| 26 |
+
self.blank_id = blank_id
|
| 27 |
+
self.sample_rate = sample_rate
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def register_for_auto_class(cls, auto_class="AutoProcessor"):
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_pretrained(cls, path, **kwargs):
|
| 35 |
+
sp = spm.SentencePieceProcessor(model_file=os.path.join(path, "tokenizer.model"))
|
| 36 |
+
pp = torch.load(os.path.join(path, "preproc.pt"), map_location="cpu")
|
| 37 |
+
blank = pp.get("blank_id", 3000)
|
| 38 |
+
return cls(sp, pp["window"], pp["fb"], pp["params"], blank_id=blank)
|
| 39 |
+
|
| 40 |
+
def save_pretrained(self, path, **kwargs):
|
| 41 |
+
os.makedirs(path, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _to_2d(audio):
|
| 45 |
+
if isinstance(audio, torch.Tensor):
|
| 46 |
+
a = audio
|
| 47 |
+
else:
|
| 48 |
+
import numpy as np
|
| 49 |
+
a = torch.as_tensor(np.asarray(audio), dtype=torch.float32)
|
| 50 |
+
if a.dim() == 1:
|
| 51 |
+
a = a.unsqueeze(0)
|
| 52 |
+
return a.float()
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def __call__(self, audio, sampling_rate=16000, return_tensors="pt"):
|
| 56 |
+
wav = self._to_2d(audio)
|
| 57 |
+
if sampling_rate != self.sample_rate:
|
| 58 |
+
import torchaudio
|
| 59 |
+
wav = torchaudio.functional.resample(wav, sampling_rate, self.sample_rate)
|
| 60 |
+
p = self.p
|
| 61 |
+
n_fft, hop, win = p["n_fft"], p["hop_length"], p["win_length"]
|
| 62 |
+
wav_len = torch.tensor([wav.shape[1]] * wav.shape[0])
|
| 63 |
+
seq_len = torch.div(wav_len + n_fft - n_fft, hop, rounding_mode="floor")
|
| 64 |
+
tmask = torch.arange(wav.shape[1]).unsqueeze(0) < wav_len.unsqueeze(1)
|
| 65 |
+
x = torch.cat((wav[:, :1], wav[:, 1:] - p["preemph"] * wav[:, :-1]), dim=1).masked_fill(~tmask, 0.0)
|
| 66 |
+
spec = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=win, window=self.window,
|
| 67 |
+
center=True, pad_mode="constant", return_complex=True)
|
| 68 |
+
x = torch.view_as_real(spec)
|
| 69 |
+
x = torch.sqrt(x.pow(2).sum(-1)).pow(p["mag_power"])
|
| 70 |
+
x = torch.matmul(self.fb.to(x.dtype), x)
|
| 71 |
+
x = torch.log(x + p["log_zero_guard_value"])
|
| 72 |
+
x = _normalize_per_feature(x, seq_len, p["CONSTANT"])
|
| 73 |
+
max_len = x.size(-1)
|
| 74 |
+
m = torch.arange(max_len).repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
|
| 75 |
+
x = x.masked_fill(m.unsqueeze(1), p["pad_value"])
|
| 76 |
+
return BatchFeature({"input_features": x, "feature_lengths": seq_len}, tensor_type=return_tensors)
|
| 77 |
+
|
| 78 |
+
def _clean(self, ids):
|
| 79 |
+
return [int(i) for i in ids if int(i) != self.blank_id]
|
| 80 |
+
|
| 81 |
+
def batch_decode(self, sequences, **kwargs):
|
| 82 |
+
token_lists = getattr(sequences, "token_lists", None)
|
| 83 |
+
if token_lists is None:
|
| 84 |
+
seqs = getattr(sequences, "sequences", sequences)
|
| 85 |
+
if isinstance(seqs, torch.Tensor):
|
| 86 |
+
token_lists = [self._clean(row.tolist()) for row in seqs]
|
| 87 |
+
else:
|
| 88 |
+
token_lists = [self._clean(row) for row in seqs]
|
| 89 |
+
return [self.sp.decode(t) for t in token_lists]
|
| 90 |
+
|
| 91 |
+
def decode(self, sequence, **kwargs):
|
| 92 |
+
return self.batch_decode([sequence], **kwargs)[0]
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0af4086ec53482ac0cef0375369cf4ce7bafaf8b0a7203e97d126d0599ab90a6
|
| 3 |
+
size 325287
|