| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import lightning as L |
| from typing import Any, Dict, Tuple |
|
|
| from src.models.components.spectrogram import Spectrogram |
| from src.models.components.masking import MaskingGenerator |
| from src.models.components.patch_embed import PatchEmbed |
| from src.models.components.vit import ViT |
| from src.models.components.random_projection_quantizer import RandomProjectionQuantizer |
| from src.utils.lr_schedulers import LinearWarmupCosineDecay |
|
|
|
|
| class BestRQModule(L.LightningModule): |
| """ |
| Best-RQ Lightning Module. |
| |
| Implements a single-stage Masked Audio Modeling approach using Random Projection Quantization targets. |
| |
| Args: |
| optimizer (torch.optim.Optimizer): Optimizer configuration. |
| net (Dict[str, Any]): Configuration for sub-modules. |
| warmup_pct (float): Percentage of total steps for warmup. |
| final_lr_ratio (float): Ratio of final learning rate to initial learning rate. |
| spectrogram_adjustment_mode (str): 'pad' or 'truncate' for spectrogram time dimension. |
| codebook_dim (int): Codebook dimension for RandomProjectionQuantizer. |
| vocab_size (int): Vocabulary size for RandomProjectionQuantizer. |
| """ |
|
|
| def __init__( |
| self, |
| optimizer: torch.optim.Optimizer, |
| net: Dict[str, Any], |
| warmup_pct: float = 0.1, |
| final_lr_ratio: float = 0.001, |
| spectrogram_adjustment_mode: str = "pad", |
| codebook_dim: int = 16, |
| vocab_size: int = 8192, |
| ): |
| super().__init__() |
| self.save_hyperparameters(logger=False, ignore=["net", "optimizer"]) |
|
|
| self.warmup_pct = warmup_pct |
| self.final_lr_ratio = final_lr_ratio |
| self.spectrogram_adjustment_mode = spectrogram_adjustment_mode |
| self.vocab_size = vocab_size |
|
|
| |
| self.optimizer_config = optimizer |
|
|
| |
| self.spectrogram = Spectrogram(**net.get("spectrogram", {})) |
| self.patch_embed = PatchEmbed(**net.get("patch_embed", {})) |
| self.mask_generator = MaskingGenerator(**net.get("masking", {})) |
|
|
| |
| self.encoder = ViT(**net.get("encoder", {})) |
|
|
| |
| encoder_dim = net.get("encoder", {}).get("embed_dim", 768) |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, encoder_dim)) |
| nn.init.trunc_normal_(self.mask_token, std=0.02) |
|
|
| |
| |
| patch_size = self.patch_embed.patch_size |
| in_chans = self.patch_embed.in_chans |
| quantizer_input_dim = patch_size[0] * patch_size[1] * in_chans |
|
|
| self.quantizer = RandomProjectionQuantizer( |
| input_dim=quantizer_input_dim, cb_dim=codebook_dim, cb_vocab=vocab_size |
| ) |
| |
| for p in self.quantizer.parameters(): |
| p.requires_grad = False |
|
|
| |
| self.output_proj = nn.Linear(encoder_dim, vocab_size) |
|
|
| |
| self.criterion = nn.CrossEntropyLoss() |
|
|
| def _adjust_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: |
| """ |
| Adjusts the spectrogram time dimension to be divisible by the patch size. |
| """ |
| patch_size = self.patch_embed.patch_embed.patch_size |
| patch_time_dim = patch_size[1] |
|
|
| T = spec.shape[-1] |
| remainder = T % patch_time_dim |
|
|
| if remainder != 0: |
| if self.spectrogram_adjustment_mode == "pad": |
| pad_amount = patch_time_dim - remainder |
| spec = F.pad(spec, (0, pad_amount)) |
| elif self.spectrogram_adjustment_mode == "truncate": |
| spec = spec[..., : T - remainder] |
| else: |
| raise ValueError( |
| f"Unknown spectrogram_adjustment_mode: {self.spectrogram_adjustment_mode}" |
| ) |
|
|
| return spec |
|
|
| def _process_audio( |
| self, waveform: torch.Tensor |
| ) -> Tuple[torch.Tensor, Tuple[int, int]]: |
| """ |
| Processes raw waveform into patches and returns patches and grid size. |
| """ |
| |
| spec = self.spectrogram(waveform) |
| spec = self._adjust_spectrogram(spec) |
|
|
| |
| patches = self.patch_embed(spec) |
|
|
| |
| patch_size = self.patch_embed.patch_embed.patch_size |
| F_pix = spec.shape[2] |
| T_pix = spec.shape[3] |
| H_grid = F_pix // patch_size[0] |
| W_grid = T_pix // patch_size[1] |
| grid_size = (H_grid, W_grid) |
|
|
| return patches, grid_size |
|
|
| def _get_raw_patches(self, spec: torch.Tensor) -> torch.Tensor: |
| """ |
| Extract raw key-value patches from spectrogram for quantization. |
| """ |
| patch_size = self.patch_embed.patch_size |
|
|
| |
| |
| |
|
|
| patches = F.unfold(spec, kernel_size=patch_size, stride=patch_size) |
| patches = patches.transpose(1, 2) |
|
|
| return patches |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass for inference/eval. Returns encoder representation. |
| """ |
| patches, grid_size = self._process_audio(x) |
| x = self.encoder(patches, grid_size=grid_size) |
| return x |
|
|
| def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: |
| waveform = batch["waveform"] |
|
|
| |
| patches, current_grid_size = self._process_audio(waveform) |
| B, N, D = patches.shape |
|
|
| |
| mask = self.mask_generator(1, device=self.device, grid_size=current_grid_size) |
| mask = mask.expand(B, -1) |
|
|
| |
| encoder_input = patches.clone() |
| mask_tokens_expanded = self.mask_token.expand(B, N, -1) |
|
|
| |
| mask_bool = mask.bool() |
| encoder_input[mask_bool] = mask_tokens_expanded[mask_bool] |
|
|
| |
| |
| |
| encoder_out = self.encoder( |
| encoder_input, grid_size=current_grid_size |
| ) |
|
|
| |
| with torch.no_grad(): |
| |
| spec = self.spectrogram(waveform) |
| spec = self._adjust_spectrogram(spec) |
| raw_patches = self._get_raw_patches(spec) |
|
|
| |
| m = mask[0] |
| mask_indices = torch.nonzero(m).flatten() |
|
|
| target_input = raw_patches[:, mask_indices, :] |
| targets = self.quantizer(target_input) |
|
|
| |
| |
| predictions = encoder_out[:, mask_indices, :] |
| logits = self.output_proj(predictions) |
|
|
| |
| loss = self.criterion(logits.reshape(-1, self.vocab_size), targets.reshape(-1)) |
|
|
| self.log( |
| "train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=B |
| ) |
| return loss |
|
|
| def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: |
| waveform = batch["waveform"] |
|
|
| patches, current_grid_size = self._process_audio(waveform) |
| B, N, D = patches.shape |
|
|
| mask = self.mask_generator(1, device=self.device, grid_size=current_grid_size) |
| mask = mask.expand(B, -1) |
|
|
| encoder_input = patches.clone() |
| mask_tokens_expanded = self.mask_token.expand(B, N, -1) |
| mask_bool = mask.bool() |
| encoder_input[mask_bool] = mask_tokens_expanded[mask_bool] |
|
|
| encoder_out = self.encoder(encoder_input, grid_size=current_grid_size) |
|
|
| with torch.no_grad(): |
| spec = self.spectrogram(waveform) |
| spec = self._adjust_spectrogram(spec) |
| raw_patches = self._get_raw_patches(spec) |
|
|
| m = mask[0] |
| mask_indices = torch.nonzero(m).flatten() |
|
|
| target_input = raw_patches[:, mask_indices, :] |
| targets = self.quantizer(target_input) |
|
|
| predictions = encoder_out[:, mask_indices, :] |
| logits = self.output_proj(predictions) |
|
|
| loss = self.criterion(logits.reshape(-1, self.vocab_size), targets.reshape(-1)) |
|
|
| self.log( |
| "val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=B |
| ) |
| return loss |
|
|
| def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: |
| return self.validation_step(batch, batch_idx) |
|
|
| def configure_optimizers(self) -> Dict[str, Any]: |
| optimizer = self.optimizer_config(params=self.parameters()) |
|
|
| if self.trainer.max_steps and self.trainer.max_steps > 0: |
| total_steps = self.trainer.max_steps |
| else: |
| total_steps = self.trainer.estimated_stepping_batches |
|
|
| warmup_steps = int(total_steps * self.warmup_pct) |
|
|
| lr_lambda = LinearWarmupCosineDecay( |
| warmup_steps=warmup_steps, |
| total_steps=total_steps, |
| final_lr_ratio=self.final_lr_ratio, |
| ) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| "scheduler": scheduler, |
| "monitor": "val_loss", |
| "interval": "step", |
| "frequency": 1, |
| }, |
| } |
|
|