| import torch |
| import torch.nn as nn |
| import sys |
| import os |
|
|
| |
| sys.path.append(os.getcwd()) |
|
|
| from src.models.rqa_jepa_module import RQAJEPAModule |
|
|
|
|
| def verify_rqa_jepa(): |
| print("Verifying RQA-JEPA Implementation...") |
|
|
| |
| net_config = { |
| "spectrogram": {}, |
| "patch_embed": { |
| "img_size": (1024, 128), |
| "patch_size": (16, 16), |
| "in_chans": 1, |
| "embed_dim": 768, |
| }, |
| "masking": { |
| "input_size": (64, 8), |
| "mask_ratio": (0.4, 0.6), |
| }, |
| "encoder": { |
| "img_size": (1024, 128), |
| "patch_size": (16, 16), |
| "embed_dim": 768, |
| "depth": 2, |
| "num_heads": 4, |
| }, |
| "predictor": { |
| "img_size": (1024, 128), |
| "patch_size": (16, 16), |
| "embed_dim": 384, |
| "depth": 1, |
| "num_heads": 4, |
| }, |
| } |
|
|
| |
| def optimizer_partial(params): |
| return torch.optim.Adam(params) |
|
|
| |
| print("\n--- Verifying RQA-JEPA (Mode: teacher) ---") |
| model_teacher = RQAJEPAModule( |
| optimizer=optimizer_partial, |
| net=net_config, |
| jepa_criterion=nn.MSELoss(), |
| rq_criterion=nn.CrossEntropyLoss(), |
| rq_input_type="teacher", |
| codebook_dim=16, |
| vocab_size=100, |
| ) |
|
|
| |
| B, C, T = 2, 1, 16000 |
| waveform = torch.rand(B, C, T) |
| batch = {"waveform": waveform} |
|
|
| |
| model_teacher.to("cpu") |
|
|
| |
| print("Model (teacher) instantiated. Running training_step...") |
| loss_teacher = model_teacher.training_step(batch, 0) |
| print(f"Training step successful. Loss: {loss_teacher.item()}") |
| assert isinstance(loss_teacher, torch.Tensor) |
| assert loss_teacher.ndim == 0 |
|
|
| |
| print("\n--- Verifying RQA-JEPA (Mode: spectrogram) ---") |
| model_spec = RQAJEPAModule( |
| optimizer=optimizer_partial, |
| net=net_config, |
| jepa_criterion=nn.MSELoss(), |
| rq_criterion=nn.CrossEntropyLoss(), |
| rq_input_type="spectrogram", |
| codebook_dim=16, |
| vocab_size=100, |
| ) |
| model_spec.to("cpu") |
|
|
| print("Model (spectrogram) instantiated. Running training_step...") |
| loss_spec = model_spec.training_step(batch, 0) |
| print(f"Training step successful. Loss: {loss_spec.item()}") |
| assert isinstance(loss_spec, torch.Tensor) |
| assert loss_spec.ndim == 0 |
|
|
| print("\nVerification Passed!") |
|
|
|
|
| if __name__ == "__main__": |
| verify_rqa_jepa() |
|
|