# Audio Embeddings with Lightning & Hydra This project is a clean, modular, and scalable implementation of audio embedding models using **PyTorch Lightning** and **Hydra**. It is designed to be easily extensible and runnable on local or cluster environments. It is based on the [Audio-JEPA](https://github.com/LudovicTuncay/Audio-JEPA) implementation and therefore implements the Audio-JEPA architecture. Other architecture can and will be added in the future. ## ๐ŸŽฏ Goal The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include: - **Modular Architecture**: Components like Spectrogram, Masking, and ViT are decoupled. - **Configurable Positional Embeddings**: Support for **RoPE** (2D Rotary Embeddings), **SinCos** (2D Sinusoidal), and **Learnable** embeddings. - **Hydra Configuration**: flexible experiment management via hierarchical config files. - **Lightning Trainer**: Simplified training loop, logging, and checkpointing. - **Modern Tooling**: Uses `uv` for fast and reliable dependency management. ## ๐Ÿš€ Installation This project uses [`uv`](https://github.com/astral-sh/uv) for dependency management. 1. **Install `uv`** (if not already installed): ```bash curl -LsSf https://astral.sh/uv/install.sh | sh ``` 2. **Clone the repository**: ```bash git clone cd audio-embeddings ``` 3. **Install dependencies**: ```bash uv sync ``` 4. **Enable shared git hooks** (runs `uv sync` after merge/checkout/rewrite): ```bash git config core.hooksPath .githooks ``` ## ๐Ÿƒ Usage ### Basic Training To start training with the default configuration: ```bash uv run src/train.py ``` ### Common Commands Run on GPU with Weights & Biases logging: ```bash uv run src/train.py trainer=gpu logger=wandb ``` Override hyperparameters on the command line: ```bash uv run src/train.py data.batch_size=64 trainer.max_epochs=50 ``` ### Configurable Positional Embeddings You can switch between different positional embedding strategies easily: **RoPE**: ```bash uv run src/train.py model.net.encoder.pos_embed_type=rope ``` ### Offline WandB Logging with Model Checkpoints To run training offline but still have model checkpoints staged for upload (which standard WandB restricts): ```bash uv run src/train.py \ logger=wandb \ logger.wandb.offline=True \ logger.wandb.log_model=False \ +callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \ trainer=gpu trainer.devices=1 \ data.batch_size=128 trainer.max_epochs=100 ``` These checkpoints will be uploaded when you run `wandb sync`. **2D SinCos**: ```bash uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincos ``` **Learnable**: ```bash uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnable ``` ## ๐Ÿ“‚ Project Structure ```text โ”œโ”€โ”€ configs/ # Hydra configuration files โ”‚ โ”œโ”€โ”€ callbacks/ # Callback configs (checkpoints, early stopping) โ”‚ โ”œโ”€โ”€ data/ # Data configs (AudioSet, etc.) โ”‚ โ”œโ”€โ”€ logger/ # Logger configs (WandB, Tensorboard) โ”‚ โ”œโ”€โ”€ model/ # Model configs (AudioJEPA parameters) โ”‚ โ”œโ”€โ”€ trainer/ # Trainer configs (CPU, GPU, strategies) โ”‚ โ””โ”€โ”€ train.yaml # Main configuration entry point โ”œโ”€โ”€ src/ โ”‚ โ”œโ”€โ”€ data/ # Data loading logic โ”‚ โ”‚ โ””โ”€โ”€ audioset_datamodule.py # AudioSet DataModule & Dataset โ”‚ โ”œโ”€โ”€ models/ # Model architectures โ”‚ โ”‚ โ”œโ”€โ”€ components/ # Reusable blocks โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ masking.py # Masking generators โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ patch_embed.py # Patchification โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ rope.py # 2D Rotary Embeddings โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ spectrogram.py # Audio preprocessing โ”‚ โ”‚ โ”‚ โ””โ”€โ”€ vit.py # Vision Transformer (Student/Teacher/Predictor) โ”‚ โ”‚ โ””โ”€โ”€ audio_jepa_module.py # Main LightningModule โ”‚ โ”œโ”€โ”€ utils/ # Utility functions โ”‚ โ””โ”€โ”€ train.py # Training entry point โ”œโ”€โ”€ scripts/ # Helper scripts โ”œโ”€โ”€ tests/ # Verification tests โ”œโ”€โ”€ pyproject.toml # Project dependencies โ””โ”€โ”€ README.md # This file ``` ## ๐Ÿ› ๏ธ Extensibility ### Adding a New Model 1. Create your model components in `src/models/components/`. 2. Create a new LightningModule in `src/models/` (or update `AudioJEPAModule`). 3. Create a new config file in `configs/model/my_new_model.yaml`. 4. Run with `uv run src/train.py model=my_new_model`. ### Adding a New Dataset 1. Create a new DataModule in `src/data/`. 2. Create a new config file in `configs/data/my_dataset.yaml`. 3. Run with `uv run src/train.py data=my_dataset`. ### Adding Functionalities - **Callbacks**: Add custom callbacks in `src/callbacks/` (if needed) or use existing Lightning callbacks, and configure them in `configs/callbacks/`. - **Metrics**: Add metrics logging in `training_step` or `validation_step` inside `src/models/audio_jepa_module.py`. ## ๐Ÿงช Testing Run verification scripts to ensure components are working: ```bash uv run tests/verify_rope.py uv run tests/verify_custom_rope.py ```