ltuncay commited on
Commit
eca55dc
·
verified ·
1 Parent(s): 673d263

Submission to the Interspeech 2026 Audio Encoder Capability Challenge

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BEST-RQ-2.safetensors +3 -0
  2. BEST-RQ-2_encoder.py +233 -0
  3. README.md +35 -3
  4. audio-embeddings/.githooks/post-checkout +13 -0
  5. audio-embeddings/.githooks/post-merge +13 -0
  6. audio-embeddings/.githooks/post-rewrite +13 -0
  7. audio-embeddings/.gitignore +29 -0
  8. audio-embeddings/.pre-commit-config.yaml +39 -0
  9. audio-embeddings/.project-root +0 -0
  10. audio-embeddings/.python-version +1 -0
  11. audio-embeddings/AGENTS.md +150 -0
  12. audio-embeddings/README.md +142 -0
  13. audio-embeddings/THIRD_PARTY_LICENSES.md +13 -0
  14. audio-embeddings/configs/__init__.py +1 -0
  15. audio-embeddings/configs/callbacks/default.yaml +34 -0
  16. audio-embeddings/configs/callbacks/early_stopping.yaml +15 -0
  17. audio-embeddings/configs/callbacks/ema_weight_averaging.yaml +9 -0
  18. audio-embeddings/configs/callbacks/model_checkpoint.yaml +17 -0
  19. audio-embeddings/configs/callbacks/model_summary.yaml +5 -0
  20. audio-embeddings/configs/callbacks/none.yaml +0 -0
  21. audio-embeddings/configs/callbacks/rich_progress_bar.yaml +4 -0
  22. audio-embeddings/configs/callbacks/wandb_offline.yaml +2 -0
  23. audio-embeddings/configs/data/audioset.yaml +12 -0
  24. audio-embeddings/configs/data/mock_audioset.yaml +7 -0
  25. audio-embeddings/configs/data/yt1b.yaml +13 -0
  26. audio-embeddings/configs/experiment/audio_jepa/baseline.yaml +34 -0
  27. audio-embeddings/configs/experiment/audio_jepa/large.yaml +44 -0
  28. audio-embeddings/configs/experiment/audio_jepa/rope.yaml +41 -0
  29. audio-embeddings/configs/experiment/audio_jepa/time_res2x.yaml +54 -0
  30. audio-embeddings/configs/experiment/audio_jepa/time_res4x.yaml +54 -0
  31. audio-embeddings/configs/experiment/best_rq/audioset.yaml +33 -0
  32. audio-embeddings/configs/experiment/best_rq/yt1b.yaml +31 -0
  33. audio-embeddings/configs/experiment/best_rq_2/audioset.yaml +33 -0
  34. audio-embeddings/configs/experiment/best_rq_2/audioset_100k_512bs.yaml +30 -0
  35. audio-embeddings/configs/experiment/best_rq_2/audioset_1m_128bs_4gpu.yaml +30 -0
  36. audio-embeddings/configs/experiment/best_rq_2/audioset_200k_256bs_4gpu.yaml +30 -0
  37. audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs.yaml +30 -0
  38. audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs_4gpu.yaml +30 -0
  39. audio-embeddings/configs/experiment/best_rq_2/audioset_800k_64bs_4gpu.yaml +30 -0
  40. audio-embeddings/configs/experiment/best_rq_2/audioset_ema.yaml +35 -0
  41. audio-embeddings/configs/experiment/best_rq_2/audioset_ema_600k.yaml +36 -0
  42. audio-embeddings/configs/experiment/best_rq_2/yt1b_ema.yaml +37 -0
  43. audio-embeddings/configs/experiment/local/audio_jepa.yaml +29 -0
  44. audio-embeddings/configs/experiment/local/audio_jepa_rope.yaml +36 -0
  45. audio-embeddings/configs/experiment/local/best_rq.yaml +29 -0
  46. audio-embeddings/configs/experiment/local/best_rq2.yaml +38 -0
  47. audio-embeddings/configs/experiment/local/m4_mock_jepa.yaml +37 -0
  48. audio-embeddings/configs/experiment/local/rqa_jepa.yaml +29 -0
  49. audio-embeddings/configs/experiment/rqa_jepa/audioset.yaml +33 -0
  50. audio-embeddings/configs/experiment/rqa_jepa/yt1b.yaml +31 -0
