| import torch |
|
|
| import sys |
| import os |
|
|
| |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..")) |
|
|
| from src.models.audio_jepa_module import AudioJEPAModule |
| from unittest.mock import MagicMock |
|
|
|
|
| def test_scheduler(): |
| |
| optimizer_cls = MagicMock() |
| optimizer_instance = MagicMock() |
| optimizer_cls.return_value = optimizer_instance |
|
|
| |
| net_config = { |
| "spectrogram": {}, |
| "patch_embed": {}, |
| "masking": {}, |
| "encoder": {"embed_dim": 768}, |
| "predictor": {"embed_dim": 768}, |
| } |
|
|
| |
| module = AudioJEPAModule( |
| optimizer=optimizer_cls, net=net_config, warmup_pct=0.1, final_lr_ratio=0.001 |
| ) |
|
|
| |
| module.trainer = MagicMock() |
| module.trainer.max_steps = 1000 |
| module.trainer.estimated_stepping_batches = 1000 |
|
|
| |
| |
| real_optimizer = torch.optim.SGD([torch.nn.Parameter(torch.randn(1))], lr=1.0) |
| module.hparams.optimizer = lambda params: real_optimizer |
|
|
| optim_conf = module.configure_optimizers() |
| scheduler = optim_conf["lr_scheduler"]["scheduler"] |
|
|
| lrs = [] |
| steps = range(1000) |
|
|
| for step in steps: |
| |
| scheduler.step() |
| lrs.append(scheduler.get_last_lr()[0]) |
|
|
| |
| warmup_steps = 100 |
|
|
| print(f"LR at step 0: {lrs[0]}") |
| print(f"LR at step {warmup_steps}: {lrs[warmup_steps]}") |
| print(f"LR at step 999: {lrs[999]}") |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| assert lrs[0] < 0.1, f"LR at step 0 should be small, got {lrs[0]}" |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| print("Verification successful!") |
|
|
|
|
| if __name__ == "__main__": |
| test_scheduler() |
|
|