Audio Embeddings with Lightning & Hydra
This project is a clean, modular, and scalable implementation of audio embedding models using PyTorch Lightning and Hydra. It is designed to be easily extensible and runnable on local or cluster environments. It is based on the Audio-JEPA implementation and therefore implements the Audio-JEPA architecture. Other architecture can and will be added in the future.
π― Goal
The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include:
- Modular Architecture: Components like Spectrogram, Masking, and ViT are decoupled.
- Configurable Positional Embeddings: Support for RoPE (2D Rotary Embeddings), SinCos (2D Sinusoidal), and Learnable embeddings.
- Hydra Configuration: flexible experiment management via hierarchical config files.
- Lightning Trainer: Simplified training loop, logging, and checkpointing.
- Modern Tooling: Uses
uvfor fast and reliable dependency management.
π Installation
This project uses uv for dependency management.
Install
uv(if not already installed):curl -LsSf https://astral.sh/uv/install.sh | shClone the repository:
git clone <repository_url> cd audio-embeddingsInstall dependencies:
uv syncEnable shared git hooks (runs
uv syncafter merge/checkout/rewrite):git config core.hooksPath .githooks
π Usage
Basic Training
To start training with the default configuration:
uv run src/train.py
Common Commands
Run on GPU with Weights & Biases logging:
uv run src/train.py trainer=gpu logger=wandb
Override hyperparameters on the command line:
uv run src/train.py data.batch_size=64 trainer.max_epochs=50
Configurable Positional Embeddings
You can switch between different positional embedding strategies easily:
RoPE:
uv run src/train.py model.net.encoder.pos_embed_type=rope
Offline WandB Logging with Model Checkpoints
To run training offline but still have model checkpoints staged for upload (which standard WandB restricts):
uv run src/train.py \
logger=wandb \
logger.wandb.offline=True \
logger.wandb.log_model=False \
+callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \
trainer=gpu trainer.devices=1 \
data.batch_size=128 trainer.max_epochs=100
These checkpoints will be uploaded when you run wandb sync.
2D SinCos:
uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincos
Learnable:
uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnable
π Project Structure
βββ configs/ # Hydra configuration files
β βββ callbacks/ # Callback configs (checkpoints, early stopping)
β βββ data/ # Data configs (AudioSet, etc.)
β βββ logger/ # Logger configs (WandB, Tensorboard)
β βββ model/ # Model configs (AudioJEPA parameters)
β βββ trainer/ # Trainer configs (CPU, GPU, strategies)
β βββ train.yaml # Main configuration entry point
βββ src/
β βββ data/ # Data loading logic
β β βββ audioset_datamodule.py # AudioSet DataModule & Dataset
β βββ models/ # Model architectures
β β βββ components/ # Reusable blocks
β β β βββ masking.py # Masking generators
β β β βββ patch_embed.py # Patchification
β β β βββ rope.py # 2D Rotary Embeddings
β β β βββ spectrogram.py # Audio preprocessing
β β β βββ vit.py # Vision Transformer (Student/Teacher/Predictor)
β β βββ audio_jepa_module.py # Main LightningModule
β βββ utils/ # Utility functions
β βββ train.py # Training entry point
βββ scripts/ # Helper scripts
βββ tests/ # Verification tests
βββ pyproject.toml # Project dependencies
βββ README.md # This file
π οΈ Extensibility
Adding a New Model
- Create your model components in
src/models/components/. - Create a new LightningModule in
src/models/(or updateAudioJEPAModule). - Create a new config file in
configs/model/my_new_model.yaml. - Run with
uv run src/train.py model=my_new_model.
Adding a New Dataset
- Create a new DataModule in
src/data/. - Create a new config file in
configs/data/my_dataset.yaml. - Run with
uv run src/train.py data=my_dataset.
Adding Functionalities
- Callbacks: Add custom callbacks in
src/callbacks/(if needed) or use existing Lightning callbacks, and configure them inconfigs/callbacks/. - Metrics: Add metrics logging in
training_steporvalidation_stepinsidesrc/models/audio_jepa_module.py.
π§ͺ Testing
Run verification scripts to ensure components are working:
uv run tests/verify_rope.py
uv run tests/verify_custom_rope.py