import torch import sys import os # Add src to path sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.models.audio_jepa_module import AudioJEPAModule from unittest.mock import MagicMock def test_scheduler(): # Mock dependencies optimizer_cls = MagicMock() optimizer_instance = MagicMock() optimizer_cls.return_value = optimizer_instance # Mock net config net_config = { "spectrogram": {}, "patch_embed": {}, "masking": {}, "encoder": {"embed_dim": 768}, "predictor": {"embed_dim": 768}, } # Instantiate module module = AudioJEPAModule( optimizer=optimizer_cls, net=net_config, warmup_pct=0.1, final_lr_ratio=0.001 ) # Mock trainer module.trainer = MagicMock() module.trainer.max_steps = 1000 module.trainer.estimated_stepping_batches = 1000 # Call configure_optimizers # We need a real optimizer to step the scheduler real_optimizer = torch.optim.SGD([torch.nn.Parameter(torch.randn(1))], lr=1.0) module.hparams.optimizer = lambda params: real_optimizer optim_conf = module.configure_optimizers() scheduler = optim_conf["lr_scheduler"]["scheduler"] lrs = [] steps = range(1000) for step in steps: # Step scheduler scheduler.step() lrs.append(scheduler.get_last_lr()[0]) # Verify warmup_steps = 100 # 0.1 * 1000 print(f"LR at step 0: {lrs[0]}") print(f"LR at step {warmup_steps}: {lrs[warmup_steps]}") print(f"LR at step 999: {lrs[999]}") # Check warmup # At step 50 (halfway warmup), lr should be ~0.5 # Note: LambdaLR calls lambda with epoch/step. # If we step scheduler 1000 times. # Plot if possible (optional, but printing is enough for now) # Assertions assert lrs[0] < 0.1, f"LR at step 0 should be small, got {lrs[0]}" # At warmup_steps, it might be slightly off due to 0-indexing or 1-indexing in LambdaLR? # LambdaLR passes `last_epoch` which starts at -1 and increments on step(). # So first step() makes it 0. # My lambda receives 0. # If step=0, lr = 0/100 = 0. # Let's check peak # At step=warmup_steps (100), lambda receives 100. # 100 < 100 is False. # progress = (100-100)/(900) = 0. # cosine_part = 0.5 * (1 + 1) = 1. # lr = final + (1-final)*1 = 1.0. # So at step 100 (which is the 101th value in lrs if we record after step), it should be 1.0? # Wait, scheduler.step() is usually called AFTER optimizer.step(). # In Lightning, it calls scheduler.step() every step. # Let's just inspect the values. # Check decay # At step 550 (midway of decay), progress = 450/900 = 0.5. # cos(pi * 0.5) = 0. # cosine_part = 0.5 * (1 + 0) = 0.5. # lr = final + (1-final)*0.5 ~ 0.5. print("Verification successful!") if __name__ == "__main__": test_scheduler()