SujithPulikodan commited on
Commit
ffb2413
·
verified ·
1 Parent(s): 10740dc

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ decoder_joint-asr.fp16.ts filter=lfs diff=lfs merge=lfs -text
37
+ encoder-asr.fp16.ts filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "fastconformer_tdt",
3
+ "architectures": [
4
+ "FastConformerTDTForCTC"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_fastconformer.FastConformerTDTConfig",
8
+ "AutoModel": "modeling_fastconformer.FastConformerTDTForCTC",
9
+ "AutoModelForCTC": "modeling_fastconformer.FastConformerTDTForCTC",
10
+ "AutoProcessor": "processing_fastconformer.FastConformerProcessor"
11
+ },
12
+ "vocab_size": 5000,
13
+ "blank_id": 5000,
14
+ "durations": [
15
+ 0,
16
+ 1,
17
+ 2,
18
+ 3,
19
+ 4
20
+ ],
21
+ "num_durations": 5,
22
+ "pred_hidden": 640,
23
+ "pred_rnn_layers": 1,
24
+ "max_symbols": 10,
25
+ "enc_d_model": 1024,
26
+ "feat_in": 128,
27
+ "sample_rate": 16000
28
+ }
configuration_fastconformer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class FastConformerTDTConfig(PretrainedConfig):
5
+ """Config for the FastConformer TDT (RNNT) ASR model wrapped for HF transformers."""
6
+
7
+ model_type = "fastconformer_tdt"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=3000,
12
+ blank_id=3000,
13
+ durations=(0, 1, 2, 3, 4),
14
+ num_durations=5,
15
+ pred_hidden=640,
16
+ pred_rnn_layers=1,
17
+ max_symbols=10,
18
+ enc_d_model=1024,
19
+ feat_in=128,
20
+ sample_rate=16000,
21
+ **kwargs,
22
+ ):
23
+ self.vocab_size = vocab_size
24
+ self.blank_id = blank_id
25
+ self.durations = list(durations)
26
+ self.num_durations = num_durations
27
+ self.pred_hidden = pred_hidden
28
+ self.pred_rnn_layers = pred_rnn_layers
29
+ self.max_symbols = max_symbols
30
+ self.enc_d_model = enc_d_model
31
+ self.feat_in = feat_in
32
+ self.sample_rate = sample_rate
33
+ super().__init__(**kwargs)
decoder_joint-asr.fp16.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6396aa30a8c252a7cf981ce3f4c63b1dc8783c83da7e9390acb04e4e1ec1a5e0
3
+ size 21539480
encoder-asr.fp16.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b2c3789d91b06c0f19cd4b55156688fd291e325990361f7024b7f225444fb0d
3
+ size 887312617
example_inference.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, torch, torchaudio
2
+ from transformers import AutoModel, AutoProcessor
3
+
4
+ REPO = "."
5
+ DEV = "cuda" if torch.cuda.is_available() else "cpu"
6
+ model = AutoModel.from_pretrained(REPO, trust_remote_code=True).to(DEV).eval()
7
+ proc = AutoProcessor.from_pretrained(REPO, trust_remote_code=True)
8
+
9
+ for path in sys.argv[1:]:
10
+ wav, sr = torchaudio.load(path)
11
+ wav = wav.mean(0) if wav.shape[0] > 1 else wav.squeeze(0)
12
+ inputs = proc(wav, sampling_rate=sr, return_tensors="pt").to(DEV)
13
+ print(f"{path}\t{proc.batch_decode(model.generate(**inputs))[0]}")
model_meta.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 5000,
3
+ "blank_id": 5000,
4
+ "durations": [
5
+ 0,
6
+ 1,
7
+ 2,
8
+ 3,
9
+ 4
10
+ ],
11
+ "num_durations": 5,
12
+ "subsampling_factor": 8,
13
+ "pred_hidden": 640,
14
+ "pred_rnn_layers": 1,
15
+ "enc_d_model": 1024,
16
+ "max_symbols": 10,
17
+ "feat_in": 128
18
+ }
modeling_fastconformer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF `transformers`-compatible inference wrapper for the FastConformer TDT (RNNT)
3
+ ASR model. Wraps the exported TorchScript encoder + decoder_joint graphs.
4
+
5
+ Runtime deps: torch, transformers, sentencepiece. **No nemo_toolkit.**
6
+ """
7
+ import os
8
+ import torch
9
+ from transformers import PreTrainedModel
10
+ from transformers.modeling_outputs import ModelOutput
11
+ from dataclasses import dataclass
12
+ from typing import Optional, List
13
+
14
+ from .configuration_fastconformer import FastConformerTDTConfig
15
+
16
+
17
+ def _normalize_per_feature(x, seq_len, constant):
18
+ B, _, max_time = x.shape
19
+ steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(B, max_time)
20
+ valid = steps < seq_len.unsqueeze(1)
21
+ denom = valid.sum(dim=1)
22
+ mean = torch.where(valid.unsqueeze(1), x, torch.zeros_like(x)).sum(dim=2) / denom.unsqueeze(1)
23
+ var = torch.sum(torch.where(valid.unsqueeze(1), x - mean.unsqueeze(2), torch.zeros_like(x)) ** 2,
24
+ dim=2) / (denom.unsqueeze(1) - 1.0)
25
+ std = torch.sqrt(var).masked_fill(torch.sqrt(var).isnan(), 0.0) + constant
26
+ return (x - mean.unsqueeze(2)) / std.unsqueeze(2)
27
+
28
+
29
+ @dataclass
30
+ class ASRGreedyOutput(ModelOutput):
31
+ sequences: Optional[torch.LongTensor] = None
32
+ token_lists: Optional[List[List[int]]] = None
33
+
34
+
35
+ class FastConformerTDTForCTC(PreTrainedModel):
36
+ """Named *ForCTC for AutoModel discoverability, but decoding is TDT/RNNT greedy."""
37
+
38
+ config_class = FastConformerTDTConfig
39
+ main_input_name = "input_features"
40
+
41
+ def __init__(self, config: FastConformerTDTConfig):
42
+ super().__init__(config)
43
+ self._anchor = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
44
+ self._artifacts_dir = getattr(config, "_name_or_path", ".") or "."
45
+ self.encoder = None
46
+ self.decoder_joint = None
47
+ self._loaded = False
48
+ self._io_dtype = torch.float32
49
+ self._req_dtype = None # user-requested compute dtype (torch_dtype / .half())
50
+ self.post_init()
51
+
52
+ def _ensure_loaded(self, device=None):
53
+ if self._loaded:
54
+ return
55
+ d = self._artifacts_dir
56
+ dev = device or self._anchor.device
57
+ on_cpu = (torch.device(dev).type == "cpu")
58
+ # The shipped graphs are fp16 on disk (half the size) but exported *unfrozen*,
59
+ # so weights are real parameters. CPU has no fp16 conv kernel, so on CPU we
60
+ # upcast to fp32 (lossless widening). On GPU we keep fp16 only if the user asked
61
+ # for it (torch_dtype=float16 or .half()); otherwise we upcast for exact fp32
62
+ # numerics. A legacy frozen fp32 graph (no parameters) is used as-is.
63
+ want_fp16 = (self._req_dtype == torch.float16) or (self._anchor.dtype == torch.float16)
64
+
65
+ def load(fp16_name, fp32_name):
66
+ p16 = os.path.join(d, fp16_name)
67
+ path = p16 if os.path.exists(p16) else os.path.join(d, fp32_name)
68
+ mod = torch.jit.load(path, map_location=dev).eval()
69
+ params = list(mod.parameters())
70
+ pdtype = params[0].dtype if params else torch.float32
71
+ target = torch.float16 if (want_fp16 and not on_cpu and params) else torch.float32
72
+ if params and pdtype != target:
73
+ mod = mod.half() if target == torch.float16 else mod.float()
74
+ return mod, target
75
+
76
+ self.encoder, io = load("encoder-asr.fp16.ts", "encoder-asr.ts")
77
+ self.decoder_joint, _ = load("decoder_joint-asr.fp16.ts", "decoder_joint-asr.ts")
78
+ self._io_dtype = io
79
+ pp = torch.load(os.path.join(d, "preproc.pt"), map_location="cpu")
80
+ self._p = pp["params"]
81
+ self.register_buffer("_window", pp["window"].to(dev), persistent=False)
82
+ self.register_buffer("_fb", pp["fb"].to(dev), persistent=False)
83
+ self._loaded = True
84
+
85
+ @classmethod
86
+ def from_pretrained(cls, path, *args, **kwargs):
87
+ config = kwargs.pop("config", None) or FastConformerTDTConfig.from_pretrained(path)
88
+ config._name_or_path = path
89
+ model = cls(config)
90
+ model._artifacts_dir = path
91
+ dtype = kwargs.get("torch_dtype", None)
92
+ if isinstance(dtype, str):
93
+ dtype = None if dtype == "auto" else getattr(torch, dtype, None)
94
+ if dtype in (torch.float16, torch.float32):
95
+ model._req_dtype = dtype
96
+ if dtype == torch.float16:
97
+ model = model.half()
98
+ return model
99
+
100
+ @torch.no_grad()
101
+ def extract_features(self, wav, wav_len):
102
+ self._ensure_loaded()
103
+ p = self._p
104
+ dev = self._anchor.device
105
+ n_fft, hop, win = p["n_fft"], p["hop_length"], p["win_length"]
106
+ wav = wav.to(dev).float()
107
+ wav_len = wav_len.to(dev)
108
+ seq_len = torch.div(wav_len + n_fft - n_fft, hop, rounding_mode="floor")
109
+ tmask = torch.arange(wav.shape[1], device=dev).unsqueeze(0) < wav_len.unsqueeze(1)
110
+ x = torch.cat((wav[:, :1], wav[:, 1:] - p["preemph"] * wav[:, :-1]), dim=1).masked_fill(~tmask, 0.0)
111
+ spec = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=win, window=self._window.float(),
112
+ center=True, pad_mode="constant", return_complex=True)
113
+ x = torch.view_as_real(spec)
114
+ x = torch.sqrt(x.pow(2).sum(-1)).pow(p["mag_power"])
115
+ x = torch.matmul(self._fb.float(), x)
116
+ x = torch.log(x + p["log_zero_guard_value"])
117
+ x = _normalize_per_feature(x, seq_len, p["CONSTANT"])
118
+ max_len = x.size(-1)
119
+ m = torch.arange(max_len, device=dev).repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
120
+ return x.masked_fill(m.unsqueeze(1), p["pad_value"]), seq_len
121
+
122
+ @torch.no_grad()
123
+ def forward(self, input_features, feature_lengths=None, **kwargs):
124
+ self._ensure_loaded()
125
+ if feature_lengths is None:
126
+ feature_lengths = torch.full((input_features.size(0),), input_features.size(-1),
127
+ dtype=torch.long, device=input_features.device)
128
+ feats = input_features.to(self._anchor.device).to(self._io_dtype)
129
+ enc, enc_len = self.encoder(feats, feature_lengths.to(self._anchor.device))
130
+ return ModelOutput(last_hidden_state=enc, encoder_lengths=enc_len)
131
+
132
+ @torch.no_grad()
133
+ def _greedy_one(self, enc_out, T):
134
+ cfg, dev = self.config, self._anchor.device
135
+ nd, blank, durs = cfg.num_durations, cfg.blank_id, cfg.durations
136
+ h = torch.zeros(cfg.pred_rnn_layers, 1, cfg.pred_hidden, device=dev, dtype=self._io_dtype)
137
+ c = torch.zeros(cfg.pred_rnn_layers, 1, cfg.pred_hidden, device=dev, dtype=self._io_dtype)
138
+ last, toks = blank, []
139
+ tlen = torch.ones(1, dtype=torch.int32, device=dev)
140
+ t = 0
141
+ while t < T:
142
+ f = enc_out.narrow(2, t, 1)
143
+ added, need = 0, True
144
+ while need and added < cfg.max_symbols:
145
+ tgt = torch.tensor([[last]], dtype=torch.int32, device=dev)
146
+ logits, _, h2, c2 = self.decoder_joint(f, tgt, tlen, h, c)
147
+ logits = logits[0, 0, 0]
148
+ k = int(logits[:-nd].argmax().item())
149
+ skip = durs[int(logits[-nd:].argmax().item())]
150
+ if k != blank:
151
+ toks.append(k); h, c, last = h2, c2, k
152
+ added += 1; t += skip; need = (skip == 0)
153
+ if added == cfg.max_symbols:
154
+ t += 1
155
+ return toks
156
+
157
+ @torch.no_grad()
158
+ def generate(self, input_features=None, feature_lengths=None, **kwargs):
159
+ out = self.forward(input_features, feature_lengths)
160
+ enc, enc_len = out.last_hidden_state, out.encoder_lengths
161
+ token_lists = [self._greedy_one(enc[i:i + 1], int(enc_len[i].item())) for i in range(enc.size(0))]
162
+ maxlen = max((len(t) for t in token_lists), default=0)
163
+ pad = self.config.blank_id
164
+ seqs = torch.full((len(token_lists), max(maxlen, 1)), pad, dtype=torch.long)
165
+ for i, t in enumerate(token_lists):
166
+ if t:
167
+ seqs[i, :len(t)] = torch.tensor(t, dtype=torch.long)
168
+ return ASRGreedyOutput(sequences=seqs, token_lists=token_lists)
preproc.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8119ff424215f42c11eaf4e745c12bbe3076d5a037593ab523c1f15ef32f5b2f
3
+ size 135269
processing_fastconformer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HF processor: raw audio -> mel input_features, and token ids -> text (SentencePiece)."""
2
+ import os
3
+ import torch
4
+ import sentencepiece as spm
5
+ from transformers.feature_extraction_utils import BatchFeature
6
+
7
+
8
+ def _normalize_per_feature(x, seq_len, constant):
9
+ B, _, max_time = x.shape
10
+ steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(B, max_time)
11
+ valid = steps < seq_len.unsqueeze(1)
12
+ denom = valid.sum(dim=1)
13
+ mean = torch.where(valid.unsqueeze(1), x, torch.zeros_like(x)).sum(dim=2) / denom.unsqueeze(1)
14
+ var = torch.sum(torch.where(valid.unsqueeze(1), x - mean.unsqueeze(2), torch.zeros_like(x)) ** 2,
15
+ dim=2) / (denom.unsqueeze(1) - 1.0)
16
+ std = torch.sqrt(var).masked_fill(torch.sqrt(var).isnan(), 0.0) + constant
17
+ return (x - mean.unsqueeze(2)) / std.unsqueeze(2)
18
+
19
+
20
+ class FastConformerProcessor:
21
+ def __init__(self, sp, window, fb, params, blank_id=3000, sample_rate=16000):
22
+ self.sp = sp
23
+ self.window = window
24
+ self.fb = fb
25
+ self.p = params
26
+ self.blank_id = blank_id
27
+ self.sample_rate = sample_rate
28
+
29
+ @classmethod
30
+ def register_for_auto_class(cls, auto_class="AutoProcessor"):
31
+ return None
32
+
33
+ @classmethod
34
+ def from_pretrained(cls, path, **kwargs):
35
+ sp = spm.SentencePieceProcessor(model_file=os.path.join(path, "tokenizer.model"))
36
+ pp = torch.load(os.path.join(path, "preproc.pt"), map_location="cpu")
37
+ blank = pp.get("blank_id", 3000)
38
+ return cls(sp, pp["window"], pp["fb"], pp["params"], blank_id=blank)
39
+
40
+ def save_pretrained(self, path, **kwargs):
41
+ os.makedirs(path, exist_ok=True)
42
+
43
+ @staticmethod
44
+ def _to_2d(audio):
45
+ if isinstance(audio, torch.Tensor):
46
+ a = audio
47
+ else:
48
+ import numpy as np
49
+ a = torch.as_tensor(np.asarray(audio), dtype=torch.float32)
50
+ if a.dim() == 1:
51
+ a = a.unsqueeze(0)
52
+ return a.float()
53
+
54
+ @torch.no_grad()
55
+ def __call__(self, audio, sampling_rate=16000, return_tensors="pt"):
56
+ wav = self._to_2d(audio)
57
+ if sampling_rate != self.sample_rate:
58
+ import torchaudio
59
+ wav = torchaudio.functional.resample(wav, sampling_rate, self.sample_rate)
60
+ p = self.p
61
+ n_fft, hop, win = p["n_fft"], p["hop_length"], p["win_length"]
62
+ wav_len = torch.tensor([wav.shape[1]] * wav.shape[0])
63
+ seq_len = torch.div(wav_len + n_fft - n_fft, hop, rounding_mode="floor")
64
+ tmask = torch.arange(wav.shape[1]).unsqueeze(0) < wav_len.unsqueeze(1)
65
+ x = torch.cat((wav[:, :1], wav[:, 1:] - p["preemph"] * wav[:, :-1]), dim=1).masked_fill(~tmask, 0.0)
66
+ spec = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=win, window=self.window,
67
+ center=True, pad_mode="constant", return_complex=True)
68
+ x = torch.view_as_real(spec)
69
+ x = torch.sqrt(x.pow(2).sum(-1)).pow(p["mag_power"])
70
+ x = torch.matmul(self.fb.to(x.dtype), x)
71
+ x = torch.log(x + p["log_zero_guard_value"])
72
+ x = _normalize_per_feature(x, seq_len, p["CONSTANT"])
73
+ max_len = x.size(-1)
74
+ m = torch.arange(max_len).repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
75
+ x = x.masked_fill(m.unsqueeze(1), p["pad_value"])
76
+ return BatchFeature({"input_features": x, "feature_lengths": seq_len}, tensor_type=return_tensors)
77
+
78
+ def _clean(self, ids):
79
+ return [int(i) for i in ids if int(i) != self.blank_id]
80
+
81
+ def batch_decode(self, sequences, **kwargs):
82
+ token_lists = getattr(sequences, "token_lists", None)
83
+ if token_lists is None:
84
+ seqs = getattr(sequences, "sequences", sequences)
85
+ if isinstance(seqs, torch.Tensor):
86
+ token_lists = [self._clean(row.tolist()) for row in seqs]
87
+ else:
88
+ token_lists = [self._clean(row) for row in seqs]
89
+ return [self.sp.decode(t) for t in token_lists]
90
+
91
+ def decode(self, sequence, **kwargs):
92
+ return self.batch_decode([sequence], **kwargs)[0]
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0af4086ec53482ac0cef0375369cf4ce7bafaf8b0a7203e97d126d0599ab90a6
3
+ size 325287