BEST-RQ-2 / audio-embeddings /src /models /components /random_projection_quantizer.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class RandomProjectionQuantizer(nn.Module):
"""Vector quantization using a projection and a randomly initialised codebook
this is useful for models like BEST-RQ for instance.
The output is the indices of the closest code in the codebook for each
time step of the input.
ref: https://arxiv.org/pdf/2202.01855
Arguments
---------
input_dim: int
Input dimension (channels).
cb_dim: int
Size of each code in the codebook.
cb_vocab: int
Number of codes in the codebook
Example
-------
>>> quantiser = RandomProjectionQuantizer(16, 16, 32)
>>> inputs = torch.rand(10, 12, 16)
>>> output = quantiser(inputs)
>>> output.shape
torch.Size([10, 12])
"""
def __init__(self, input_dim, cb_dim, cb_vocab):
super().__init__()
self.input_dim = input_dim
self.cb_dim = cb_dim
self.cb_vocab = cb_vocab
# Section 3.1 "projection matrix A use Xavier initialization"
P_init = torch.empty((input_dim, cb_dim))
self.register_buffer("P", nn.init.xavier_uniform_(P_init))
# normalize random matrix for codebook
self.register_buffer("CB", F.normalize(torch.randn(cb_vocab, cb_dim)))
def forward(self, x):
"""Forward the latent vector to obtain a quantised output"""
x = F.normalize(x @ self.P, dim=-1)
# since both x and CB are normalized, we can just take the argmax of the dot product
return F.linear(x, self.CB).argmax(dim=-1)