import torch import matplotlib.pyplot as plt import numpy as np import lightning as L from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import WandbLogger from typing import Any, Dict, Optional class VisualizationCallback(Callback): """ Callback to visualize spectrograms, patches, and masks. Logs the first 4 samples of the first 2 batches. """ def __init__(self, num_samples: int = 4): super().__init__() self.num_samples = num_samples self.batches_logged = 0 def on_train_batch_end( self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Any, batch: Any, batch_idx: int, ) -> None: if self.batches_logged >= 2: return # Log for the first 2 batches if batch_idx < 2: self._log_visualizations(trainer, pl_module, batch, batch_idx) self.batches_logged += 1 def _log_visualizations( self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Dict[str, Any], batch_idx: int, ) -> None: logger = trainer.logger if not isinstance(logger, WandbLogger): return waveform = batch["waveform"][: self.num_samples] # [B, 1, T] sample_rate = self._resolve_sample_rate(trainer, pl_module) # Get spectrograms with torch.no_grad(): spec = pl_module.spectrogram(waveform.to(pl_module.device)) # [B, 1, F, T] # Get grid size and patch info patch_size = pl_module.patch_embed.patch_embed.patch_size F_pix = spec.shape[2] T_pix = spec.shape[3] H_grid = F_pix // patch_size[0] W_grid = T_pix // patch_size[1] current_grid_size = (H_grid, W_grid) # Generate mask # Using the same logic as training step (shared mask across batch) # But we want to see if it's the same across batches (it should be random each step) mask = pl_module.mask_generator( 1, device=pl_module.device, grid_size=current_grid_size ) # [1, N] mask = mask.expand(self.num_samples, -1) # [B, N] # Log to WandB import wandb columns = [ "Batch Idx", "Sample Idx", "Audio", "Spectrogram", "Masked Spectrogram (Context)", "Inverse Masked Spectrogram (Targets)", ] data = [] for i in range(self.num_samples): # Audio audio_data = waveform[i].squeeze().cpu().numpy() audio = wandb.Audio( audio_data, sample_rate=sample_rate, caption=f"B{batch_idx}_S{i}" ) # Spectrograms spec_data = spec[i].squeeze().cpu().numpy() mask_data = mask[i].cpu().numpy() # 1. Original fig_orig = self._plot_spectrogram(spec_data, patch_size, current_grid_size) img_orig = wandb.Image(fig_orig, caption=f"Spec B{batch_idx}_S{i}") plt.close(fig_orig) # 2. Masked (Context) - Masked parts are dark fig_masked = self._plot_spectrogram_with_mask( spec_data, mask_data, patch_size, current_grid_size, invert_mask=False ) img_masked = wandb.Image(fig_masked, caption=f"Masked B{batch_idx}_S{i}") plt.close(fig_masked) # 3. Inverse Masked (Targets) - Context parts are dark fig_inv_masked = self._plot_spectrogram_with_mask( spec_data, mask_data, patch_size, current_grid_size, invert_mask=True ) img_inv_masked = wandb.Image( fig_inv_masked, caption=f"InvMasked B{batch_idx}_S{i}" ) plt.close(fig_inv_masked) data.append([batch_idx, i, audio, img_orig, img_masked, img_inv_masked]) # Log Table table = wandb.Table(columns=columns, data=data) logger.experiment.log({f"train/visualizations_batch_{batch_idx}": table}) @staticmethod def _resolve_sample_rate(trainer: L.Trainer, pl_module: L.LightningModule) -> int: """Resolve audio logging sample rate, preferring data target sample rate.""" sample_rate = 32000 datamodule = getattr(trainer, "datamodule", None) if datamodule is not None: dm_sr = getattr(datamodule, "target_sample_rate", None) if dm_sr is None and hasattr(datamodule, "hparams"): hparams = datamodule.hparams if isinstance(hparams, dict): dm_sr = hparams.get("target_sample_rate") else: dm_sr = getattr(hparams, "target_sample_rate", None) if dm_sr is not None: return int(dm_sr) spectrogram = getattr(pl_module, "spectrogram", None) module_sr = getattr(spectrogram, "sample_rate", None) if module_sr is not None: return int(module_sr) hparams = getattr(pl_module, "hparams", None) if isinstance(hparams, dict): net_cfg = hparams.get("net") if isinstance(net_cfg, dict): spectrogram_cfg = net_cfg.get("spectrogram") if isinstance(spectrogram_cfg, dict): config_sr = spectrogram_cfg.get("sample_rate") if config_sr is not None: return int(config_sr) return sample_rate def _plot_spectrogram( self, spec: np.ndarray, patch_size: tuple[int, int], grid_size: tuple[int, int] ) -> plt.Figure: """Plots spectrogram with grid lines.""" return self._plot_spectrogram_with_mask(spec, None, patch_size, grid_size) def _plot_spectrogram_with_mask( self, spec: np.ndarray, mask: Optional[np.ndarray], patch_size: tuple[int, int], grid_size: tuple[int, int], invert_mask: bool = False, ) -> plt.Figure: """ Plots spectrogram with dashed grid lines and darker masked patches. If mask is None, just plots spectrogram and grid. If invert_mask is True, darkens the unmasked parts instead. """ H_grid, W_grid = grid_size Ph, Pw = patch_size H, W = spec.shape fig, ax = plt.subplots(figsize=(10, 4)) ax.imshow(spec, origin="lower", aspect="auto", cmap="viridis") # Overlay Grid for h in range(0, H + 1, Ph): ax.axhline(h - 0.5, color="white", linestyle="--", linewidth=0.5, alpha=0.5) for w in range(0, W + 1, Pw): ax.axvline(w - 0.5, color="white", linestyle="--", linewidth=0.5, alpha=0.5) # Overlay Mask if mask is not None: mask_grid = mask.reshape(H_grid, W_grid) if invert_mask: mask_grid = ~mask_grid overlay = np.zeros((H, W, 4)) # RGBA for r in range(H_grid): for c in range(W_grid): if mask_grid[r, c]: y_start = r * Ph y_end = (r + 1) * Ph x_start = c * Pw x_end = (c + 1) * Pw overlay[y_start:y_end, x_start:x_end, 3] = 0.7 ax.imshow(overlay, origin="lower", aspect="auto") ax.set_title("Spectrogram") ax.set_xlabel("Time Frames") ax.set_ylabel("Frequency Bins") plt.tight_layout() return fig