BEST-RQ-2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7111465e6c868e3d0b55c5fe9a23dc5069ac80beba8444c5ee9db4691d796899
3
+ size 483870856
BEST-RQ-2_encoder.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from omegaconf import OmegaConf
8
+ from safetensors.torch import load_file
9
+
10
+ # Add audio-embeddings to path dynamically
11
+ # We assume audio-embeddings is a sibling directory to xares-llm or provided via env var
12
+ # Prioritize absolute path if known, otherwise relative
13
+ POSSIBLE_PATHS = [
14
+ # "/media/ltuncay/Shared-4TB/dev/audio-embeddings",
15
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "audio-embeddings")),
16
+ # os.path.abspath(os.path.join(os.getcwd(), "../audio-embeddings")),
17
+ ]
18
+
19
+ AUDIO_EMBEDDINGS_PATH = None
20
+ for p in POSSIBLE_PATHS:
21
+ if os.path.exists(p):
22
+ AUDIO_EMBEDDINGS_PATH = p
23
+ break
24
+
25
+ if AUDIO_EMBEDDINGS_PATH:
26
+ if AUDIO_EMBEDDINGS_PATH not in sys.path:
27
+ sys.path.append(AUDIO_EMBEDDINGS_PATH)
28
+ print(f"Added {AUDIO_EMBEDDINGS_PATH} to sys.path")
29
+ else:
30
+ print(
31
+ "Warning: audio-embeddings path not found. Imports may fail if not installed in environment."
32
+ )
33
+
34
+ try:
35
+ from src.models.best_rq2_module import BestRQ2Module
36
+ except ImportError as e:
37
+ raise ImportError(
38
+ f"Could not import src.models.best_rq2_module. Ensure audio-embeddings is correctly located or installed. Error: {e}"
39
+ )
40
+
41
+
42
+ class BestRQ2Encoder(nn.Module):
43
+ def __init__(self, checkpoint_path=None, model_config_path=None, **kwargs):
44
+ super().__init__()
45
+
46
+ base_path = os.path.dirname(__file__)
47
+ model_config_path = os.path.join(base_path, "config.yaml")
48
+ checkpoint_path = os.path.join(base_path, "BEST-RQ-2.safetensors")
49
+
50
+ if not os.path.exists(model_config_path):
51
+ raise FileNotFoundError(f"Config not found at {model_config_path}")
52
+
53
+ if not checkpoint_path or not os.path.exists(checkpoint_path):
54
+ raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
55
+
56
+ print(f"Loading BestRQ2 config from {model_config_path}")
57
+ cfg = OmegaConf.load(model_config_path)
58
+
59
+ print(f"Loading BestRQ2 checkpoint from {checkpoint_path}")
60
+
61
+ # Reconstruct model args from config
62
+ model_cfg = cfg.model
63
+ net_cfg = model_cfg.net
64
+
65
+ # Instantiate model
66
+ # Note: BestRQ2Module inherits from LightningModule
67
+ self.module = BestRQ2Module(
68
+ optimizer=None, # Not needed for inference
69
+ net=net_cfg,
70
+ warmup_pct=model_cfg.get("warmup_pct", 0.1),
71
+ final_lr_ratio=model_cfg.get("final_lr_ratio", 0.001),
72
+ spectrogram_adjustment_mode=model_cfg.get(
73
+ "spectrogram_adjustment_mode", "pad"
74
+ ),
75
+ codebook_dim=model_cfg.get("codebook_dim", 16),
76
+ vocab_size=model_cfg.get("vocab_size", 8192),
77
+ criterion=None,
78
+ )
79
+
80
+ # Load weights
81
+ try:
82
+ state_dict = load_file(checkpoint_path)
83
+ except Exception as e:
84
+ print(f"Error loading safetensors: {e}. Trying torch.load...")
85
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
86
+ if "state_dict" in state_dict:
87
+ state_dict = state_dict["state_dict"]
88
+
89
+ # Handle 'module.' prefix if present in checkpoint vs model
90
+ # Usually LightningModules save with state_dict keys matching model attributes.
91
+ # But sometimes they might be wrapped.
92
+ # We will try loading strict=False and inspect.
93
+
94
+ missing, unexpected = self.module.load_state_dict(state_dict, strict=False)
95
+ if missing:
96
+ # Check if prefixes match
97
+ # If all missing keys start with something common, or if state_dict has prefixes
98
+ print(f"Warning: {len(missing)} keys missing during loading.")
99
+ # print(missing[:5])
100
+ if unexpected:
101
+ print(f"Warning: {len(unexpected)} keys unexpected during loading.")
102
+
103
+ self.module.eval()
104
+ self.output_dim = net_cfg.encoder.embed_dim
105
+
106
+ # Extract dynamic parameters for length handling
107
+ try:
108
+ # 1. Sample Rate & Hop Length (from Spectrogram)
109
+ # BestRQ2Module -> Spectrogram -> MelSpectrogram -> hop_length
110
+ self.sample_rate = self.module.spectrogram.mel_spec.sample_rate
111
+ self.hop_length = self.module.spectrogram.mel_spec.hop_length
112
+
113
+ # 2. Patch Size (Time dimension)
114
+ # BestRQ2Module -> PatchEmbed -> patch_size (H, W) -> W is time
115
+ self.patch_size_time = self.module.patch_embed.patch_size[1]
116
+
117
+ # 3. Max Input Frames (Time dimension)
118
+ # BestRQ2Module -> PatchEmbed -> img_size (H, W) -> W is time frames
119
+ self.max_frames = self.module.patch_embed.img_size[1]
120
+
121
+ # Calculations
122
+ # Minimum samples required to get at least 1 patch width in spectrogram
123
+ # We need T_spec >= patch_size_time
124
+ # T_spec = T_samples // hop_length (roughly)
125
+ # So T_samples >= patch_size_time * hop_length
126
+ self.min_samples = self.patch_size_time * self.hop_length
127
+
128
+ # Chunk size: The maximum audio length the model's positional embeddings can handle
129
+ # T_samples_max = max_frames * hop_length
130
+ self.chunk_samples = self.max_frames * self.hop_length
131
+
132
+ print(
133
+ f"BestRQ2Encoder constraints: Min Samples={self.min_samples}, Chunk Samples={self.chunk_samples}"
134
+ )
135
+
136
+ except Exception as e:
137
+ print(f"Warning: Could not extract dynamic length constraints: {e}")
138
+ print("Falling back to safe defaults (1s min, 10s chunk)")
139
+ self.min_samples = 16000
140
+ self.chunk_samples = 16000 * 10
141
+
142
+ def _forward_chunk(self, audio_chunk: torch.Tensor) -> torch.Tensor:
143
+ """Helper to process a single time-chunk of audio."""
144
+ # Determine target device from the spectrogram window (safest for STFT)
145
+ try:
146
+ target_device = self.module.spectrogram.mel_spec.spectrogram.window.device
147
+ except AttributeError:
148
+ if hasattr(self.module.spectrogram.mel_spec, "window"):
149
+ target_device = self.module.spectrogram.mel_spec.window.device
150
+ else:
151
+ target_device = self.module.device
152
+
153
+ if audio_chunk.device != target_device:
154
+ audio_chunk = audio_chunk.to(target_device)
155
+
156
+ # BestRQ2Module expects [B, C, T]
157
+ if audio_chunk.ndim == 2:
158
+ audio_chunk = audio_chunk.unsqueeze(1) # [B, 1, T]
159
+
160
+ # _process_audio returns (patches, grid_size)
161
+ patches, grid_size = self.module._process_audio(audio_chunk)
162
+
163
+ # Create Dummy Mask (all False = keep all)
164
+ B, N, D = patches.shape
165
+ mask = torch.zeros((B, N), dtype=torch.bool, device=patches.device)
166
+
167
+ # Compute encoder
168
+ encoder_out = self.module.compute_encoder(patches, mask, grid_size)
169
+ return encoder_out
170
+
171
+ def forward(
172
+ self, audio: torch.Tensor, audio_attention_mask=None
173
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
174
+ # audio: [B, T]
175
+ if audio.ndim == 1:
176
+ audio = audio.unsqueeze(0)
177
+
178
+ B, T = audio.shape
179
+
180
+ # 1. Handle Short Audio (Whole Batch)
181
+ if T < self.min_samples:
182
+ pad_amt = self.min_samples - T
183
+ audio = torch.nn.functional.pad(audio, (0, pad_amt))
184
+ T = self.min_samples # Update T
185
+
186
+ # 2. Sequential Chunking
187
+ if T <= self.chunk_samples:
188
+ # Single chunk processing
189
+ return self._forward_chunk(audio), None
190
+ else:
191
+ # Split into chunks of max length
192
+ chunks = torch.split(audio, self.chunk_samples, dim=1)
193
+ outputs = []
194
+
195
+ for chunk in chunks:
196
+ # Handle potentially short last chunk
197
+ chunk_len = chunk.shape[1]
198
+
199
+ if chunk_len < self.min_samples:
200
+ pad_amt = self.min_samples - chunk_len
201
+ chunk = torch.nn.functional.pad(chunk, (0, pad_amt))
202
+
203
+ # Process
204
+ out_chunk = self._forward_chunk(chunk)
205
+
206
+ # If we padded the last chunk solely to meet min_samples,
207
+ # should we slice? BestRQ2 output is patches.
208
+ # 1 patch covers `min_samples`.
209
+ # If original was < 1 patch, we produced 1 patch.
210
+ # We can't slice sub-patch. We just return the 1 patch.
211
+
212
+ outputs.append(out_chunk)
213
+
214
+ # Concatenate along sequence dimension (dim=1)
215
+ final_output = torch.cat(outputs, dim=1)
216
+
217
+ return final_output, None
218
+
219
+
220
+ if __name__ == "__main__":
221
+ try:
222
+ mdl = BestRQ2Encoder()
223
+ print("Model initialized successfully")
224
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+ mdl.module.to(device)
226
+ x = torch.randn(1, 160000).to(device)
227
+ y, _ = mdl(x)
228
+ print(f"Output shape: {y.shape}")
229
+ except Exception as e:
230
+ print(f"Error testing model: {e}")
231
+ import traceback
232
+
233
+ traceback.print_exc()
README.md CHANGED
@@ -1,3 +1,35 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BEST-RQ-2 (xares-llm encoder)
2
+
3
+ This folder contains the BEST-RQ-2 audio encoder integration for `xares-llm`.
4
+
5
+ Benchmark repository: [xiaomi-research/xares-llm](https://github.com/xiaomi-research/xares-llm)
6
+
7
+ ## Setup
8
+
9
+ 1. Make sure the environment to run `xares-llm` is set up properly (virtual environment initialized and the `xares-llm` package downloaded/installed).
10
+
11
+ 2. Add the `BEST-RQ-2` folder to the `xares-llm` (current) directory so it is available at `./BEST-RQ-2`.
12
+
13
+ 3. Before running a `xares-llm` evaluation, you must install the required packages for BEST-RQ-2:
14
+
15
+ ```bash
16
+ uv pip install -r BEST-RQ-2/audio-embeddings/pyproject.toml
17
+ ```
18
+
19
+ ## Run an evaluation
20
+
21
+ Single task (e.g. to test that everything works):
22
+
23
+ ```bash
24
+ uv run -m xares_llm.run BEST-RQ-2.BEST-RQ-2_encoder.BestRQ2Encoder <task> <task> # e.g. replace <task> with esc-50, score should be around 0.48
25
+ ```
26
+
27
+ All tasks:
28
+
29
+ ```bash
30
+ uv run -m xares_llm.run BEST-RQ-2.BEST-RQ-2_encoder.BestRQ2Encoder all
31
+ ```
32
+
33
+ ## Help
34
+
35
+ If you encounter any problems, contact: [ludovic.tuncay@irit.fr](mailto:ludovic.tuncay@irit.fr)
audio-embeddings/.githooks/post-checkout ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ repo_root="$(git rev-parse --show-toplevel)"
5
+ cd "$repo_root"
6
+
7
+ if ! command -v uv >/dev/null 2>&1; then
8
+ echo "[post-checkout] uv not found; skipping uv sync" >&2
9
+ exit 0
10
+ fi
11
+
12
+ echo "[post-checkout] Running uv sync..."
13
+ uv sync
audio-embeddings/.githooks/post-merge ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ repo_root="$(git rev-parse --show-toplevel)"
5
+ cd "$repo_root"
6
+
7
+ if ! command -v uv >/dev/null 2>&1; then
8
+ echo "[post-merge] uv not found; skipping uv sync" >&2
9
+ exit 0
10
+ fi
11
+
12
+ echo "[post-merge] Running uv sync..."
13
+ uv sync
audio-embeddings/.githooks/post-rewrite ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ repo_root="$(git rev-parse --show-toplevel)"
5
+ cd "$repo_root"
6
+
7
+ if ! command -v uv >/dev/null 2>&1; then
8
+ echo "[post-rewrite] uv not found; skipping uv sync" >&2
9
+ exit 0
10
+ fi
11
+
12
+ echo "[post-rewrite] Running uv sync..."
13
+ uv sync
audio-embeddings/.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Project logs
13
+ logs/
14
+
15
+ # Large data files
16
+ *.h5
17
+
18
+ # Data
19
+ data/
20
+
21
+ # Jupyter
22
+ .ipynb_checkpoints/
23
+
24
+ # macOS Finder metadata
25
+ .DS_Store
26
+
27
+ # Explicitly untracked large docs
28
+ documents/LeJEPA.pdf
29
+ documents/Audio-LeJEPA.pdf
audio-embeddings/.pre-commit-config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ # Trim whitespace at end of lines.
6
+ - id: trailing-whitespace
7
+ # Ensure files end with a single newline.
8
+ - id: end-of-file-fixer
9
+ # Catch unresolved merge conflict markers.
10
+ - id: check-merge-conflict
11
+ # Validate YAML syntax (Hydra configs, etc.).
12
+ - id: check-yaml
13
+ # Validate TOML syntax (e.g., pyproject.toml).
14
+ - id: check-toml
15
+ # Detect mixed CRLF/LF endings without auto-rewriting.
16
+ - id: mixed-line-ending
17
+ args: ["--fix=no"]
18
+ # Catch accidental debug leftovers (breakpoint/pdb).
19
+ - id: debug-statements
20
+ # Block accidental private key commits.
21
+ - id: detect-private-key
22
+ # Prevent case-colliding paths across filesystems.
23
+ - id: check-case-conflict
24
+ # Block newly added large files unless explicitly allowlisted below.
25
+ - id: check-added-large-files
26
+ args: ["--maxkb=5000"]
27
+ exclude: ^(documents/LeJEPA\.pdf|documents/Audio-LeJEPA\.pdf)$
28
+
29
+ - repo: https://github.com/astral-sh/ruff-pre-commit
30
+ # Ruff version.
31
+ rev: v0.15.0
32
+ hooks:
33
+ # Run the linter.
34
+ - id: ruff-check
35
+ types_or: [python, pyi]
36
+ args: [--fix]
37
+ # Run the formatter.
38
+ - id: ruff-format
39
+ types_or: [python, pyi]
audio-embeddings/.project-root ADDED
File without changes
audio-embeddings/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
audio-embeddings/AGENTS.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGENTS Guide - audio-embeddings
2
+
3
+ This file is for coding agents working in this repository.
4
+ Follow these repo-specific rules over generic defaults.
5
+
6
+ ## 1) Environment Snapshot
7
+ - Python: `>=3.12` (from `pyproject.toml`).
8
+ - Dependency manager: `uv`.
9
+ - Main stack: PyTorch, PyTorch Lightning, Hydra, OmegaConf.
10
+ - Project root marker: `.project-root`.
11
+ - Main entrypoint: `src/train.py`.
12
+
13
+ ## 2) Cursor / Copilot Rule Files
14
+ - Checked `.cursor/rules/`: not present.
15
+ - Checked `.cursorrules`: not present.
16
+ - Checked `.github/copilot-instructions.md`: not present.
17
+ - Therefore, no additional Cursor/Copilot rule files are currently enforced.
18
+
19
+ ## 3) Install / Setup Commands
20
+ ```bash
21
+ uv sync
22
+ uv run <command>
23
+ uv add <package>
24
+ ```
25
+
26
+ ## 4) Build / Train / Eval Commands
27
+ There is no separate "build" step (this is a training codebase).
28
+ Use quick-run training as the integration sanity check.
29
+ ```bash
30
+ uv run src/train.py
31
+ uv run src/train.py trainer.fast_dev_run=True
32
+ uv run src/train.py trainer=cpu trainer.fast_dev_run=True
33
+ uv run src/train.py experiment=local/audio_jepa
34
+ uv run src/train.py trainer.max_epochs=10 data.batch_size=32 model.optimizer.lr=1e-4
35
+ ```
36
+ Cluster-style execution (existing project pattern):
37
+ ```bash
38
+ srun .venv/bin/python -u -O src/train.py experiment=cluster_jepa_audioset_rope +trainer.max_time="00:19:50:00"
39
+ ```
40
+
41
+ ## 5) Lint / Formatting / Static Checks
42
+ Use the commands below as pragmatic checks:
43
+ ```bash
44
+ uv run pre-commit run --all-files
45
+ uv run pre-commit run ruff --all-files
46
+ uv run pre-commit run ruff-format --all-files
47
+ uv run python -m compileall src
48
+ ```
49
+ Ruff is configured via `.pre-commit-config.yaml` and runs both lint fixes and formatting.
50
+
51
+ ## 6) Test Commands (Including Single Test)
52
+ Primary validation in this repo is script-based verification under `tests/`.
53
+ Run test files directly as native Python files:
54
+ ```bash
55
+ uv run tests/verify_rope.py
56
+ uv run tests/verify_custom_rope.py
57
+ uv run tests/verify_data.py
58
+ ```
59
+ Useful single-file checks (native execution):
60
+ ```bash
61
+ uv run src/train.py trainer.fast_dev_run=True
62
+ uv run src/train.py trainer=cpu trainer.fast_dev_run=True
63
+ uv run scripts/verify_shapes.py
64
+ uv run scripts/verify_scheduler.py
65
+ ```
66
+ Notes:
67
+ - `tests/test_*.py` are pytest-style and are not part of the default native-file workflow.
68
+ - Prefer `tests/verify_*.py` and `scripts/verify_*.py` for lightweight checks.
69
+
70
+ ## 7) Repository Architecture Expectations
71
+ - `configs/`: Hydra composition (trainer/data/model/logger/callbacks/experiment).
72
+ - `src/train.py`: orchestration only (instantiate and run).
73
+ - `src/models/`: LightningModules (high-level training logic).
74
+ - `src/models/components/`: reusable `nn.Module` building blocks.
75
+ - `src/data/`: DataModules/Datasets and collate logic.
76
+ - `src/utils/`: logging, instantiation, wrappers, scheduler helpers.
77
+ When possible, prefer config changes over hardcoded Python changes.
78
+
79
+ ## 8) Code Style Guidelines
80
+ ### Imports
81
+ - Group imports as: standard library -> third-party -> local `src.*`.
82
+ - Keep one import per line unless importing multiple names from same module.
83
+ - Avoid wildcard imports.
84
+ - Prefer absolute imports from `src...`.
85
+
86
+ ### Formatting
87
+ - Use 4-space indentation and readable line lengths.
88
+ - Keep functions small; extract helpers for complex logic.
89
+ - Do not introduce unrelated reformatting in touched files.
90
+ - Keep comments for non-obvious intent, not obvious mechanics.
91
+
92
+ ### Typing
93
+ - Type hints are expected for function arguments and return values.
94
+ - Use concrete tensor/container types when practical.
95
+ - Use `Optional[T]` / `T | None` consistently within a file.
96
+ - For dict-like configs, type as `DictConfig` when passing Hydra config objects.
97
+
98
+ ### Naming
99
+ - `snake_case`: functions, variables, module filenames.
100
+ - `PascalCase`: classes (`AudioJEPAModule`, `AudioSetDataModule`).
101
+ - `UPPER_SNAKE_CASE`: constants.
102
+ - Prefer descriptive names (`mask_indices`) over short names (`m2`) except local math temporaries.
103
+
104
+ ### PyTorch / Lightning / Hydra Conventions
105
+ - Keep heavy compute out of `__init__` where possible.
106
+ - `forward()` for inference logic; training behavior in `training_step()`.
107
+ - Use `self.log(...)` with explicit flags (`on_step`, `on_epoch`, `prog_bar`, `batch_size`).
108
+ - Instantiate components through Hydra (`hydra.utils.instantiate`).
109
+ - Expose tunable parameters in config files, not hardcoded literals.
110
+
111
+ ### Error Handling and Validation
112
+ - Raise informative `ValueError` / `RuntimeError` for invalid config/state.
113
+ - Validate critical tensor assumptions with assertions or explicit checks.
114
+ - Prefer logger/warnings over bare `print()` in new code.
115
+ - For file I/O, prefer `pathlib.Path` and existence checks.
116
+
117
+ ### Data and Paths
118
+ - Do not hardcode absolute machine paths.
119
+ - Use `rootutils.setup_root(..., indicator=".project-root", pythonpath=True)` in entrypoints/scripts when needed.
120
+ - Respect `cfg.paths.*` outputs for logs/checkpoints/artifacts.
121
+
122
+ ## 9) Agent Workflow Rules
123
+ - Reuse existing components before adding new abstractions.
124
+ - Keep `src/train.py` generic; place model/data logic in dedicated modules.
125
+ - Prefer minimal, focused diffs.
126
+ - Update configs and docs when behavior changes.
127
+ - Validate with the smallest meaningful command first (`fast_dev_run`, single test), then broader checks.
128
+
129
+ ## 10) Git / Change Hygiene
130
+ - Do not revert unrelated local changes.
131
+ - Keep commits scoped to one concern.
132
+ - Write clear commit messages describing intent.
133
+ - Prefer Conventional Commit-like format: `type(scope): intent`.
134
+ - Common types in this repo: `feat`, `fix`, `conf`, `build`, `docs`, `style`, `chore`.
135
+ - Never commit secrets, credentials, or environment-specific absolute paths.
136
+
137
+ ## 11) Practical Agent Defaults
138
+ - Prefer reusing existing modules over creating new abstractions.
139
+ - Keep edits local to the requested change; avoid drive-by refactors.
140
+ - Run the smallest useful verification command after changes.
141
+ - If you touch training logic, run at least one fast training sanity check.
142
+ - If you touch model components, run relevant verify script(s) in `tests/`.
143
+ - If you touch Hydra config wiring, run a config-backed entry command via `uv run src/train.py ...`.
144
+
145
+ ## 12) Common Pitfalls
146
+ - Avoid hardcoding data paths; use config (`cfg.paths`, data config fields).
147
+ - Avoid printing in new code paths; use ranked loggers/warnings.
148
+ - Avoid putting heavy tensor compute in constructors.
149
+ - Avoid bypassing Hydra by manually instantiating configurable components.
150
+ - Avoid changing unrelated formatting in files you touch.
audio-embeddings/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio Embeddings with Lightning & Hydra
2
+
3
+ This project is a clean, modular, and scalable implementation of audio embedding models using **PyTorch Lightning** and **Hydra**. It is designed to be easily extensible and runnable on local or cluster environments. It is based on the [Audio-JEPA](https://github.com/LudovicTuncay/Audio-JEPA) implementation and therefore implements the Audio-JEPA architecture. Other architecture can and will be added in the future.
4
+
5
+ ## 🎯 Goal
6
+
7
+ The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include:
8
+ - **Modular Architecture**: Components like Spectrogram, Masking, and ViT are decoupled.
9
+ - **Configurable Positional Embeddings**: Support for **RoPE** (2D Rotary Embeddings), **SinCos** (2D Sinusoidal), and **Learnable** embeddings.
10
+ - **Hydra Configuration**: flexible experiment management via hierarchical config files.
11
+ - **Lightning Trainer**: Simplified training loop, logging, and checkpointing.
12
+ - **Modern Tooling**: Uses `uv` for fast and reliable dependency management.
13
+
14
+ ## 🚀 Installation
15
+
16
+ This project uses [`uv`](https://github.com/astral-sh/uv) for dependency management.
17
+
18
+ 1. **Install `uv`** (if not already installed):
19
+ ```bash
20
+ curl -LsSf https://astral.sh/uv/install.sh | sh
21
+ ```
22
+
23
+ 2. **Clone the repository**:
24
+ ```bash
25
+ git clone <repository_url>
26
+ cd audio-embeddings
27
+ ```
28
+
29
+ 3. **Install dependencies**:
30
+ ```bash
31
+ uv sync
32
+ ```
33
+
34
+ 4. **Enable shared git hooks** (runs `uv sync` after merge/checkout/rewrite):
35
+ ```bash
36
+ git config core.hooksPath .githooks
37
+ ```
38
+
39
+ ## 🏃 Usage
40
+
41
+ ### Basic Training
42
+ To start training with the default configuration:
43
+ ```bash
44
+ uv run src/train.py
45
+ ```
46
+
47
+ ### Common Commands
48
+ Run on GPU with Weights & Biases logging:
49
+ ```bash
50
+ uv run src/train.py trainer=gpu logger=wandb
51
+ ```
52
+
53
+ Override hyperparameters on the command line:
54
+ ```bash
55
+ uv run src/train.py data.batch_size=64 trainer.max_epochs=50
56
+ ```
57
+
58
+ ### Configurable Positional Embeddings
59
+ You can switch between different positional embedding strategies easily:
60
+
61
+ **RoPE**:
62
+ ```bash
63
+ uv run src/train.py model.net.encoder.pos_embed_type=rope
64
+ ```
65
+
66
+ ### Offline WandB Logging with Model Checkpoints
67
+ To run training offline but still have model checkpoints staged for upload (which standard WandB restricts):
68
+
69
+ ```bash
70
+ uv run src/train.py \
71
+ logger=wandb \
72
+ logger.wandb.offline=True \
73
+ logger.wandb.log_model=False \
74
+ +callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \
75
+ trainer=gpu trainer.devices=1 \
76
+ data.batch_size=128 trainer.max_epochs=100
77
+ ```
78
+ These checkpoints will be uploaded when you run `wandb sync`.
79
+
80
+
81
+ **2D SinCos**:
82
+ ```bash
83
+ uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincos
84
+ ```
85
+
86
+ **Learnable**:
87
+ ```bash
88
+ uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnable
89
+ ```
90
+
91
+ ## 📂 Project Structure
92
+
93
+ ```text
94
+ ├── configs/ # Hydra configuration files
95
+ │ ├── callbacks/ # Callback configs (checkpoints, early stopping)
96
+ │ ├── data/ # Data configs (AudioSet, etc.)
97
+ │ ├── logger/ # Logger configs (WandB, Tensorboard)
98
+ │ ├── model/ # Model configs (AudioJEPA parameters)
99
+ │ ├── trainer/ # Trainer configs (CPU, GPU, strategies)
100
+ │ └── train.yaml # Main configuration entry point
101
+ ├── src/
102
+ │ ├── data/ # Data loading logic
103
+ │ │ └── audioset_datamodule.py # AudioSet DataModule & Dataset
104
+ │ ├── models/ # Model architectures
105
+ │ │ ├── components/ # Reusable blocks
106
+ │ │ │ ├── masking.py # Masking generators
107
+ │ │ │ ├── patch_embed.py # Patchification
108
+ │ │ │ ├── rope.py # 2D Rotary Embeddings
109
+ │ │ │ ├── spectrogram.py # Audio preprocessing
110
+ │ │ │ └── vit.py # Vision Transformer (Student/Teacher/Predictor)
111
+ │ │ └── audio_jepa_module.py # Main LightningModule
112
+ │ ├── utils/ # Utility functions
113
+ │ └── train.py # Training entry point
114
+ ├── scripts/ # Helper scripts
115
+ ├── tests/ # Verification tests
116
+ ├── pyproject.toml # Project dependencies
117
+ └── README.md # This file
118
+ ```
119
+
120
+ ## 🛠️ Extensibility
121
+
122
+ ### Adding a New Model
123
+ 1. Create your model components in `src/models/components/`.
124
+ 2. Create a new LightningModule in `src/models/` (or update `AudioJEPAModule`).
125
+ 3. Create a new config file in `configs/model/my_new_model.yaml`.
126
+ 4. Run with `uv run src/train.py model=my_new_model`.
127
+
128
+ ### Adding a New Dataset
129
+ 1. Create a new DataModule in `src/data/`.
130
+ 2. Create a new config file in `configs/data/my_dataset.yaml`.
131
+ 3. Run with `uv run src/train.py data=my_dataset`.
132
+
133
+ ### Adding Functionalities
134
+ - **Callbacks**: Add custom callbacks in `src/callbacks/` (if needed) or use existing Lightning callbacks, and configure them in `configs/callbacks/`.
135
+ - **Metrics**: Add metrics logging in `training_step` or `validation_step` inside `src/models/audio_jepa_module.py`.
136
+
137
+ ## 🧪 Testing
138
+ Run verification scripts to ensure components are working:
139
+ ```bash
140
+ uv run tests/verify_rope.py
141
+ uv run tests/verify_custom_rope.py
142
+ ```
audio-embeddings/THIRD_PARTY_LICENSES.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Third-Party Licenses
2
+
3
+ This project vendors third-party source code listed below.
4
+
5
+ ## Lightning-AI / pytorch-lightning
6
+
7
+ - Component: `src/callbacks/lightning_weight_averaging.py`
8
+ - Source repository: `https://github.com/Lightning-AI/pytorch-lightning`
9
+ - Source file: `src/lightning/pytorch/callbacks/weight_averaging.py`
10
+ - Pinned commit: `9bcba1c1e82b45e10f948dc28fc12f4cf04ab736`
11
+ - License: Apache License 2.0
12
+
13
+ See `licenses/APACHE-2.0-LIGHTNING.txt` for the full license text.
audio-embeddings/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # this file is needed here to include configs when building project as a package
audio-embeddings/configs/callbacks/default.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model_checkpoint
3
+ - model_summary
4
+ - rich_progress_bar
5
+ - _self_
6
+
7
+ model_checkpoint:
8
+ dirpath: ${paths.output_dir}/checkpoints
9
+ filename: "best-step-{step:06d}"
10
+ monitor: "val/loss"
11
+ mode: "min"
12
+ save_last: True
13
+ save_weights_only: False
14
+ save_top_k: 0
15
+ auto_insert_metric_name: False
16
+
17
+ safetensors:
18
+ _target_: src.callbacks.safetensors_callback.SafetensorsCallback
19
+ cleanup_orphan_safetensors: True
20
+
21
+ # early_stopping:
22
+ # monitor: "val/loss"
23
+ # patience: 100
24
+ # mode: "min"
25
+
26
+ model_summary:
27
+ max_depth: 1
28
+
29
+ device_stats:
30
+ _target_: lightning.pytorch.callbacks.DeviceStatsMonitor
31
+
32
+ visualization:
33
+ _target_: src.callbacks.visualization_callback.VisualizationCallback
34
+ num_samples: 4
audio-embeddings/configs/callbacks/early_stopping.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
+
3
+ early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
+ monitor: ??? # quantity to be monitored, must be specified !!!
6
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
+ patience: 3 # number of checks with no improvement after which training will be stopped
8
+ verbose: False # verbosity mode
9
+ mode: "min" # "max" means higher metric value is better, can be also "min"
10
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
11
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
audio-embeddings/configs/callbacks/ema_weight_averaging.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ema_weight_averaging:
2
+ _target_: src.callbacks.ema_weight_averaging.WarmupEMAWeightAveraging
3
+ warmup_pct: ${model.warmup_pct}
4
+ enabled: ${model.ema.enabled}
5
+ decay: ${model.ema.decay} # decay rate is 1 - decay_numerator / ( total_steps - warmup_steps )
6
+ decay_numerator: ${model.ema.decay_numerator} # decay rate is 1 - decay_numerator / ( total_steps - warmup_steps )
7
+ update_every_n_steps: ${model.ema.update_every_n_steps}
8
+ use_buffers: ${model.ema.use_buffers}
9
+ update_starting_at_step: null
audio-embeddings/configs/callbacks/model_checkpoint.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
+
3
+ model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
+ dirpath: null # directory to save the model file
6
+ filename: null # checkpoint filename
7
+ monitor: null # name of the logged metric which determines when model is improving
8
+ verbose: False # verbosity mode
9
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10
+ save_top_k: 1 # save k best models (determined by above metric)
11
+ mode: "min" # "max" means higher metric value is better, can be also "min"
12
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13
+ save_weights_only: False # if True, then only the model’s weights will be saved
14
+ every_n_train_steps: null # number of training steps between checkpoints
15
+ train_time_interval: null # checkpoints are monitored at the specified time interval
16
+ every_n_epochs: null # number of epochs between checkpoints
17
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
audio-embeddings/configs/callbacks/model_summary.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
+
3
+ model_summary:
4
+ _target_: lightning.pytorch.callbacks.RichModelSummary
5
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
audio-embeddings/configs/callbacks/none.yaml ADDED
File without changes
audio-embeddings/configs/callbacks/rich_progress_bar.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2
+
3
+ rich_progress_bar:
4
+ _target_: lightning.pytorch.callbacks.RichProgressBar
audio-embeddings/configs/callbacks/wandb_offline.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wandb_offline_checkpoint:
2
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/data/audioset.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.audioset_datamodule.AudioSetDataModule
2
+ data_dir: ${paths.data_dir}/AudioSet
3
+ batch_size: 64
4
+ num_workers: 4
5
+ pin_memory: True
6
+ train_h5: full_unbal_bal_train_wav.h5
7
+ train_csv: silent_files_full_unbal_bal_train_wav.csv
8
+ val_h5: eval_soxrhq.h5
9
+ val_csv: silent_files_eval_soxrhq.csv
10
+ max_audio_length_sec: 10.0 # 10 seconds
11
+ target_sample_rate: 16000
12
+ collate_mode: pad
audio-embeddings/configs/data/mock_audioset.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: src.data.mock_audioset_datamodule.MockAudioSetDataModule
2
+ batch_size: 16
3
+ num_workers: 0 # 0 is often better for simple debugging/validation on local Mac
4
+ pin_memory: True
5
+ max_audio_length_sec: 10.0
6
+ target_sample_rate: 16000
7
+ collate_mode: pad
audio-embeddings/configs/data/yt1b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.yt1b_datamodule.YT1BDataModule
2
+ data_dir: ${paths.data_dir}/YT-Temporal-1B
3
+ batch_size: 64
4
+ num_workers: 4
5
+ pin_memory: True
6
+ train_parquet: train_metadata.parquet
7
+ val_parquet: val_metadata.parquet
8
+ test_parquet: val_metadata.parquet
9
+ max_audio_length_sec: 10.0
10
+ min_duration_sec: 10.0
11
+ target_sample_rate: 16000
12
+ collate_mode: pad
13
+ decode_window_sec: null
audio-embeddings/configs/experiment/audio_jepa/baseline.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: audioset
5
+ - override /model: audio_jepa
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["audioset", "jepa", "baseline", "cluster"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ data_dir: ${paths.data_dir}
20
+ train_h5: AudioSet/full_unbal_bal_train_wav.h5
21
+ val_h5: AudioSet/eval_soxrhq.h5
22
+ batch_size: 256
23
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
24
+
25
+ logger:
26
+ wandb:
27
+ offline: True
28
+ log_model: False
29
+
30
+ callbacks:
31
+ rich_progress_bar: null
32
+
33
+ wandb_offline_checkpoint:
34
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/audio_jepa/large.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: audioset
5
+ - override /model: audio_jepa
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["audioset", "jepa", "large", "cluster", "1GPU"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ data_dir: ${paths.data_dir}
20
+ train_h5: AudioSet/full_unbal_bal_train_wav.h5
21
+ val_h5: AudioSet/eval_soxrhq.h5
22
+ batch_size: 256
23
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
24
+
25
+ logger:
26
+ wandb:
27
+ offline: True
28
+ log_model: False
29
+
30
+ callbacks:
31
+ rich_progress_bar: null
32
+
33
+ wandb_offline_checkpoint:
34
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
35
+
36
+ # ViT Large Configuration
37
+ model:
38
+ net:
39
+ patch_embed:
40
+ embed_dim: 1024
41
+ encoder:
42
+ embed_dim: 1024
43
+ depth: 24
44
+ num_heads: 16
audio-embeddings/configs/experiment/audio_jepa/rope.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: audioset
5
+ - override /model: audio_jepa
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["audioset", "jepa", "RoPE", "cluster"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ data_dir: ${paths.data_dir}
20
+ train_h5: AudioSet/full_unbal_bal_train_wav.h5
21
+ val_h5: AudioSet/eval_soxrhq.h5
22
+ batch_size: 256
23
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
24
+
25
+ logger:
26
+ wandb:
27
+ offline: True
28
+ log_model: False
29
+
30
+ callbacks:
31
+ rich_progress_bar: null
32
+
33
+ wandb_offline_checkpoint:
34
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
35
+
36
+ model:
37
+ net:
38
+ encoder:
39
+ pos_embed_type: rope
40
+ predictor:
41
+ pos_embed_type: rope
audio-embeddings/configs/experiment/audio_jepa/time_res2x.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: audioset
5
+ - override /model: audio_jepa
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["audioset", "jepa", "time_res2x", "cluster"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ data_dir: ${paths.data_dir}
20
+ train_h5: AudioSet/full_unbal_bal_train_wav.h5
21
+ val_h5: AudioSet/eval_soxrhq.h5
22
+ batch_size: 256
23
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
24
+
25
+ logger:
26
+ wandb:
27
+ offline: True
28
+ log_model: False
29
+
30
+ callbacks:
31
+ rich_progress_bar: null
32
+
33
+ wandb_offline_checkpoint:
34
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
35
+
36
+ model:
37
+ net:
38
+ spectrogram:
39
+ win_length_ms: 64 # halved
40
+ hop_length_ms: 19.53125 # halved
41
+
42
+ patch_embed:
43
+ img_size: [128, 512] # width doubled because hop is halved
44
+
45
+ masking:
46
+ input_size: [128, 512]
47
+
48
+ encoder:
49
+ num_patches: 256 # (128/16) * (512/16) = 8 * 32 = 256
50
+ img_size: [128, 512] # Explicitly set img_size for ViT to match patch_embed
51
+
52
+ predictor:
53
+ num_patches: 256
54
+ img_size: [128, 512]
audio-embeddings/configs/experiment/audio_jepa/time_res4x.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: audioset
5
+ - override /model: audio_jepa
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["audioset", "jepa", "time_res4x", "cluster"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ data_dir: ${paths.data_dir}
20
+ train_h5: AudioSet/full_unbal_bal_train_wav.h5
21
+ val_h5: AudioSet/eval_soxrhq.h5
22
+ batch_size: 256
23
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
24
+
25
+ logger:
26
+ wandb:
27
+ offline: True
28
+ log_model: False
29
+
30
+ callbacks:
31
+ rich_progress_bar: null
32
+
33
+ wandb_offline_checkpoint:
34
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
35
+
36
+ model:
37
+ net:
38
+ spectrogram:
39
+ win_length_ms: 32 # quartered
40
+ hop_length_ms: 9.765625 # quartered
41
+
42
+ patch_embed:
43
+ img_size: [128, 1024] # width quadrupled because hop is quartered
44
+
45
+ masking:
46
+ input_size: [128, 1024]
47
+
48
+ encoder:
49
+ num_patches: 512 # (128/16) * (1024/16) = 8 * 64 = 512
50
+ img_size: [128, 1024]
51
+
52
+ predictor:
53
+ num_patches: 512
54
+ img_size: [128, 1024]
audio-embeddings/configs/experiment/best_rq/audioset.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq/audioset
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "best-rq", "baseline", "cluster GPU"]
16
+
17
+ trainer:
18
+ max_steps: 200000
19
+
20
+ data:
21
+ batch_size: 256
22
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
23
+
24
+ logger:
25
+ wandb:
26
+ offline: True
27
+ log_model: False
28
+
29
+ callbacks:
30
+ rich_progress_bar: null
31
+
32
+ wandb_offline_checkpoint:
33
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq/yt1b.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: yt1b
5
+ - override /model: best_rq
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["yt1b", "best-rq", "baseline", "cluster GPU"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ batch_size: 256
20
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
21
+
22
+ logger:
23
+ wandb:
24
+ offline: True
25
+ log_model: False
26
+
27
+ callbacks:
28
+ rich_progress_bar: null
29
+
30
+ wandb_offline_checkpoint:
31
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "best-rq-2", "baseline", "cluster GPU"]
16
+
17
+ trainer:
18
+ max_steps: 200000
19
+
20
+ data:
21
+ batch_size: 256
22
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
23
+
24
+ logger:
25
+ wandb:
26
+ offline: True
27
+ log_model: False
28
+
29
+ callbacks:
30
+ rich_progress_bar: null
31
+
32
+ wandb_offline_checkpoint:
33
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_100k_512bs.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_100k_512bs
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ tags: ["audioset", "best-rq-2", "100k", "512bs", "cluster GPU"]
13
+
14
+ trainer:
15
+ max_steps: 100000
16
+
17
+ data:
18
+ batch_size: 512
19
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
20
+
21
+ logger:
22
+ wandb:
23
+ offline: True
24
+ log_model: False
25
+
26
+ callbacks:
27
+ rich_progress_bar: null
28
+
29
+ wandb_offline_checkpoint:
30
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_1m_128bs_4gpu.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_1m_128bs_4gpu
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: ddp
10
+ - override /logger: wandb
11
+
12
+ tags: ["audioset", "best-rq-2", "1m", "128bs", "4GPU", "cluster GPU"]
13
+
14
+ trainer:
15
+ max_steps: 1000000
16
+
17
+ data:
18
+ batch_size: 128
19
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
20
+
21
+ logger:
22
+ wandb:
23
+ offline: True
24
+ log_model: False
25
+
26
+ callbacks:
27
+ rich_progress_bar: null
28
+
29
+ wandb_offline_checkpoint:
30
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_200k_256bs_4gpu.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_200k_256bs_4gpu
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: ddp
10
+ - override /logger: wandb
11
+
12
+ tags: ["audioset", "best-rq-2", "200k", "256bs", "4GPU", "cluster GPU"]
13
+
14
+ trainer:
15
+ max_steps: 200000
16
+
17
+ data:
18
+ batch_size: 256
19
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
20
+
21
+ logger:
22
+ wandb:
23
+ offline: True
24
+ log_model: False
25
+
26
+ callbacks:
27
+ rich_progress_bar: null
28
+
29
+ wandb_offline_checkpoint:
30
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_400k_128bs
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ tags: ["audioset", "best-rq-2", "400k", "128bs", "cluster GPU"]
13
+
14
+ trainer:
15
+ max_steps: 400000
16
+
17
+ data:
18
+ batch_size: 128
19
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
20
+
21
+ logger:
22
+ wandb:
23
+ offline: True
24
+ log_model: False
25
+
26
+ callbacks:
27
+ rich_progress_bar: null
28
+
29
+ wandb_offline_checkpoint:
30
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_400k_128bs_4gpu.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_400k_128bs_4gpu
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: ddp
10
+ - override /logger: wandb
11
+
12
+ tags: ["audioset", "best-rq-2", "400k", "128bs", "4GPU", "cluster GPU"]
13
+
14
+ trainer:
15
+ max_steps: 400000
16
+
17
+ data:
18
+ batch_size: 128
19
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
20
+
21
+ logger:
22
+ wandb:
23
+ offline: True
24
+ log_model: False
25
+
26
+ callbacks:
27
+ rich_progress_bar: null
28
+
29
+ wandb_offline_checkpoint:
30
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_800k_64bs_4gpu.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_800k_64bs_4gpu
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: ddp
10
+ - override /logger: wandb
11
+
12
+ tags: ["audioset", "best-rq-2", "800k", "64bs", "4GPU", "cluster GPU"]
13
+
14
+ trainer:
15
+ max_steps: 800000
16
+
17
+ data:
18
+ batch_size: 64
19
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
20
+
21
+ logger:
22
+ wandb:
23
+ offline: True
24
+ log_model: False
25
+
26
+ callbacks:
27
+ rich_progress_bar: null
28
+
29
+ wandb_offline_checkpoint:
30
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/best_rq_2/audioset_ema.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_ema
5
+
6
+ defaults:
7
+ - /callbacks/ema_weight_averaging
8
+ - override /data: audioset
9
+ - override /model: best_rq2
10
+ - override /trainer: gpu
11
+ - override /logger: wandb
12
+
13
+ tags: ["audioset", "best-rq-2", "ema", "cluster GPU"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ batch_size: 256
20
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
21
+
22
+ logger:
23
+ wandb:
24
+ offline: True
25
+ log_model: False
26
+
27
+ callbacks:
28
+ rich_progress_bar: null
29
+
30
+ wandb_offline_checkpoint:
31
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
32
+
33
+ model:
34
+ ema:
35
+ enabled: true
audio-embeddings/configs/experiment/best_rq_2/audioset_ema_600k.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/audioset_ema_600k
5
+
6
+ defaults:
7
+ - /callbacks/ema_weight_averaging
8
+ - override /data: audioset
9
+ - override /model: best_rq2
10
+ - override /trainer: gpu
11
+ - override /logger: wandb
12
+
13
+ tags: ["audioset", "best-rq-2", "ema", "600k", "cluster GPU"]
14
+
15
+ trainer:
16
+ max_steps: 600000
17
+
18
+ data:
19
+ batch_size: 256
20
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
21
+
22
+ logger:
23
+ wandb:
24
+ offline: True
25
+ log_model: False
26
+
27
+ callbacks:
28
+ rich_progress_bar: null
29
+
30
+ wandb_offline_checkpoint:
31
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
32
+
33
+ model:
34
+ ema:
35
+ enabled: true
36
+ decay_numerator: 40.0
audio-embeddings/configs/experiment/best_rq_2/yt1b_ema.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=best_rq_2/yt1b_ema
5
+
6
+ defaults:
7
+ - /callbacks/ema_weight_averaging
8
+ - override /data: yt1b
9
+ - override /model: best_rq2
10
+ - override /trainer: gpu
11
+ - override /logger: wandb
12
+
13
+ tags: ["yt1b", "best-rq-2", "ema", "cluster GPU"]
14
+
15
+ trainer:
16
+ max_steps: 600000
17
+
18
+ data:
19
+ batch_size: 256
20
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
21
+
22
+ logger:
23
+ wandb:
24
+ offline: True
25
+ log_model: False
26
+
27
+ callbacks:
28
+ rich_progress_bar: null
29
+
30
+ wandb_offline_checkpoint:
31
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
32
+
33
+ model:
34
+ warmup_pct: 0.03
35
+ ema:
36
+ enabled: true
37
+ decay_numerator: 40.0
audio-embeddings/configs/experiment/local/audio_jepa.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=audioset_baseline
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: audio_jepa
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "jepa", "baseline", "local GPU"]
16
+
17
+ trainer:
18
+ devices: 1
19
+ max_steps: 200000
20
+
21
+ data:
22
+ batch_size: 128
23
+ num_workers: 16
24
+
25
+ logger:
26
+ wandb:
27
+ name: "local-jepa-audioset-baseline-200k-128x1bs"
28
+ offline: False
29
+ log_model: True
audio-embeddings/configs/experiment/local/audio_jepa_rope.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=audioset_rope
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: audio_jepa
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "jepa", "rope", "local GPU"]
16
+
17
+ trainer:
18
+ devices: 1
19
+ max_steps: 200000
20
+
21
+ data:
22
+ batch_size: 128
23
+ num_workers: 16
24
+
25
+ logger:
26
+ wandb:
27
+ name: "local-jepa-audioset-rope-200k-128x1bs"
28
+ offline: False
29
+ log_model: True
30
+
31
+ model:
32
+ net:
33
+ encoder:
34
+ pos_embed_type: rope
35
+ predictor:
36
+ pos_embed_type: rope
audio-embeddings/configs/experiment/local/best_rq.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=local/best_rq
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "best-rq", "baseline", "local GPU"]
16
+
17
+ trainer:
18
+ devices: 1
19
+ max_steps: 200000
20
+
21
+ data:
22
+ batch_size: 128
23
+ num_workers: 16
24
+
25
+ logger:
26
+ wandb:
27
+ name: "local-best-rq-vit-audioset-200k-128x1bs"
28
+ offline: False
29
+ log_model: True
audio-embeddings/configs/experiment/local/best_rq2.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=local/best_rq2
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: best_rq2
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "best-rq-2", "baseline", "local GPU"]
16
+
17
+ trainer:
18
+ devices: 1
19
+ max_steps: 100000
20
+
21
+ data:
22
+ batch_size: 128
23
+ num_workers: 16
24
+
25
+ logger:
26
+ wandb:
27
+ name: "cluster-best-rq-2-audioset-100k-128bs"
28
+ offline: True
29
+ log_model: False
30
+
31
+ callbacks:
32
+ # model_checkpoint:
33
+ # save_weights_only: True
34
+
35
+ rich_progress_bar: null
36
+
37
+ wandb_offline_checkpoint:
38
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/local/m4_mock_jepa.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=local/m4_mock_jepa
5
+
6
+ defaults:
7
+ - override /data: mock_audioset
8
+ - override /model: audio_jepa
9
+ - override /trainer: mps
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["mock", "jepa", "local", "mps"]
16
+
17
+ trainer:
18
+ # devices: 1 # set in trainer/mps.yaml
19
+ max_steps: 100
20
+ log_every_n_steps: 1
21
+ val_check_interval: 50
22
+ limit_val_batches: 5
23
+
24
+ data:
25
+ batch_size: 8
26
+
27
+ logger:
28
+ wandb:
29
+ name: "local-m4-mock-jepa"
30
+ offline: True
31
+ log_model: False
32
+
33
+ callbacks:
34
+ model_checkpoint:
35
+ save_weights_only: True
36
+ monitor: null # verify we can save without monitoring
37
+ device_stats: null
audio-embeddings/configs/experiment/local/rqa_jepa.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=local_rqa_jepa_audioset
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: rqa_jepa
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "rqa-jepa", "baseline", "local GPU"]
16
+
17
+ trainer:
18
+ devices: 1
19
+ max_steps: 200000
20
+
21
+ data:
22
+ batch_size: 128
23
+ num_workers: 16
24
+
25
+ logger:
26
+ wandb:
27
+ name: "local-rqa-jepa-audioset-200k-128x1bs"
28
+ offline: False
29
+ log_model: True
audio-embeddings/configs/experiment/rqa_jepa/audioset.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=cluster_rqa_jepa_audioset
5
+
6
+ defaults:
7
+ - override /data: audioset
8
+ - override /model: rqa_jepa
9
+ - override /trainer: gpu
10
+ - override /logger: wandb
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["audioset", "rqa-jepa", "baseline", "cluster GPU"]
16
+
17
+ trainer:
18
+ max_steps: 200000
19
+
20
+ data:
21
+ batch_size: 256
22
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
23
+
24
+ logger:
25
+ wandb:
26
+ offline: True
27
+ log_model: False
28
+
29
+ callbacks:
30
+ rich_progress_bar: null
31
+
32
+ wandb_offline_checkpoint:
33
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback
audio-embeddings/configs/experiment/rqa_jepa/yt1b.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: yt1b
5
+ - override /model: rqa_jepa
6
+ - override /trainer: gpu
7
+ - override /logger: wandb
8
+ - override /callbacks: default
9
+
10
+ # all parameters below will be merged with parameters from default configurations set above
11
+ # this allows you to overwrite only specified parameters
12
+
13
+ tags: ["yt1b", "rqa-jepa", "baseline", "cluster GPU"]
14
+
15
+ trainer:
16
+ max_steps: 200000
17
+
18
+ data:
19
+ batch_size: 256
20
+ num_workers: ${oc.decode:${oc.env:SLURM_CPUS_PER_TASK}}
21
+
22
+ logger:
23
+ wandb:
24
+ offline: True
25
+ log_model: False
26
+
27
+ callbacks:
28
+ rich_progress_bar: null
29
+
30
+ wandb_offline_checkpoint:
31
+ _target_: src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback