BEST-RQ-2 / audio-embeddings /src /models /rqa_jepa_module.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import torch
import functools
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict, Optional, Tuple
from src.models.audio_jepa_module import AudioJEPAModule
from src.models.components.random_projection_quantizer import RandomProjectionQuantizer
class RQAJEPAModule(AudioJEPAModule):
"""
RQA-JEPA Lightning Module.
Extends AudioJEPAModule with Random Projection Quantization loss.
Args:
optimizer (torch.optim.Optimizer): Optimizer configuration.
net (Dict[str, Any]): Configuration for sub-modules.
warmup_pct (float): Percentage of total steps for warmup.
final_lr_ratio (float): Ratio of final learning rate to initial learning rate.
ema_decay (float): Initial EMA decay rate.
ema_end_decay (float): Final EMA decay rate.
ema_anneal_end_step (int): Step at which EMA decay reaches ema_end_decay.
spectrogram_adjustment_mode (str): 'pad' or 'truncate' for spectrogram time dimension.
jepa_criterion (torch.nn.Module): Loss function for JEPA (defaults to MSELoss).
rq_criterion (torch.nn.Module): Loss function for RQ (defaults to CrossEntropyLoss).
rq_lambda (float): Weight for JEPA loss (1 - rq_lambda is used for RQ loss).
codebook_dim (int): Codebook dimension for RandomProjectionQuantizer.
vocab_size (int): Vocabulary size for RandomProjectionQuantizer.
rq_input_type (str): 'teacher' or 'spectrogram'. Source for quantization targets.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
net: Dict[str, Any],
warmup_pct: float = 0.1,
final_lr_ratio: float = 0.001,
ema_decay: float = 0.996,
ema_end_decay: float = 1.0,
ema_anneal_end_step: Optional[int] = None,
spectrogram_adjustment_mode: str = "pad",
jepa_criterion: Optional[torch.nn.Module] = None,
rq_criterion: Optional[torch.nn.Module] = None,
rq_lambda: float = 0.5,
codebook_dim: int = 16,
vocab_size: int = 8192,
rq_input_type: str = "teacher",
):
super().__init__(
optimizer=optimizer,
net=net,
warmup_pct=warmup_pct,
final_lr_ratio=final_lr_ratio,
ema_decay=ema_decay,
ema_end_decay=ema_end_decay,
ema_anneal_end_step=ema_anneal_end_step,
spectrogram_adjustment_mode=spectrogram_adjustment_mode,
criterion=jepa_criterion, # Pass jepa_criterion as criterion to base class
)
self.save_hyperparameters(
logger=False, ignore=["jepa_criterion", "rq_criterion", "net", "optimizer"]
)
self.rq_lambda = rq_lambda
# Store rq_criterion separately
if rq_criterion is not None:
self.rq_criterion = (
rq_criterion()
if isinstance(rq_criterion, (type, functools.partial))
or callable(rq_criterion)
and not isinstance(rq_criterion, nn.Module)
else rq_criterion
)
else:
self.rq_criterion = nn.CrossEntropyLoss()
self.rq_input_type = rq_input_type
if self.rq_input_type not in ["teacher", "spectrogram"]:
raise ValueError(
f"rq_input_type must be 'teacher' or 'spectrogram', got {self.rq_input_type}"
)
# Random Projection Quantizer
# Determine input dimension for quantizer
if self.rq_input_type == "teacher":
# Input to quantizer is teacher output which has encoder_dim
quantizer_input_dim = net.get("encoder", {}).get("embed_dim", 768)
else: # spectrogram
# Input is raw patches
# patch_embed is locally available on self
patch_size = self.patch_embed.patch_size
in_chans = self.patch_embed.in_chans
quantizer_input_dim = patch_size[0] * patch_size[1] * in_chans
self.quantizer = RandomProjectionQuantizer(
input_dim=quantizer_input_dim, cb_dim=codebook_dim, cb_vocab=vocab_size
)
# Freeze quantizer (it is random and fixed)
for p in self.quantizer.parameters():
p.requires_grad = False
# Projection head for RQ prediction
# Takes predictor output (pred_dim) and predicts code indices (vocab_size)
predictor_config = net.get("predictor", {})
predictor_embed_dim = predictor_config.get("embed_dim", 768)
self.rq_proj = nn.Linear(predictor_embed_dim, vocab_size)
def _calculate_combined_loss(
self,
predictions_raw: torch.Tensor,
teacher_targets: torch.Tensor,
rq_logits: torch.Tensor,
rq_targets: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calculates both JEPA and RQ losses and combines them.
"""
# --- JEPA Loss ---
# Project back to encoder dimension for JEPA loss
predictions_jepa = self.predictor_output_proj(
predictions_raw
) # [B, N_mask, encoder_dim]
jepa_loss = self.criterion(
predictions_jepa, teacher_targets
) # Uses self.criterion (mapped from jepa_criterion)
# --- RQ Loss ---
# Calculate Scale RQ Loss
# Flatten for loss calculation
rq_loss = self.rq_criterion(
rq_logits.reshape(-1, self.hparams.vocab_size), rq_targets.reshape(-1)
)
# --- Combine ---
total_loss = self.rq_lambda * jepa_loss + (1 - self.rq_lambda) * rq_loss
return total_loss, jepa_loss, rq_loss
def _get_raw_patches(self, spec: torch.Tensor) -> torch.Tensor:
"""
Extract raw key-value patches from spectrogram.
Args:
spec (torch.Tensor): Adjusted spectrogram [B, C, F, T].
Returns:
torch.Tensor: Flattened patches [B, N, patch_dim]
"""
patch_size = self.patch_embed.patch_size # (H, W)
# Using kernel_size=patch_size, stride=patch_size ensures non-overlapping patches
# F.unfold returns [B, C*pH*pW, L]
patches = F.unfold(spec, kernel_size=patch_size, stride=patch_size) # [B, D, N]
patches = patches.transpose(1, 2) # [B, N, D]
return patches
def _get_rq_targets_input(
self, spec: torch.Tensor, teacher_full: torch.Tensor, mask_indices: torch.Tensor
) -> torch.Tensor:
"""
Helper to get the input for the RQ quantizer (either teacher embeddings or raw patches).
Only returns the targets for the MASKED locations.
"""
if self.rq_input_type == "teacher":
# Teacher targets at masked locations
return teacher_full[:, mask_indices, :] # [B, N_mask, encoder_dim]
else:
# Raw patches at masked locations
# Check if spec is None, which implies logic error in caller
if spec is None:
raise ValueError(
"Spectrogram cannot be None when rq_input_type is 'spectrogram'"
)
raw_patches = self._get_raw_patches(spec) # [B, N, patch_dim]
return raw_patches[:, mask_indices, :] # [B, N_mask, patch_dim]
def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
waveform = batch["waveform"]
patches, current_grid_size = self._process_audio(waveform)
B, N, D = patches.shape
mask = self.mask_generator(1, device=self.device, grid_size=current_grid_size)
mask = mask.expand(B, -1)
student_out = self.compute_student(patches, mask, current_grid_size)
predictions_raw = self.compute_predictor(student_out, mask, current_grid_size)
self._update_teacher()
with torch.no_grad():
teacher_full = self.teacher(patches, grid_size=current_grid_size)
# Prepare targets and logits for RQA-JEPA
m = mask[0]
mask_indices = torch.nonzero(m).flatten()
# Teacher targets always needed for JEPA loss
teacher_targets = teacher_full[:, mask_indices, :] # [B, N_mask, encoder_dim]
# RQ Targets (Quantized)
with torch.no_grad():
# Need spec for 'spectrogram' mode
spec = None
if self.rq_input_type == "spectrogram":
# Re-compute spectrogram as we don't have it exposed from _process_audio
spec = self.spectrogram(waveform)
spec = self._adjust_spectrogram(spec)
rq_targets_input = self._get_rq_targets_input(
spec, teacher_full, mask_indices
)
rq_targets = self.quantizer(rq_targets_input) # [B, N_mask]
# RQ Logits
rq_logits = self.rq_proj(predictions_raw) # [B, N_mask, vocab_size]
loss, jepa_loss, rq_loss = self._calculate_combined_loss(
predictions_raw, teacher_targets, rq_logits, rq_targets
)
self.log(
"train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=B
)
self.log(
"train/jepa_loss", jepa_loss, on_step=True, on_epoch=True, batch_size=B
)
self.log("train/rq_loss", rq_loss, on_step=True, on_epoch=True, batch_size=B)
return loss
def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
waveform = batch["waveform"]
patches, current_grid_size = self._process_audio(waveform)
B, N, D = patches.shape
mask = self.mask_generator(1, device=self.device, grid_size=current_grid_size)
mask = mask.expand(B, -1)
student_out = self.compute_student(patches, mask, current_grid_size)
predictions_raw = self.compute_predictor(student_out, mask, current_grid_size)
with torch.no_grad():
teacher_full = self.teacher(patches, grid_size=current_grid_size)
# Prepare targets and logits for RQA-JEPA
m = mask[0]
mask_indices = torch.nonzero(m).flatten()
# Teacher targets at masked locations
teacher_targets = teacher_full[
:, mask_indices, :
] # [B, N_mask, encoder_dim]
# RQ Targets (Quantized)
spec = None
if self.rq_input_type == "spectrogram":
spec = self.spectrogram(waveform)
spec = self._adjust_spectrogram(spec)
rq_targets_input = self._get_rq_targets_input(
spec, teacher_full, mask_indices
)
rq_targets = self.quantizer(rq_targets_input) # [B, N_mask]
# RQ Logits
rq_logits = self.rq_proj(predictions_raw) # [B, N_mask, vocab_size]
loss, jepa_loss, rq_loss = self._calculate_combined_loss(
predictions_raw, teacher_targets, rq_logits, rq_targets
)
self.log(
"val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=B
)
self.log("val/jepa_loss", jepa_loss, on_step=False, on_epoch=True, batch_size=B)
self.log("val/rq_loss", rq_loss, on_step=False, on_epoch=True, batch_size=B)
return loss