File size: 4,043 Bytes
ffb2413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""HF processor: raw audio -> mel input_features, and token ids -> text (SentencePiece)."""
import os
import torch
import sentencepiece as spm
from transformers.feature_extraction_utils import BatchFeature


def _normalize_per_feature(x, seq_len, constant):
    B, _, max_time = x.shape
    steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(B, max_time)
    valid = steps < seq_len.unsqueeze(1)
    denom = valid.sum(dim=1)
    mean = torch.where(valid.unsqueeze(1), x, torch.zeros_like(x)).sum(dim=2) / denom.unsqueeze(1)
    var = torch.sum(torch.where(valid.unsqueeze(1), x - mean.unsqueeze(2), torch.zeros_like(x)) ** 2,
                    dim=2) / (denom.unsqueeze(1) - 1.0)
    std = torch.sqrt(var).masked_fill(torch.sqrt(var).isnan(), 0.0) + constant
    return (x - mean.unsqueeze(2)) / std.unsqueeze(2)


class FastConformerProcessor:
    def __init__(self, sp, window, fb, params, blank_id=3000, sample_rate=16000):
        self.sp = sp
        self.window = window
        self.fb = fb
        self.p = params
        self.blank_id = blank_id
        self.sample_rate = sample_rate

    @classmethod
    def register_for_auto_class(cls, auto_class="AutoProcessor"):
        return None

    @classmethod
    def from_pretrained(cls, path, **kwargs):
        sp = spm.SentencePieceProcessor(model_file=os.path.join(path, "tokenizer.model"))
        pp = torch.load(os.path.join(path, "preproc.pt"), map_location="cpu")
        blank = pp.get("blank_id", 3000)
        return cls(sp, pp["window"], pp["fb"], pp["params"], blank_id=blank)

    def save_pretrained(self, path, **kwargs):
        os.makedirs(path, exist_ok=True)

    @staticmethod
    def _to_2d(audio):
        if isinstance(audio, torch.Tensor):
            a = audio
        else:
            import numpy as np
            a = torch.as_tensor(np.asarray(audio), dtype=torch.float32)
        if a.dim() == 1:
            a = a.unsqueeze(0)
        return a.float()

    @torch.no_grad()
    def __call__(self, audio, sampling_rate=16000, return_tensors="pt"):
        wav = self._to_2d(audio)
        if sampling_rate != self.sample_rate:
            import torchaudio
            wav = torchaudio.functional.resample(wav, sampling_rate, self.sample_rate)
        p = self.p
        n_fft, hop, win = p["n_fft"], p["hop_length"], p["win_length"]
        wav_len = torch.tensor([wav.shape[1]] * wav.shape[0])
        seq_len = torch.div(wav_len + n_fft - n_fft, hop, rounding_mode="floor")
        tmask = torch.arange(wav.shape[1]).unsqueeze(0) < wav_len.unsqueeze(1)
        x = torch.cat((wav[:, :1], wav[:, 1:] - p["preemph"] * wav[:, :-1]), dim=1).masked_fill(~tmask, 0.0)
        spec = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=win, window=self.window,
                          center=True, pad_mode="constant", return_complex=True)
        x = torch.view_as_real(spec)
        x = torch.sqrt(x.pow(2).sum(-1)).pow(p["mag_power"])
        x = torch.matmul(self.fb.to(x.dtype), x)
        x = torch.log(x + p["log_zero_guard_value"])
        x = _normalize_per_feature(x, seq_len, p["CONSTANT"])
        max_len = x.size(-1)
        m = torch.arange(max_len).repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
        x = x.masked_fill(m.unsqueeze(1), p["pad_value"])
        return BatchFeature({"input_features": x, "feature_lengths": seq_len}, tensor_type=return_tensors)

    def _clean(self, ids):
        return [int(i) for i in ids if int(i) != self.blank_id]

    def batch_decode(self, sequences, **kwargs):
        token_lists = getattr(sequences, "token_lists", None)
        if token_lists is None:
            seqs = getattr(sequences, "sequences", sequences)
            if isinstance(seqs, torch.Tensor):
                token_lists = [self._clean(row.tolist()) for row in seqs]
            else:
                token_lists = [self._clean(row) for row in seqs]
        return [self.sp.decode(t) for t in token_lists]

    def decode(self, sequence, **kwargs):
        return self.batch_decode([sequence], **kwargs)[0]