Submission to the Interspeech 2026 Audio Encoder Capability Challenge
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- BEST-RQ-2.safetensors +3 -0
- BEST-RQ-2_encoder.py +233 -0
- README.md +35 -3
- audio-embeddings/.githooks/post-checkout +13 -0
- audio-embeddings/.githooks/post-merge +13 -0
- audio-embeddings/.githooks/post-rewrite +13 -0
- audio-embeddings/.gitignore +29 -0
- audio-embeddings/.pre-commit-config.yaml +39 -0
- audio-embeddings/.project-root +0 -0
- audio-embeddings/.python-version +1 -0
- audio-embeddings/AGENTS.md +150 -0
- audio-embeddings/README.md +142 -0
- audio-embeddings/THIRD_PARTY_LICENSES.md +13 -0
- audio-embeddings/configs/__init__.py +1 -0
- audio-embeddings/configs/callbacks/default.yaml +34 -0
- audio-embeddings/configs/callbacks/early_stopping.yaml +15 -0
- audio-embeddings/configs/callbacks/ema_weight_averaging.yaml +9 -0
- audio-embeddings/configs/callbacks/model_checkpoint.yaml +17 -0
- audio-embeddings/configs/callbacks/model_summary.yaml +5 -0
- audio-embeddings/configs/callbacks/none.yaml +0 -0
- audio-embeddings/configs/callbacks/rich_progress_bar.yaml +4 -0
- audio-embeddings/configs/callbacks/wandb_offline.yaml +2 -0
- audio-embeddings/configs/data/audioset.yaml +12 -0
- audio-embeddings/configs/data/mock_audioset.yaml +7 -0
- audio-embeddings/configs/data/yt1b.yaml +13 -0
- audio-embeddings/configs/experiment/audio_jepa/baseline.yaml +34 -0
- audio-embeddings/configs/experiment/audio_jepa/large.yaml +44 -0
- audio-embeddings/configs/experiment/audio_jepa/rope.yaml +41 -0
- audio-embeddings/configs/experiment/audio_jepa/time_res2x.yaml +54 -0
- audio-embeddings/configs/experiment/audio_jepa/time_res4x.yaml +54 -0
- audio-embeddings/configs/experiment/best_rq/audioset.yaml +33 -0
- audio-embeddings/configs/experiment/best_rq/yt1b.yaml +31 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset.yaml +33 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_100k_512bs.yaml +30 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_1m_128bs_4gpu.yaml +30 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_200k_256bs_4gpu.yaml +30 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs.yaml +30 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs_4gpu.yaml +30 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_800k_64bs_4gpu.yaml +30 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_ema.yaml +35 -0
- audio-embeddings/configs/experiment/best_rq_2/audioset_ema_600k.yaml +36 -0
- audio-embeddings/configs/experiment/best_rq_2/yt1b_ema.yaml +37 -0
- audio-embeddings/configs/experiment/local/audio_jepa.yaml +29 -0
- audio-embeddings/configs/experiment/local/audio_jepa_rope.yaml +36 -0
- audio-embeddings/configs/experiment/local/best_rq.yaml +29 -0
- audio-embeddings/configs/experiment/local/best_rq2.yaml +38 -0
- audio-embeddings/configs/experiment/local/m4_mock_jepa.yaml +37 -0
- audio-embeddings/configs/experiment/local/rqa_jepa.yaml +29 -0
- audio-embeddings/configs/experiment/rqa_jepa/audioset.yaml +33 -0
- audio-embeddings/configs/experiment/rqa_jepa/yt1b.yaml +31 -0
BEST-RQ-2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7111465e6c868e3d0b55c5fe9a23dc5069ac80beba8444c5ee9db4691d796899
|
| 3 |
+
size 483870856
|
BEST-RQ-2_encoder.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
|
| 10 |
+
# Add audio-embeddings to path dynamically
|
| 11 |
+
# We assume audio-embeddings is a sibling directory to xares-llm or provided via env var
|
| 12 |
+
# Prioritize absolute path if known, otherwise relative
|
| 13 |
+
POSSIBLE_PATHS = [
|
| 14 |
+
# "/media/ltuncay/Shared-4TB/dev/audio-embeddings",
|
| 15 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "audio-embeddings")),
|
| 16 |
+
# os.path.abspath(os.path.join(os.getcwd(), "../audio-embeddings")),
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
AUDIO_EMBEDDINGS_PATH = None
|
| 20 |
+
for p in POSSIBLE_PATHS:
|
| 21 |
+
if os.path.exists(p):
|
| 22 |
+
AUDIO_EMBEDDINGS_PATH = p
|
| 23 |
+
break
|
| 24 |
+
|
| 25 |
+
if AUDIO_EMBEDDINGS_PATH:
|
| 26 |
+
if AUDIO_EMBEDDINGS_PATH not in sys.path:
|
| 27 |
+
sys.path.append(AUDIO_EMBEDDINGS_PATH)
|
| 28 |
+
print(f"Added {AUDIO_EMBEDDINGS_PATH} to sys.path")
|
| 29 |
+
else:
|
| 30 |
+
print(
|
| 31 |
+
"Warning: audio-embeddings path not found. Imports may fail if not installed in environment."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from src.models.best_rq2_module import BestRQ2Module
|
| 36 |
+
except ImportError as e:
|
| 37 |
+
raise ImportError(
|
| 38 |
+
f"Could not import src.models.best_rq2_module. Ensure audio-embeddings is correctly located or installed. Error: {e}"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BestRQ2Encoder(nn.Module):
|
| 43 |
+
def __init__(self, checkpoint_path=None, model_config_path=None, **kwargs):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
base_path = os.path.dirname(__file__)
|
| 47 |
+
model_config_path = os.path.join(base_path, "config.yaml")
|
| 48 |
+
checkpoint_path = os.path.join(base_path, "BEST-RQ-2.safetensors")
|
| 49 |
+
|
| 50 |
+
if not os.path.exists(model_config_path):
|
| 51 |
+
raise FileNotFoundError(f"Config not found at {model_config_path}")
|
| 52 |
+
|
| 53 |
+
if not checkpoint_path or not os.path.exists(checkpoint_path):
|
| 54 |
+
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
|
| 55 |
+
|
| 56 |
+
print(f"Loading BestRQ2 config from {model_config_path}")
|
| 57 |
+
cfg = OmegaConf.load(model_config_path)
|
| 58 |
+
|
| 59 |
+
print(f"Loading BestRQ2 checkpoint from {checkpoint_path}")
|
| 60 |
+
|
| 61 |
+
# Reconstruct model args from config
|
| 62 |
+
model_cfg = cfg.model
|
| 63 |
+
net_cfg = model_cfg.net
|
| 64 |
+
|
| 65 |
+
# Instantiate model
|
| 66 |
+
# Note: BestRQ2Module inherits from LightningModule
|
| 67 |
+
self.module = BestRQ2Module(
|
| 68 |
+
optimizer=None, # Not needed for inference
|
| 69 |
+
net=net_cfg,
|
| 70 |
+
warmup_pct=model_cfg.get("warmup_pct", 0.1),
|
| 71 |
+
final_lr_ratio=model_cfg.get("final_lr_ratio", 0.001),
|
| 72 |
+
spectrogram_adjustment_mode=model_cfg.get(
|
| 73 |
+
"spectrogram_adjustment_mode", "pad"
|
| 74 |
+
),
|
| 75 |
+
codebook_dim=model_cfg.get("codebook_dim", 16),
|
| 76 |
+
vocab_size=model_cfg.get("vocab_size", 8192),
|
| 77 |
+
criterion=None,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Load weights
|
| 81 |
+
try:
|
| 82 |
+
state_dict = load_file(checkpoint_path)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error loading safetensors: {e}. Trying torch.load...")
|
| 85 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
| 86 |
+
if "state_dict" in state_dict:
|
| 87 |
+
state_dict = state_dict["state_dict"]
|
| 88 |
+
|
| 89 |
+
# Handle 'module.' prefix if present in checkpoint vs model
|
| 90 |
+
# Usually LightningModules save with state_dict keys matching model attributes.
|
| 91 |
+
# But sometimes they might be wrapped.
|
| 92 |
+
# We will try loading strict=False and inspect.
|
| 93 |
+
|
| 94 |
+
missing, unexpected = self.module.load_state_dict(state_dict, strict=False)
|
| 95 |
+
if missing:
|
| 96 |
+
# Check if prefixes match
|
| 97 |
+
# If all missing keys start with something common, or if state_dict has prefixes
|
| 98 |
+
print(f"Warning: {len(missing)} keys missing during loading.")
|
| 99 |
+
# print(missing[:5])
|
| 100 |
+
if unexpected:
|
| 101 |
+
print(f"Warning: {len(unexpected)} keys unexpected during loading.")
|
| 102 |
+
|
| 103 |
+
self.module.eval()
|
| 104 |
+
self.output_dim = net_cfg.encoder.embed_dim
|
| 105 |
+
|
| 106 |
+
# Extract dynamic parameters for length handling
|
| 107 |
+
try:
|
| 108 |
+
# 1. Sample Rate & Hop Length (from Spectrogram)
|
| 109 |
+
# BestRQ2Module -> Spectrogram -> MelSpectrogram -> hop_length
|
| 110 |
+
self.sample_rate = self.module.spectrogram.mel_spec.sample_rate
|
| 111 |
+
self.hop_length = self.module.spectrogram.mel_spec.hop_length
|
| 112 |
+
|
| 113 |
+
# 2. Patch Size (Time dimension)
|
| 114 |
+
# BestRQ2Module -> PatchEmbed -> patch_size (H, W) -> W is time
|
| 115 |
+
self.patch_size_time = self.module.patch_embed.patch_size[1]
|
| 116 |
+
|
| 117 |
+
# 3. Max Input Frames (Time dimension)
|
| 118 |
+
# BestRQ2Module -> PatchEmbed -> img_size (H, W) -> W is time frames
|
| 119 |
+
self.max_frames = self.module.patch_embed.img_size[1]
|
| 120 |
+
|
| 121 |
+
# Calculations
|
| 122 |
+
# Minimum samples required to get at least 1 patch width in spectrogram
|
| 123 |
+
# We need T_spec >= patch_size_time
|
| 124 |
+
# T_spec = T_samples // hop_length (roughly)
|
| 125 |
+
# So T_samples >= patch_size_time * hop_length
|
| 126 |
+
self.min_samples = self.patch_size_time * self.hop_length
|
| 127 |
+
|
| 128 |
+
# Chunk size: The maximum audio length the model's positional embeddings can handle
|
| 129 |
+
# T_samples_max = max_frames * hop_length
|
| 130 |
+
self.chunk_samples = self.max_frames * self.hop_length
|
| 131 |
+
|
| 132 |
+
print(
|
| 133 |
+
f"BestRQ2Encoder constraints: Min Samples={self.min_samples}, Chunk Samples={self.chunk_samples}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"Warning: Could not extract dynamic length constraints: {e}")
|
| 138 |
+
print("Falling back to safe defaults (1s min, 10s chunk)")
|
| 139 |
+
self.min_samples = 16000
|
| 140 |
+
self.chunk_samples = 16000 * 10
|
| 141 |
+
|
| 142 |
+
def _forward_chunk(self, audio_chunk: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
"""Helper to process a single time-chunk of audio."""
|
| 144 |
+
# Determine target device from the spectrogram window (safest for STFT)
|
| 145 |
+
try:
|
| 146 |
+
target_device = self.module.spectrogram.mel_spec.spectrogram.window.device
|
| 147 |
+
except AttributeError:
|
| 148 |
+
if hasattr(self.module.spectrogram.mel_spec, "window"):
|
| 149 |
+
target_device = self.module.spectrogram.mel_spec.window.device
|
| 150 |
+
else:
|
| 151 |
+
target_device = self.module.device
|
| 152 |
+
|
| 153 |
+
if audio_chunk.device != target_device:
|
| 154 |
+
audio_chunk = audio_chunk.to(target_device)
|
| 155 |
+
|
| 156 |
+
# BestRQ2Module expects [B, C, T]
|
| 157 |
+
if audio_chunk.ndim == 2:
|
| 158 |
+
audio_chunk = audio_chunk.unsqueeze(1) # [B, 1, T]
|
| 159 |
+
|
| 160 |
+
# _process_audio returns (patches, grid_size)
|
| 161 |
+
patches, grid_size = self.module._process_audio(audio_chunk)
|
| 162 |
+
|
| 163 |
+
# Create Dummy Mask (all False = keep all)
|
| 164 |
+
B, N, D = patches.shape
|
| 165 |
+
mask = torch.zeros((B, N), dtype=torch.bool, device=patches.device)
|
| 166 |
+
|
| 167 |
+
# Compute encoder
|
| 168 |
+
encoder_out = self.module.compute_encoder(patches, mask, grid_size)
|
| 169 |
+
return encoder_out
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self, audio: torch.Tensor, audio_attention_mask=None
|
| 173 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 174 |
+
# audio: [B, T]
|
| 175 |
+
if audio.ndim == 1:
|
| 176 |
+
audio = audio.unsqueeze(0)
|
| 177 |
+
|
| 178 |
+
B, T = audio.shape
|
| 179 |
+
|
| 180 |
+
# 1. Handle Short Audio (Whole Batch)
|
| 181 |
+
if T < self.min_samples:
|
| 182 |
+
pad_amt = self.min_samples - T
|
| 183 |
+
audio = torch.nn.functional.pad(audio, (0, pad_amt))
|
| 184 |
+
T = self.min_samples # Update T
|
| 185 |
+
|
| 186 |
+
# 2. Sequential Chunking
|
| 187 |
+
if T <= self.chunk_samples:
|
| 188 |
+
# Single chunk processing
|
| 189 |
+
return self._forward_chunk(audio), None
|
| 190 |
+
else:
|
| 191 |
+
# Split into chunks of max length
|
| 192 |
+
chunks = torch.split(audio, self.chunk_samples, dim=1)
|
| 193 |
+
outputs = []
|
| 194 |
+
|
| 195 |
+
for chunk in chunks:
|
| 196 |
+
# Handle potentially short last chunk
|
| 197 |
+
chunk_len = chunk.shape[1]
|
| 198 |
+
|
| 199 |
+
if chunk_len < self.min_samples:
|
| 200 |
+
pad_amt = self.min_samples - chunk_len
|
| 201 |
+
chunk = torch.nn.functional.pad(chunk, (0, pad_amt))
|
| 202 |
+
|
| 203 |
+
# Process
|
| 204 |
+
out_chunk = self._forward_chunk(chunk)
|
| 205 |
+
|
| 206 |
+
# If we padded the last chunk solely to meet min_samples,
|
| 207 |
+
# should we slice? BestRQ2 output is patches.
|
| 208 |
+
# 1 patch covers `min_samples`.
|
| 209 |
+
# If original was < 1 patch, we produced 1 patch.
|
| 210 |
+
# We can't slice sub-patch. We just return the 1 patch.
|
| 211 |
+
|
| 212 |
+
outputs.append(out_chunk)
|
| 213 |
+
|
| 214 |
+
# Concatenate along sequence dimension (dim=1)
|
| 215 |
+
final_output = torch.cat(outputs, dim=1)
|
| 216 |
+
|
| 217 |
+
return final_output, None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
try:
|
| 222 |
+
mdl = BestRQ2Encoder()
|
| 223 |
+
print("Model initialized successfully")
|
| 224 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
+
mdl.module.to(device)
|
| 226 |
+
x = torch.randn(1, 160000).to(device)
|
| 227 |
+
y, _ = mdl(x)
|
| 228 |
+
print(f"Output shape: {y.shape}")
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Error testing model: {e}")
|
| 231 |
+
import traceback
|
| 232 |
+
|
| 233 |
+
traceback.print_exc()
|
README.md
CHANGED
|
@@ -1,3 +1,35 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BEST-RQ-2 (xares-llm encoder)
|
| 2 |
+
|
| 3 |
+
This folder contains the BEST-RQ-2 audio encoder integration for `xares-llm`.
|
| 4 |
+
|
| 5 |
+
Benchmark repository: [xiaomi-research/xares-llm](https://github.com/xiaomi-research/xares-llm)
|
| 6 |
+
|
| 7 |
+
## Setup
|
| 8 |
+
|
| 9 |
+
1. Make sure the environment to run `xares-llm` is set up properly (virtual environment initialized and the `xares-llm` package downloaded/installed).
|
| 10 |
+
|
| 11 |
+
2. Add the `BEST-RQ-2` folder to the `xares-llm` (current) directory so it is available at `./BEST-RQ-2`.
|
| 12 |
+
|
| 13 |
+
3. Before running a `xares-llm` evaluation, you must install the required packages for BEST-RQ-2:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
uv pip install -r BEST-RQ-2/audio-embeddings/pyproject.toml
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Run an evaluation
|
| 20 |
+
|
| 21 |
+
Single task (e.g. to test that everything works):
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
uv run -m xares_llm.run BEST-RQ-2.BEST-RQ-2_encoder.BestRQ2Encoder <task> <task> # e.g. replace <task> with esc-50, score should be around 0.48
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
All tasks:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
uv run -m xares_llm.run BEST-RQ-2.BEST-RQ-2_encoder.BestRQ2Encoder all
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Help
|
| 34 |
+
|
| 35 |
+
If you encounter any problems, contact: [ludovic.tuncay@irit.fr](mailto:ludovic.tuncay@irit.fr)
|
audio-embeddings/.githooks/post-checkout
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
repo_root="$(git rev-parse --show-toplevel)"
|
| 5 |
+
cd "$repo_root"
|
| 6 |
+
|
| 7 |
+
if ! command -v uv >/dev/null 2>&1; then
|
| 8 |
+
echo "[post-checkout] uv not found; skipping uv sync" >&2
|
| 9 |
+
exit 0
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
echo "[post-checkout] Running uv sync..."
|
| 13 |
+
uv sync
|
audio-embeddings/.githooks/post-merge
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
repo_root="$(git rev-parse --show-toplevel)"
|
| 5 |
+
cd "$repo_root"
|
| 6 |
+
|
| 7 |
+
if ! command -v uv >/dev/null 2>&1; then
|
| 8 |
+
echo "[post-merge] uv not found; skipping uv sync" >&2
|
| 9 |
+
exit 0
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
echo "[post-merge] Running uv sync..."
|
| 13 |
+
uv sync
|
audio-embeddings/.githooks/post-rewrite
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
repo_root="$(git rev-parse --show-toplevel)"
|
| 5 |
+
cd "$repo_root"
|
| 6 |
+
|
| 7 |
+
if ! command -v uv >/dev/null 2>&1; then
|
| 8 |
+
echo "[post-rewrite] uv not found; skipping uv sync" >&2
|
| 9 |
+
exit 0
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
echo "[post-rewrite] Running uv sync..."
|
| 13 |
+
uv sync
|
audio-embeddings/.gitignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Project logs
|
| 13 |
+
logs/
|
| 14 |
+
|
| 15 |
+
# Large data files
|
| 16 |
+
*.h5
|
| 17 |
+
|
| 18 |
+
# Data
|
| 19 |
+
data/
|
| 20 |
+
|
| 21 |
+
# Jupyter
|
| 22 |
+
.ipynb_checkpoints/
|
| 23 |
+
|
| 24 |
+
# macOS Finder metadata
|
| 25 |
+
.DS_Store
|
| 26 |
+
|
| 27 |
+
# Explicitly untracked large docs
|
| 28 |
+
documents/LeJEPA.pdf
|
| 29 |
+
documents/Audio-LeJEPA.pdf
|
audio-embeddings/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v5.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
# Trim whitespace at end of lines.
|
| 6 |
+
- id: trailing-whitespace
|
| 7 |
+
# Ensure files end with a single newline.
|
| 8 |
+
- id: end-of-file-fixer
|
| 9 |
+
# Catch unresolved merge conflict markers.
|
| 10 |
+
- id: check-merge-conflict
|
| 11 |
+
# Validate YAML syntax (Hydra configs, etc.).
|
| 12 |
+
- id: check-yaml
|
| 13 |
+
# Validate TOML syntax (e.g., pyproject.toml).
|
| 14 |
+
- id: check-toml
|
| 15 |
+
# Detect mixed CRLF/LF endings without auto-rewriting.
|
| 16 |
+
- id: mixed-line-ending
|
| 17 |
+
args: ["--fix=no"]
|
| 18 |
+
# Catch accidental debug leftovers (breakpoint/pdb).
|
| 19 |
+
- id: debug-statements
|
| 20 |
+
# Block accidental private key commits.
|
| 21 |
+
- id: detect-private-key
|
| 22 |
+
# Prevent case-colliding paths across filesystems.
|
| 23 |
+
- id: check-case-conflict
|
| 24 |
+
# Block newly added large files unless explicitly allowlisted below.
|
| 25 |
+
- id: check-added-large-files
|
| 26 |
+
args: ["--maxkb=5000"]
|
| 27 |
+
exclude: ^(documents/LeJEPA\.pdf|documents/Audio-LeJEPA\.pdf)$
|
| 28 |
+
|
| 29 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 30 |
+
# Ruff version.
|
| 31 |
+
rev: v0.15.0
|
| 32 |
+
hooks:
|
| 33 |
+
# Run the linter.
|
| 34 |
+
- id: ruff-check
|
| 35 |
+
types_or: [python, pyi]
|
| 36 |
+
args: [--fix]
|
| 37 |
+
# Run the formatter.
|
| 38 |
+
- id: ruff-format
|
| 39 |
+
types_or: [python, pyi]
|
audio-embeddings/.project-root
ADDED
|
File without changes
|
audio-embeddings/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
audio-embeddings/AGENTS.md
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AGENTS Guide - audio-embeddings
|
| 2 |
+
|
| 3 |
+
This file is for coding agents working in this repository.
|
| 4 |
+
Follow these repo-specific rules over generic defaults.
|
| 5 |
+
|
| 6 |
+
## 1) Environment Snapshot
|
| 7 |
+
- Python: `>=3.12` (from `pyproject.toml`).
|
| 8 |
+
- Dependency manager: `uv`.
|
| 9 |
+
- Main stack: PyTorch, PyTorch Lightning, Hydra, OmegaConf.
|
| 10 |
+
- Project root marker: `.project-root`.
|
| 11 |
+
- Main entrypoint: `src/train.py`.
|
| 12 |
+
|
| 13 |
+
## 2) Cursor / Copilot Rule Files
|
| 14 |
+
- Checked `.cursor/rules/`: not present.
|
| 15 |
+
- Checked `.cursorrules`: not present.
|
| 16 |
+
- Checked `.github/copilot-instructions.md`: not present.
|
| 17 |
+
- Therefore, no additional Cursor/Copilot rule files are currently enforced.
|
| 18 |
+
|
| 19 |
+
## 3) Install / Setup Commands
|
| 20 |
+
```bash
|
| 21 |
+
uv sync
|
| 22 |
+
uv run <command>
|
| 23 |
+
uv add <package>
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## 4) Build / Train / Eval Commands
|
| 27 |
+
There is no separate "build" step (this is a training codebase).
|
| 28 |
+
Use quick-run training as the integration sanity check.
|
| 29 |
+
```bash
|
| 30 |
+
uv run src/train.py
|
| 31 |
+
uv run src/train.py trainer.fast_dev_run=True
|
| 32 |
+
uv run src/train.py trainer=cpu trainer.fast_dev_run=True
|
| 33 |
+
uv run src/train.py experiment=local/audio_jepa
|
| 34 |
+
uv run src/train.py trainer.max_epochs=10 data.batch_size=32 model.optimizer.lr=1e-4
|
| 35 |
+
```
|
| 36 |
+
Cluster-style execution (existing project pattern):
|
| 37 |
+
```bash
|
| 38 |
+
srun .venv/bin/python -u -O src/train.py experiment=cluster_jepa_audioset_rope +trainer.max_time="00:19:50:00"
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## 5) Lint / Formatting / Static Checks
|
| 42 |
+
Use the commands below as pragmatic checks:
|
| 43 |
+
```bash
|
| 44 |
+
uv run pre-commit run --all-files
|
| 45 |
+
uv run pre-commit run ruff --all-files
|
| 46 |
+
uv run pre-commit run ruff-format --all-files
|
| 47 |
+
uv run python -m compileall src
|
| 48 |
+
```
|
| 49 |
+
Ruff is configured via `.pre-commit-config.yaml` and runs both lint fixes and formatting.
|
| 50 |
+
|
| 51 |
+
## 6) Test Commands (Including Single Test)
|
| 52 |
+
Primary validation in this repo is script-based verification under `tests/`.
|
| 53 |
+
Run test files directly as native Python files:
|
| 54 |
+
```bash
|
| 55 |
+
uv run tests/verify_rope.py
|
| 56 |
+
uv run tests/verify_custom_rope.py
|
| 57 |
+
uv run tests/verify_data.py
|
| 58 |
+
```
|
| 59 |
+
Useful single-file checks (native execution):
|
| 60 |
+
```bash
|
| 61 |
+
uv run src/train.py trainer.fast_dev_run=True
|
| 62 |
+
uv run src/train.py trainer=cpu trainer.fast_dev_run=True
|
| 63 |
+
uv run scripts/verify_shapes.py
|
| 64 |
+
uv run scripts/verify_scheduler.py
|
| 65 |
+
```
|
| 66 |
+
Notes:
|
| 67 |
+
- `tests/test_*.py` are pytest-style and are not part of the default native-file workflow.
|
| 68 |
+
- Prefer `tests/verify_*.py` and `scripts/verify_*.py` for lightweight checks.
|
| 69 |
+
|
| 70 |
+
## 7) Repository Architecture Expectations
|
| 71 |
+
- `configs/`: Hydra composition (trainer/data/model/logger/callbacks/experiment).
|
| 72 |
+
- `src/train.py`: orchestration only (instantiate and run).
|
| 73 |
+
- `src/models/`: LightningModules (high-level training logic).
|
| 74 |
+
- `src/models/components/`: reusable `nn.Module` building blocks.
|
| 75 |
+
- `src/data/`: DataModules/Datasets and collate logic.
|
| 76 |
+
- `src/utils/`: logging, instantiation, wrappers, scheduler helpers.
|
| 77 |
+
When possible, prefer config changes over hardcoded Python changes.
|
| 78 |
+
|
| 79 |
+
## 8) Code Style Guidelines
|
| 80 |
+
### Imports
|
| 81 |
+
- Group imports as: standard library -> third-party -> local `src.*`.
|
| 82 |
+
- Keep one import per line unless importing multiple names from same module.
|
| 83 |
+
- Avoid wildcard imports.
|
| 84 |
+
- Prefer absolute imports from `src...`.
|
| 85 |
+
|
| 86 |
+
### Formatting
|
| 87 |
+
- Use 4-space indentation and readable line lengths.
|
| 88 |
+
- Keep functions small; extract helpers for complex logic.
|
| 89 |
+
- Do not introduce unrelated reformatting in touched files.
|
| 90 |
+
- Keep comments for non-obvious intent, not obvious mechanics.
|
| 91 |
+
|
| 92 |
+
### Typing
|
| 93 |
+
- Type hints are expected for function arguments and return values.
|
| 94 |
+
- Use concrete tensor/container types when practical.
|
| 95 |
+
- Use `Optional[T]` / `T | None` consistently within a file.
|
| 96 |
+
- For dict-like configs, type as `DictConfig` when passing Hydra config objects.
|
| 97 |
+
|
| 98 |
+
### Naming
|
| 99 |
+
- `snake_case`: functions, variables, module filenames.
|
| 100 |
+
- `PascalCase`: classes (`AudioJEPAModule`, `AudioSetDataModule`).
|
| 101 |
+
- `UPPER_SNAKE_CASE`: constants.
|
| 102 |
+
- Prefer descriptive names (`mask_indices`) over short names (`m2`) except local math temporaries.
|
| 103 |
+
|
| 104 |
+
### PyTorch / Lightning / Hydra Conventions
|
| 105 |
+
- Keep heavy compute out of `__init__` where possible.
|
| 106 |
+
- `forward()` for inference logic; training behavior in `training_step()`.
|
| 107 |
+
- Use `self.log(...)` with explicit flags (`on_step`, `on_epoch`, `prog_bar`, `batch_size`).
|
| 108 |
+
- Instantiate components through Hydra (`hydra.utils.instantiate`).
|
| 109 |
+
- Expose tunable parameters in config files, not hardcoded literals.
|
| 110 |
+
|
| 111 |
+
### Error Handling and Validation
|
| 112 |
+
- Raise informative `ValueError` / `RuntimeError` for invalid config/state.
|
| 113 |
+
- Validate critical tensor assumptions with assertions or explicit checks.
|
| 114 |
+
- Prefer logger/warnings over bare `print()` in new code.
|
| 115 |
+
- For file I/O, prefer `pathlib.Path` and existence checks.
|
| 116 |
+
|
| 117 |
+
### Data and Paths
|
| 118 |
+
- Do not hardcode absolute machine paths.
|
| 119 |
+
- Use `rootutils.setup_root(..., indicator=".project-root", pythonpath=True)` in entrypoints/scripts when needed.
|
| 120 |
+
- Respect `cfg.paths.*` outputs for logs/checkpoints/artifacts.
|
| 121 |
+
|
| 122 |
+
## 9) Agent Workflow Rules
|
| 123 |
+
- Reuse existing components before adding new abstractions.
|
| 124 |
+
- Keep `src/train.py` generic; place model/data logic in dedicated modules.
|
| 125 |
+
- Prefer minimal, focused diffs.
|
| 126 |
+
- Update configs and docs when behavior changes.
|
| 127 |
+
- Validate with the smallest meaningful command first (`fast_dev_run`, single test), then broader checks.
|
| 128 |
+
|
| 129 |
+
## 10) Git / Change Hygiene
|
| 130 |
+
- Do not revert unrelated local changes.
|
| 131 |
+
- Keep commits scoped to one concern.
|
| 132 |
+
- Write clear commit messages describing intent.
|
| 133 |
+
- Prefer Conventional Commit-like format: `type(scope): intent`.
|
| 134 |
+
- Common types in this repo: `feat`, `fix`, `conf`, `build`, `docs`, `style`, `chore`.
|
| 135 |
+
- Never commit secrets, credentials, or environment-specific absolute paths.
|
| 136 |
+
|
| 137 |
+
## 11) Practical Agent Defaults
|
| 138 |
+
- Prefer reusing existing modules over creating new abstractions.
|
| 139 |
+
- Keep edits local to the requested change; avoid drive-by refactors.
|
| 140 |
+
- Run the smallest useful verification command after changes.
|
| 141 |
+
- If you touch training logic, run at least one fast training sanity check.
|
| 142 |
+
- If you touch model components, run relevant verify script(s) in `tests/`.
|
| 143 |
+
- If you touch Hydra config wiring, run a config-backed entry command via `uv run src/train.py ...`.
|
| 144 |
+
|
| 145 |
+
## 12) Common Pitfalls
|
| 146 |
+
- Avoid hardcoding data paths; use config (`cfg.paths`, data config fields).
|
| 147 |
+
- Avoid printing in new code paths; use ranked loggers/warnings.
|
| 148 |
+
- Avoid putting heavy tensor compute in constructors.
|
| 149 |
+
- Avoid bypassing Hydra by manually instantiating configurable components.
|
| 150 |
+
- Avoid changing unrelated formatting in files you touch.
|
audio-embeddings/README.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Audio Embeddings with Lightning & Hydra
|
| 2 |
+
|
| 3 |
+
This project is a clean, modular, and scalable implementation of audio embedding models using **PyTorch Lightning** and **Hydra**. It is designed to be easily extensible and runnable on local or cluster environments. It is based on the [Audio-JEPA](https://github.com/LudovicTuncay/Audio-JEPA) implementation and therefore implements the Audio-JEPA architecture. Other architecture can and will be added in the future.
|
| 4 |
+
|
| 5 |
+
## 🎯 Goal
|
| 6 |
+
|
| 7 |
+
The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include:
|
| 8 |
+
- **Modular Architecture**: Components like Spectrogram, Masking, and ViT are decoupled.
|
| 9 |
+
- **Configurable Positional Embeddings**: Support for **RoPE** (2D Rotary Embeddings), **SinCos** (2D Sinusoidal), and **Learnable** embeddings.
|
| 10 |
+
- **Hydra Configuration**: flexible experiment management via hierarchical config files.
|
| 11 |
+
- **Lightning Trainer**: Simplified training loop, logging, and checkpointing.
|
| 12 |
+
- **Modern Tooling**: Uses `uv` for fast and reliable dependency management.
|
| 13 |
+
|
| 14 |
+
## 🚀 Installation
|
| 15 |
+
|
| 16 |
+
This project uses [`uv`](https://github.com/astral-sh/uv) for dependency management.
|
| 17 |
+
|
| 18 |
+
1. **Install `uv`** (if not already installed):
|
| 19 |
+
```bash
|
| 20 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
2. **Clone the repository**:
|
| 24 |
+
```bash
|
| 25 |
+
git clone <repository_url>
|
| 26 |
+
cd audio-embeddings
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
3. **Install dependencies**:
|
| 30 |
+
```bash
|
| 31 |
+
uv sync
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
4. **Enable shared git hooks** (runs `uv sync` after merge/checkout/rewrite):
|
| 35 |
+
```bash
|
| 36 |
+
git config core.hooksPath .githooks
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## 🏃 Usage
|
| 40 |
+
|
| 41 |
+
### Basic Training
|
| 42 |
+
To start training with the default configuration:
|
| 43 |
+
```bash
|
| 44 |
+
uv run src/train.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Common Commands
|
| 48 |
+
Run on GPU with Weights & Biases logging:
|
| 49 |
+
```bash
|
| 50 |
+
uv run src/train.py trainer=gpu logger=wandb
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Override hyperparameters on the command line:
|
| 54 |
+
```bash
|
| 55 |
+
uv run src/train.py data.batch_size=64 trainer.max_epochs=50
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Configurable Positional Embeddings
|
| 59 |
+
You can switch between different positional embedding strategies easily:
|
| 60 |
+
|
| 61 |
+
**RoPE**:
|
| 62 |
+
```bash
|
| 63 |
+
uv run src/train.py model.net.encoder.pos_embed_type=rope
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Offline WandB Logging with Model Checkpoints
|
| 67 |
+
To run training offline but still have model checkpoints staged for upload (which standard WandB restricts):
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
uv run src/train.py \
|
| 71 |
+
logger=wandb \
|
| 72 |
+
logger.wandb.offline=True \
|
| 73 |
+
logger.wandb.log_model=False \
|
| 74 |
+
+callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \
|
| 75 |
+
trainer=gpu trainer.devices=1 \
|
| 76 |
+
data.batch_size=128 trainer.max_epochs=100
|
| 77 |
+
```
|
| 78 |
+
These checkpoints will be uploaded when you run `wandb sync`.
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
**2D SinCos**:
|
| 82 |
+
```bash
|
| 83 |
+
uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincos
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
**Learnable**:
|
| 87 |
+
```bash
|
| 88 |
+
uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnable
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## 📂 Project Structure
|
| 92 |
+
|
| 93 |
+
```text
|
| 94 |
+
├── configs/ # Hydra configuration files
|
| 95 |
+
│ ├── callbacks/ # Callback configs (checkpoints, early stopping)
|
| 96 |
+
│ ├── data/ # Data configs (AudioSet, etc.)
|
| 97 |
+
│ ├── logger/ # Logger configs (WandB, Tensorboard)
|
| 98 |
+
│ ├── model/ # Model configs (AudioJEPA parameters)
|
| 99 |
+
│ ├── trainer/ # Trainer configs (CPU, GPU, strategies)
|
| 100 |
+
│ └── train.yaml # Main configuration entry point
|
| 101 |
+
├── src/
|
| 102 |
+
│ ├── data/ # Data loading logic
|
| 103 |
+
│ │ └── audioset_datamodule.py # AudioSet DataModule & Dataset
|
| 104 |
+
│ ├── models/ # Model architectures
|
| 105 |
+
│ │ ├── components/ # Reusable blocks
|
| 106 |
+
│ │ │ ├── masking.py # Masking generators
|
| 107 |
+
│ │ │ ├── patch_embed.py # Patchification
|
| 108 |
+
│ │ │ ├── rope.py # 2D Rotary Embeddings
|
| 109 |
+
│ │ │ ├── spectrogram.py # Audio preprocessing
|
| 110 |
+
│ │ │ └── vit.py # Vision Transformer (Student/Teacher/Predictor)
|
| 111 |
+
│ │ └── audio_jepa_module.py # Main LightningModule
|
| 112 |
+
│ ├── utils/ # Utility functions
|
| 113 |
+
│ └── train.py # Training entry point
|
| 114 |
+
├── scripts/ # Helper scripts
|
| 115 |
+
├── tests/ # Verification tests
|
| 116 |
+
├── pyproject.toml # Project dependencies
|
| 117 |
+
└── README.md # This file
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## 🛠️ Extensibility
|
| 121 |
+
|
| 122 |
+
### Adding a New Model
|
| 123 |
+
1. Create your model components in `src/models/components/`.
|
| 124 |
+
2. Create a new LightningModule in `src/models/` (or update `AudioJEPAModule`).
|
| 125 |
+
3. Create a new config file in `configs/model/my_new_model.yaml`.
|
| 126 |
+
4. Run with `uv run src/train.py model=my_new_model`.
|
| 127 |
+
|
| 128 |
+
### Adding a New Dataset
|
| 129 |
+
1. Create a new DataModule in `src/data/`.
|
| 130 |
+
2. Create a new config file in `configs/data/my_dataset.yaml`.
|
| 131 |
+
3. Run with `uv run src/train.py data=my_dataset`.
|
| 132 |
+
|
| 133 |
+
### Adding Functionalities
|
| 134 |
+
- **Callbacks**: Add custom callbacks in `src/callbacks/` (if needed) or use existing Lightning callbacks, and configure them in `configs/callbacks/`.
|
| 135 |
+
- **Metrics**: Add metrics logging in `training_step` or `validation_step` inside `src/models/audio_jepa_module.py`.
|
| 136 |
+
|
| 137 |
+
## 🧪 Testing
|
| 138 |
+
Run verification scripts to ensure components are working:
|
| 139 |
+
```bash
|
| 140 |
+
uv run tests/verify_rope.py
|
| 141 |
+
uv run tests/verify_custom_rope.py
|
| 142 |
+
```
|
audio-embeddings/THIRD_PARTY_LICENSES.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Third-Party Licenses
|
| 2 |
+
|
| 3 |
+
This project vendors third-party source code listed below.
|
| 4 |
+
|
| 5 |
+
## Lightning-AI / pytorch-lightning
|
| 6 |
+
|
| 7 |
+
- Component: `src/callbacks/lightning_weight_averaging.py`
|
| 8 |
+
- Source repository: `https://github.com/Lightning-AI/pytorch-lightning`
|
| 9 |
+
- Source file: `src/lightning/pytorch/callbacks/weight_averaging.py`
|
| 10 |
+
- Pinned commit: `9bcba1c1e82b45e10f948dc28fc12f4cf04ab736`
|
| 11 |
+
- License: Apache License 2.0
|
| 12 |
+
|
| 13 |
+
See `licenses/APACHE-2.0-LIGHTNING.txt` for the full license text.
|
audio-embeddings/configs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# this file is needed here to include configs when building project as a package
|
audio-embeddings/configs/callbacks/default.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model_checkpoint
|
| 3 |
+
- model_summary
|
| 4 |
+
- rich_progress_bar
|
| 5 |
+
- _self_
|
| 6 |
+
|
| 7 |
+
model_checkpoint:
|
| 8 |
+
dirpath: ${paths.output_dir}/checkpoints
|
| 9 |
+
filename: "best-step-{step:06d}"
|
| 10 |
+
monitor: "val/loss"
|
| 11 |
+
mode: "min"
|
| 12 |
+
save_last: True
|
| 13 |
+
save_weights_only: False
|
| 14 |
+
save_top_k: 0
|
| 15 |
+
auto_insert_metric_name: False
|
| 16 |
+
|
| 17 |
+
safetensors:
|
| 18 |
+
_target_: src.callbacks.safetensors_callback.SafetensorsCallback
|
| 19 |
+
cleanup_orphan_safetensors: True
|
| 20 |
+
|
| 21 |
+
# early_stopping:
|
| 22 |
+
# monitor: "val/loss"
|
| 23 |
+
# patience: 100
|
| 24 |
+
# mode: "min"
|
| 25 |
+
|
| 26 |
+
model_summary:
|
| 27 |
+
max_depth: 1
|
| 28 |
+
|
| 29 |
+
device_stats:
|
| 30 |
+
_target_: lightning.pytorch.callbacks.DeviceStatsMonitor
|
| 31 |
+
|
| 32 |
+
visualization:
|
| 33 |
+
_target_: src.callbacks.visualization_callback.VisualizationCallback
|
| 34 |
+
num_samples: 4
|
audio-embeddings/configs/callbacks/early_stopping.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
| 2 |
+
|
| 3 |
+
early_stopping:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
| 5 |
+
monitor: ??? # quantity to be monitored, must be specified !!!
|
| 6 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
| 7 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 10 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
| 11 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
| 12 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
| 13 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
| 14 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
| 15 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
audio-embeddings/configs/callbacks/ema_weight_averaging.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ema_weight_averaging:
|
| 2 |
+
_target_: src.callbacks.ema_weight_averaging.WarmupEMAWeightAveraging
|
| 3 |
+
warmup_pct: ${model.warmup_pct}
|
| 4 |
+
enabled: ${model.ema.enabled}
|
| 5 |
+
decay: ${model.ema.decay} # decay rate is 1 - decay_numerator / ( total_steps - warmup_steps )
|
| 6 |
+
decay_numerator: ${model.ema.decay_numerator} # decay rate is 1 - decay_numerator / ( total_steps - warmup_steps )
|
| 7 |
+
update_every_n_steps: ${model.ema.update_every_n_steps}
|
| 8 |
+
use_buffers: ${model.ema.use_buffers}
|
| 9 |
+
update_starting_at_step: null
|
audio-embeddings/configs/callbacks/model_checkpoint.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
| 2 |
+
|
| 3 |
+
model_checkpoint:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 5 |
+
dirpath: null # directory to save the model file
|
| 6 |
+
filename: null # checkpoint filename
|
| 7 |
+
monitor: null # name of the logged metric which determines when model is improving
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
| 10 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 12 |
+
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
| 13 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
| 14 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
| 15 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
| 16 |
+
every_n_epochs: null # number of epochs between checkpoints
|
| 17 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
audio-embeddings/configs/callbacks/model_summary.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
| 2 |
+
|
| 3 |
+
model_summary:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
| 5 |
+
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
audio-embeddings/configs/callbacks/none.yaml
ADDED
|
File without changes
|
audio-embeddings/configs/callbacks/rich_progress_bar.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
| 2 |
+
|
| 3 |
+
rich_progress_bar:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
audio-embeddings/configs/callbacks/wandb_offline.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_offline_checkpoint:
|
| 2 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/data/audioset.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.data.audioset_datamodule.AudioSetDataModule
|
| 2 |
+
data_dir: ${paths.data_dir}/AudioSet
|
| 3 |
+
batch_size: 64
|
| 4 |
+
num_workers: 4
|
| 5 |
+
pin_memory: True
|
| 6 |
+
train_h5: full_unbal_bal_train_wav.h5
|
| 7 |
+
train_csv: silent_files_full_unbal_bal_train_wav.csv
|
| 8 |
+
val_h5: eval_soxrhq.h5
|
| 9 |
+
val_csv: silent_files_eval_soxrhq.csv
|
| 10 |
+
max_audio_length_sec: 10.0 # 10 seconds
|
| 11 |
+
target_sample_rate: 16000
|
| 12 |
+
collate_mode: pad
|
audio-embeddings/configs/data/mock_audioset.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.data.mock_audioset_datamodule.MockAudioSetDataModule
|
| 2 |
+
batch_size: 16
|
| 3 |
+
num_workers: 0 # 0 is often better for simple debugging/validation on local Mac
|
| 4 |
+
pin_memory: True
|
| 5 |
+
max_audio_length_sec: 10.0
|
| 6 |
+
target_sample_rate: 16000
|
| 7 |
+
collate_mode: pad
|
audio-embeddings/configs/data/yt1b.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.data.yt1b_datamodule.YT1BDataModule
|
| 2 |
+
data_dir: ${paths.data_dir}/YT-Temporal-1B
|
| 3 |
+
batch_size: 64
|
| 4 |
+
num_workers: 4
|
| 5 |
+
pin_memory: True
|
| 6 |
+
train_parquet: train_metadata.parquet
|
| 7 |
+
val_parquet: val_metadata.parquet
|
| 8 |
+
test_parquet: val_metadata.parquet
|
| 9 |
+
max_audio_length_sec: 10.0
|
| 10 |
+
min_duration_sec: 10.0
|
| 11 |
+
target_sample_rate: 16000
|
| 12 |
+
collate_mode: pad
|
| 13 |
+
decode_window_sec: null
|
audio-embeddings/configs/experiment/audio_jepa/baseline.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: audioset
|
| 5 |
+
- override /model: audio_jepa
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "jepa", "baseline", "cluster"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
data_dir: ${paths.data_dir}
|
| 20 |
+
train_h5: AudioSet/full_unbal_bal_train_wav.h5
|
| 21 |
+
val_h5: AudioSet/eval_soxrhq.h5
|
| 22 |
+
batch_size: 256
|
| 23 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
offline: True
|
| 28 |
+
log_model: False
|
| 29 |
+
|
| 30 |
+
callbacks:
|
| 31 |
+
rich_progress_bar: null
|
| 32 |
+
|
| 33 |
+
wandb_offline_checkpoint:
|
| 34 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/audio_jepa/large.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: audioset
|
| 5 |
+
- override /model: audio_jepa
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "jepa", "large", "cluster", "1GPU"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
data_dir: ${paths.data_dir}
|
| 20 |
+
train_h5: AudioSet/full_unbal_bal_train_wav.h5
|
| 21 |
+
val_h5: AudioSet/eval_soxrhq.h5
|
| 22 |
+
batch_size: 256
|
| 23 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
offline: True
|
| 28 |
+
log_model: False
|
| 29 |
+
|
| 30 |
+
callbacks:
|
| 31 |
+
rich_progress_bar: null
|
| 32 |
+
|
| 33 |
+
wandb_offline_checkpoint:
|
| 34 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 35 |
+
|
| 36 |
+
# ViT Large Configuration
|
| 37 |
+
model:
|
| 38 |
+
net:
|
| 39 |
+
patch_embed:
|
| 40 |
+
embed_dim: 1024
|
| 41 |
+
encoder:
|
| 42 |
+
embed_dim: 1024
|
| 43 |
+
depth: 24
|
| 44 |
+
num_heads: 16
|
audio-embeddings/configs/experiment/audio_jepa/rope.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: audioset
|
| 5 |
+
- override /model: audio_jepa
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "jepa", "RoPE", "cluster"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
data_dir: ${paths.data_dir}
|
| 20 |
+
train_h5: AudioSet/full_unbal_bal_train_wav.h5
|
| 21 |
+
val_h5: AudioSet/eval_soxrhq.h5
|
| 22 |
+
batch_size: 256
|
| 23 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
offline: True
|
| 28 |
+
log_model: False
|
| 29 |
+
|
| 30 |
+
callbacks:
|
| 31 |
+
rich_progress_bar: null
|
| 32 |
+
|
| 33 |
+
wandb_offline_checkpoint:
|
| 34 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 35 |
+
|
| 36 |
+
model:
|
| 37 |
+
net:
|
| 38 |
+
encoder:
|
| 39 |
+
pos_embed_type: rope
|
| 40 |
+
predictor:
|
| 41 |
+
pos_embed_type: rope
|
audio-embeddings/configs/experiment/audio_jepa/time_res2x.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: audioset
|
| 5 |
+
- override /model: audio_jepa
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "jepa", "time_res2x", "cluster"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
data_dir: ${paths.data_dir}
|
| 20 |
+
train_h5: AudioSet/full_unbal_bal_train_wav.h5
|
| 21 |
+
val_h5: AudioSet/eval_soxrhq.h5
|
| 22 |
+
batch_size: 256
|
| 23 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
offline: True
|
| 28 |
+
log_model: False
|
| 29 |
+
|
| 30 |
+
callbacks:
|
| 31 |
+
rich_progress_bar: null
|
| 32 |
+
|
| 33 |
+
wandb_offline_checkpoint:
|
| 34 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 35 |
+
|
| 36 |
+
model:
|
| 37 |
+
net:
|
| 38 |
+
spectrogram:
|
| 39 |
+
win_length_ms: 64 # halved
|
| 40 |
+
hop_length_ms: 19.53125 # halved
|
| 41 |
+
|
| 42 |
+
patch_embed:
|
| 43 |
+
img_size: [128, 512] # width doubled because hop is halved
|
| 44 |
+
|
| 45 |
+
masking:
|
| 46 |
+
input_size: [128, 512]
|
| 47 |
+
|
| 48 |
+
encoder:
|
| 49 |
+
num_patches: 256 # (128/16) * (512/16) = 8 * 32 = 256
|
| 50 |
+
img_size: [128, 512] # Explicitly set img_size for ViT to match patch_embed
|
| 51 |
+
|
| 52 |
+
predictor:
|
| 53 |
+
num_patches: 256
|
| 54 |
+
img_size: [128, 512]
|
audio-embeddings/configs/experiment/audio_jepa/time_res4x.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: audioset
|
| 5 |
+
- override /model: audio_jepa
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "jepa", "time_res4x", "cluster"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
data_dir: ${paths.data_dir}
|
| 20 |
+
train_h5: AudioSet/full_unbal_bal_train_wav.h5
|
| 21 |
+
val_h5: AudioSet/eval_soxrhq.h5
|
| 22 |
+
batch_size: 256
|
| 23 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
offline: True
|
| 28 |
+
log_model: False
|
| 29 |
+
|
| 30 |
+
callbacks:
|
| 31 |
+
rich_progress_bar: null
|
| 32 |
+
|
| 33 |
+
wandb_offline_checkpoint:
|
| 34 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 35 |
+
|
| 36 |
+
model:
|
| 37 |
+
net:
|
| 38 |
+
spectrogram:
|
| 39 |
+
win_length_ms: 32 # quartered
|
| 40 |
+
hop_length_ms: 9.765625 # quartered
|
| 41 |
+
|
| 42 |
+
patch_embed:
|
| 43 |
+
img_size: [128, 1024] # width quadrupled because hop is quartered
|
| 44 |
+
|
| 45 |
+
masking:
|
| 46 |
+
input_size: [128, 1024]
|
| 47 |
+
|
| 48 |
+
encoder:
|
| 49 |
+
num_patches: 512 # (128/16) * (1024/16) = 8 * 64 = 512
|
| 50 |
+
img_size: [128, 1024]
|
| 51 |
+
|
| 52 |
+
predictor:
|
| 53 |
+
num_patches: 512
|
| 54 |
+
img_size: [128, 1024]
|
audio-embeddings/configs/experiment/best_rq/audioset.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq/audioset
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "best-rq", "baseline", "cluster GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
max_steps: 200000
|
| 19 |
+
|
| 20 |
+
data:
|
| 21 |
+
batch_size: 256
|
| 22 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 23 |
+
|
| 24 |
+
logger:
|
| 25 |
+
wandb:
|
| 26 |
+
offline: True
|
| 27 |
+
log_model: False
|
| 28 |
+
|
| 29 |
+
callbacks:
|
| 30 |
+
rich_progress_bar: null
|
| 31 |
+
|
| 32 |
+
wandb_offline_checkpoint:
|
| 33 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq/yt1b.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: yt1b
|
| 5 |
+
- override /model: best_rq
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["yt1b", "best-rq", "baseline", "cluster GPU"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
batch_size: 256
|
| 20 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 21 |
+
|
| 22 |
+
logger:
|
| 23 |
+
wandb:
|
| 24 |
+
offline: True
|
| 25 |
+
log_model: False
|
| 26 |
+
|
| 27 |
+
callbacks:
|
| 28 |
+
rich_progress_bar: null
|
| 29 |
+
|
| 30 |
+
wandb_offline_checkpoint:
|
| 31 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "best-rq-2", "baseline", "cluster GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
max_steps: 200000
|
| 19 |
+
|
| 20 |
+
data:
|
| 21 |
+
batch_size: 256
|
| 22 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 23 |
+
|
| 24 |
+
logger:
|
| 25 |
+
wandb:
|
| 26 |
+
offline: True
|
| 27 |
+
log_model: False
|
| 28 |
+
|
| 29 |
+
callbacks:
|
| 30 |
+
rich_progress_bar: null
|
| 31 |
+
|
| 32 |
+
wandb_offline_checkpoint:
|
| 33 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_100k_512bs.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_100k_512bs
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
tags: ["audioset", "best-rq-2", "100k", "512bs", "cluster GPU"]
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
max_steps: 100000
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
batch_size: 512
|
| 19 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 20 |
+
|
| 21 |
+
logger:
|
| 22 |
+
wandb:
|
| 23 |
+
offline: True
|
| 24 |
+
log_model: False
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
rich_progress_bar: null
|
| 28 |
+
|
| 29 |
+
wandb_offline_checkpoint:
|
| 30 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_1m_128bs_4gpu.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_1m_128bs_4gpu
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: ddp
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
tags: ["audioset", "best-rq-2", "1m", "128bs", "4GPU", "cluster GPU"]
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
max_steps: 1000000
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
batch_size: 128
|
| 19 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 20 |
+
|
| 21 |
+
logger:
|
| 22 |
+
wandb:
|
| 23 |
+
offline: True
|
| 24 |
+
log_model: False
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
rich_progress_bar: null
|
| 28 |
+
|
| 29 |
+
wandb_offline_checkpoint:
|
| 30 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_200k_256bs_4gpu.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_200k_256bs_4gpu
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: ddp
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
tags: ["audioset", "best-rq-2", "200k", "256bs", "4GPU", "cluster GPU"]
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
max_steps: 200000
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
batch_size: 256
|
| 19 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 20 |
+
|
| 21 |
+
logger:
|
| 22 |
+
wandb:
|
| 23 |
+
offline: True
|
| 24 |
+
log_model: False
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
rich_progress_bar: null
|
| 28 |
+
|
| 29 |
+
wandb_offline_checkpoint:
|
| 30 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_400k_128bs
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
tags: ["audioset", "best-rq-2", "400k", "128bs", "cluster GPU"]
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
max_steps: 400000
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
batch_size: 128
|
| 19 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 20 |
+
|
| 21 |
+
logger:
|
| 22 |
+
wandb:
|
| 23 |
+
offline: True
|
| 24 |
+
log_model: False
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
rich_progress_bar: null
|
| 28 |
+
|
| 29 |
+
wandb_offline_checkpoint:
|
| 30 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs_4gpu.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_400k_128bs_4gpu
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: ddp
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
tags: ["audioset", "best-rq-2", "400k", "128bs", "4GPU", "cluster GPU"]
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
max_steps: 400000
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
batch_size: 128
|
| 19 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 20 |
+
|
| 21 |
+
logger:
|
| 22 |
+
wandb:
|
| 23 |
+
offline: True
|
| 24 |
+
log_model: False
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
rich_progress_bar: null
|
| 28 |
+
|
| 29 |
+
wandb_offline_checkpoint:
|
| 30 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_800k_64bs_4gpu.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_800k_64bs_4gpu
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: ddp
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
tags: ["audioset", "best-rq-2", "800k", "64bs", "4GPU", "cluster GPU"]
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
max_steps: 800000
|
| 16 |
+
|
| 17 |
+
data:
|
| 18 |
+
batch_size: 64
|
| 19 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 20 |
+
|
| 21 |
+
logger:
|
| 22 |
+
wandb:
|
| 23 |
+
offline: True
|
| 24 |
+
log_model: False
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
rich_progress_bar: null
|
| 28 |
+
|
| 29 |
+
wandb_offline_checkpoint:
|
| 30 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/best_rq_2/audioset_ema.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_ema
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- /callbacks/ema_weight_averaging
|
| 8 |
+
- override /data: audioset
|
| 9 |
+
- override /model: best_rq2
|
| 10 |
+
- override /trainer: gpu
|
| 11 |
+
- override /logger: wandb
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "best-rq-2", "ema", "cluster GPU"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
batch_size: 256
|
| 20 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 21 |
+
|
| 22 |
+
logger:
|
| 23 |
+
wandb:
|
| 24 |
+
offline: True
|
| 25 |
+
log_model: False
|
| 26 |
+
|
| 27 |
+
callbacks:
|
| 28 |
+
rich_progress_bar: null
|
| 29 |
+
|
| 30 |
+
wandb_offline_checkpoint:
|
| 31 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 32 |
+
|
| 33 |
+
model:
|
| 34 |
+
ema:
|
| 35 |
+
enabled: true
|
audio-embeddings/configs/experiment/best_rq_2/audioset_ema_600k.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/audioset_ema_600k
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- /callbacks/ema_weight_averaging
|
| 8 |
+
- override /data: audioset
|
| 9 |
+
- override /model: best_rq2
|
| 10 |
+
- override /trainer: gpu
|
| 11 |
+
- override /logger: wandb
|
| 12 |
+
|
| 13 |
+
tags: ["audioset", "best-rq-2", "ema", "600k", "cluster GPU"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 600000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
batch_size: 256
|
| 20 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 21 |
+
|
| 22 |
+
logger:
|
| 23 |
+
wandb:
|
| 24 |
+
offline: True
|
| 25 |
+
log_model: False
|
| 26 |
+
|
| 27 |
+
callbacks:
|
| 28 |
+
rich_progress_bar: null
|
| 29 |
+
|
| 30 |
+
wandb_offline_checkpoint:
|
| 31 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 32 |
+
|
| 33 |
+
model:
|
| 34 |
+
ema:
|
| 35 |
+
enabled: true
|
| 36 |
+
decay_numerator: 40.0
|
audio-embeddings/configs/experiment/best_rq_2/yt1b_ema.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=best_rq_2/yt1b_ema
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- /callbacks/ema_weight_averaging
|
| 8 |
+
- override /data: yt1b
|
| 9 |
+
- override /model: best_rq2
|
| 10 |
+
- override /trainer: gpu
|
| 11 |
+
- override /logger: wandb
|
| 12 |
+
|
| 13 |
+
tags: ["yt1b", "best-rq-2", "ema", "cluster GPU"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 600000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
batch_size: 256
|
| 20 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 21 |
+
|
| 22 |
+
logger:
|
| 23 |
+
wandb:
|
| 24 |
+
offline: True
|
| 25 |
+
log_model: False
|
| 26 |
+
|
| 27 |
+
callbacks:
|
| 28 |
+
rich_progress_bar: null
|
| 29 |
+
|
| 30 |
+
wandb_offline_checkpoint:
|
| 31 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
| 32 |
+
|
| 33 |
+
model:
|
| 34 |
+
warmup_pct: 0.03
|
| 35 |
+
ema:
|
| 36 |
+
enabled: true
|
| 37 |
+
decay_numerator: 40.0
|
audio-embeddings/configs/experiment/local/audio_jepa.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=audioset_baseline
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: audio_jepa
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "jepa", "baseline", "local GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
devices: 1
|
| 19 |
+
max_steps: 200000
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
batch_size: 128
|
| 23 |
+
num_workers: 16
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
name: "local-jepa-audioset-baseline-200k-128x1bs"
|
| 28 |
+
offline: False
|
| 29 |
+
log_model: True
|
audio-embeddings/configs/experiment/local/audio_jepa_rope.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=audioset_rope
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: audio_jepa
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "jepa", "rope", "local GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
devices: 1
|
| 19 |
+
max_steps: 200000
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
batch_size: 128
|
| 23 |
+
num_workers: 16
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
name: "local-jepa-audioset-rope-200k-128x1bs"
|
| 28 |
+
offline: False
|
| 29 |
+
log_model: True
|
| 30 |
+
|
| 31 |
+
model:
|
| 32 |
+
net:
|
| 33 |
+
encoder:
|
| 34 |
+
pos_embed_type: rope
|
| 35 |
+
predictor:
|
| 36 |
+
pos_embed_type: rope
|
audio-embeddings/configs/experiment/local/best_rq.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=local/best_rq
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "best-rq", "baseline", "local GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
devices: 1
|
| 19 |
+
max_steps: 200000
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
batch_size: 128
|
| 23 |
+
num_workers: 16
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
name: "local-best-rq-vit-audioset-200k-128x1bs"
|
| 28 |
+
offline: False
|
| 29 |
+
log_model: True
|
audio-embeddings/configs/experiment/local/best_rq2.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=local/best_rq2
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: best_rq2
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "best-rq-2", "baseline", "local GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
devices: 1
|
| 19 |
+
max_steps: 100000
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
batch_size: 128
|
| 23 |
+
num_workers: 16
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
name: "cluster-best-rq-2-audioset-100k-128bs"
|
| 28 |
+
offline: True
|
| 29 |
+
log_model: False
|
| 30 |
+
|
| 31 |
+
callbacks:
|
| 32 |
+
# model_checkpoint:
|
| 33 |
+
# save_weights_only: True
|
| 34 |
+
|
| 35 |
+
rich_progress_bar: null
|
| 36 |
+
|
| 37 |
+
wandb_offline_checkpoint:
|
| 38 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/local/m4_mock_jepa.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=local/m4_mock_jepa
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: mock_audioset
|
| 8 |
+
- override /model: audio_jepa
|
| 9 |
+
- override /trainer: mps
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["mock", "jepa", "local", "mps"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
# devices: 1 # set in trainer/mps.yaml
|
| 19 |
+
max_steps: 100
|
| 20 |
+
log_every_n_steps: 1
|
| 21 |
+
val_check_interval: 50
|
| 22 |
+
limit_val_batches: 5
|
| 23 |
+
|
| 24 |
+
data:
|
| 25 |
+
batch_size: 8
|
| 26 |
+
|
| 27 |
+
logger:
|
| 28 |
+
wandb:
|
| 29 |
+
name: "local-m4-mock-jepa"
|
| 30 |
+
offline: True
|
| 31 |
+
log_model: False
|
| 32 |
+
|
| 33 |
+
callbacks:
|
| 34 |
+
model_checkpoint:
|
| 35 |
+
save_weights_only: True
|
| 36 |
+
monitor: null # verify we can save without monitoring
|
| 37 |
+
device_stats: null
|
audio-embeddings/configs/experiment/local/rqa_jepa.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=local_rqa_jepa_audioset
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: rqa_jepa
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "rqa-jepa", "baseline", "local GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
devices: 1
|
| 19 |
+
max_steps: 200000
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
batch_size: 128
|
| 23 |
+
num_workers: 16
|
| 24 |
+
|
| 25 |
+
logger:
|
| 26 |
+
wandb:
|
| 27 |
+
name: "local-rqa-jepa-audioset-200k-128x1bs"
|
| 28 |
+
offline: False
|
| 29 |
+
log_model: True
|
audio-embeddings/configs/experiment/rqa_jepa/audioset.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=cluster_rqa_jepa_audioset
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: audioset
|
| 8 |
+
- override /model: rqa_jepa
|
| 9 |
+
- override /trainer: gpu
|
| 10 |
+
- override /logger: wandb
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["audioset", "rqa-jepa", "baseline", "cluster GPU"]
|
| 16 |
+
|
| 17 |
+
trainer:
|
| 18 |
+
max_steps: 200000
|
| 19 |
+
|
| 20 |
+
data:
|
| 21 |
+
batch_size: 256
|
| 22 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 23 |
+
|
| 24 |
+
logger:
|
| 25 |
+
wandb:
|
| 26 |
+
offline: True
|
| 27 |
+
log_model: False
|
| 28 |
+
|
| 29 |
+
callbacks:
|
| 30 |
+
rich_progress_bar: null
|
| 31 |
+
|
| 32 |
+
wandb_offline_checkpoint:
|
| 33 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|
audio-embeddings/configs/experiment/rqa_jepa/yt1b.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: yt1b
|
| 5 |
+
- override /model: rqa_jepa
|
| 6 |
+
- override /trainer: gpu
|
| 7 |
+
- override /logger: wandb
|
| 8 |
+
- override /callbacks: default
|
| 9 |
+
|
| 10 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 11 |
+
# this allows you to overwrite only specified parameters
|
| 12 |
+
|
| 13 |
+
tags: ["yt1b", "rqa-jepa", "baseline", "cluster GPU"]
|
| 14 |
+
|
| 15 |
+
trainer:
|
| 16 |
+
max_steps: 200000
|
| 17 |
+
|
| 18 |
+
data:
|
| 19 |
+
batch_size: 256
|
| 20 |
+
num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
|
| 21 |
+
|
| 22 |
+
logger:
|
| 23 |
+
wandb:
|
| 24 |
+
offline: True
|
| 25 |
+
log_model: False
|
| 26 |
+
|
| 27 |
+
callbacks:
|
| 28 |
+
rich_progress_bar: null
|
| 29 |
+
|
| 30 |
+
wandb_offline_checkpoint:
|
| 31 |
+
_target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
|