"""HF processor: raw audio -> mel input_features, and token ids -> text (SentencePiece).""" import os import torch import sentencepiece as spm from transformers.feature_extraction_utils import BatchFeature def _normalize_per_feature(x, seq_len, constant): B, _, max_time = x.shape steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(B, max_time) valid = steps < seq_len.unsqueeze(1) denom = valid.sum(dim=1) mean = torch.where(valid.unsqueeze(1), x, torch.zeros_like(x)).sum(dim=2) / denom.unsqueeze(1) var = torch.sum(torch.where(valid.unsqueeze(1), x - mean.unsqueeze(2), torch.zeros_like(x)) ** 2, dim=2) / (denom.unsqueeze(1) - 1.0) std = torch.sqrt(var).masked_fill(torch.sqrt(var).isnan(), 0.0) + constant return (x - mean.unsqueeze(2)) / std.unsqueeze(2) class FastConformerProcessor: def __init__(self, sp, window, fb, params, blank_id=3000, sample_rate=16000): self.sp = sp self.window = window self.fb = fb self.p = params self.blank_id = blank_id self.sample_rate = sample_rate @classmethod def register_for_auto_class(cls, auto_class="AutoProcessor"): return None @classmethod def from_pretrained(cls, path, **kwargs): sp = spm.SentencePieceProcessor(model_file=os.path.join(path, "tokenizer.model")) pp = torch.load(os.path.join(path, "preproc.pt"), map_location="cpu") blank = pp.get("blank_id", 3000) return cls(sp, pp["window"], pp["fb"], pp["params"], blank_id=blank) def save_pretrained(self, path, **kwargs): os.makedirs(path, exist_ok=True) @staticmethod def _to_2d(audio): if isinstance(audio, torch.Tensor): a = audio else: import numpy as np a = torch.as_tensor(np.asarray(audio), dtype=torch.float32) if a.dim() == 1: a = a.unsqueeze(0) return a.float() @torch.no_grad() def __call__(self, audio, sampling_rate=16000, return_tensors="pt"): wav = self._to_2d(audio) if sampling_rate != self.sample_rate: import torchaudio wav = torchaudio.functional.resample(wav, sampling_rate, self.sample_rate) p = self.p n_fft, hop, win = p["n_fft"], p["hop_length"], p["win_length"] wav_len = torch.tensor([wav.shape[1]] * wav.shape[0]) seq_len = torch.div(wav_len + n_fft - n_fft, hop, rounding_mode="floor") tmask = torch.arange(wav.shape[1]).unsqueeze(0) < wav_len.unsqueeze(1) x = torch.cat((wav[:, :1], wav[:, 1:] - p["preemph"] * wav[:, :-1]), dim=1).masked_fill(~tmask, 0.0) spec = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=win, window=self.window, center=True, pad_mode="constant", return_complex=True) x = torch.view_as_real(spec) x = torch.sqrt(x.pow(2).sum(-1)).pow(p["mag_power"]) x = torch.matmul(self.fb.to(x.dtype), x) x = torch.log(x + p["log_zero_guard_value"]) x = _normalize_per_feature(x, seq_len, p["CONSTANT"]) max_len = x.size(-1) m = torch.arange(max_len).repeat(x.size(0), 1) >= seq_len.unsqueeze(1) x = x.masked_fill(m.unsqueeze(1), p["pad_value"]) return BatchFeature({"input_features": x, "feature_lengths": seq_len}, tensor_type=return_tensors) def _clean(self, ids): return [int(i) for i in ids if int(i) != self.blank_id] def batch_decode(self, sequences, **kwargs): token_lists = getattr(sequences, "token_lists", None) if token_lists is None: seqs = getattr(sequences, "sequences", sequences) if isinstance(seqs, torch.Tensor): token_lists = [self._clean(row.tolist()) for row in seqs] else: token_lists = [self._clean(row) for row in seqs] return [self.sp.decode(t) for t in token_lists] def decode(self, sequence, **kwargs): return self.batch_decode([sequence], **kwargs)[0]