File size: 2,908 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import torch
import sys
import os
# Add src to path
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from src.models.audio_jepa_module import AudioJEPAModule
from unittest.mock import MagicMock
def test_scheduler():
# Mock dependencies
optimizer_cls = MagicMock()
optimizer_instance = MagicMock()
optimizer_cls.return_value = optimizer_instance
# Mock net config
net_config = {
"spectrogram": {},
"patch_embed": {},
"masking": {},
"encoder": {"embed_dim": 768},
"predictor": {"embed_dim": 768},
}
# Instantiate module
module = AudioJEPAModule(
optimizer=optimizer_cls, net=net_config, warmup_pct=0.1, final_lr_ratio=0.001
)
# Mock trainer
module.trainer = MagicMock()
module.trainer.max_steps = 1000
module.trainer.estimated_stepping_batches = 1000
# Call configure_optimizers
# We need a real optimizer to step the scheduler
real_optimizer = torch.optim.SGD([torch.nn.Parameter(torch.randn(1))], lr=1.0)
module.hparams.optimizer = lambda params: real_optimizer
optim_conf = module.configure_optimizers()
scheduler = optim_conf["lr_scheduler"]["scheduler"]
lrs = []
steps = range(1000)
for step in steps:
# Step scheduler
scheduler.step()
lrs.append(scheduler.get_last_lr()[0])
# Verify
warmup_steps = 100 # 0.1 * 1000
print(f"LR at step 0: {lrs[0]}")
print(f"LR at step {warmup_steps}: {lrs[warmup_steps]}")
print(f"LR at step 999: {lrs[999]}")
# Check warmup
# At step 50 (halfway warmup), lr should be ~0.5
# Note: LambdaLR calls lambda with epoch/step.
# If we step scheduler 1000 times.
# Plot if possible (optional, but printing is enough for now)
# Assertions
assert lrs[0] < 0.1, f"LR at step 0 should be small, got {lrs[0]}"
# At warmup_steps, it might be slightly off due to 0-indexing or 1-indexing in LambdaLR?
# LambdaLR passes `last_epoch` which starts at -1 and increments on step().
# So first step() makes it 0.
# My lambda receives 0.
# If step=0, lr = 0/100 = 0.
# Let's check peak
# At step=warmup_steps (100), lambda receives 100.
# 100 < 100 is False.
# progress = (100-100)/(900) = 0.
# cosine_part = 0.5 * (1 + 1) = 1.
# lr = final + (1-final)*1 = 1.0.
# So at step 100 (which is the 101th value in lrs if we record after step), it should be 1.0?
# Wait, scheduler.step() is usually called AFTER optimizer.step().
# In Lightning, it calls scheduler.step() every step.
# Let's just inspect the values.
# Check decay
# At step 550 (midway of decay), progress = 450/900 = 0.5.
# cos(pi * 0.5) = 0.
# cosine_part = 0.5 * (1 + 0) = 0.5.
# lr = final + (1-final)*0.5 ~ 0.5.
print("Verification successful!")
if __name__ == "__main__":
test_scheduler()
|