File size: 2,376 Bytes
4edc9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import sys
import unittest.mock
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf

# Add Matcha-TTS to python path to access its modules
ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(ROOT / "Matcha-TTS"))

sys.path.append(str(ROOT))
import src.training
from src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder
from src.stage2.CFM import CFM
from torch.utils.data import DataLoader, Dataset


class MockDataset(Dataset):
    def __init__(
        self, num_samples, num_subjects=4, time_steps=10, voxels=100, feat_dims=(32, 64)
    ):
        self.num_samples = num_samples
        self.num_subjects = num_subjects
        self.time_steps = time_steps
        self.voxels = voxels
        self.feat_dims = feat_dims

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # features list
        features = [torch.randn(self.time_steps, dim) for dim in self.feat_dims]
        # fmri: (S, T, V)
        fmri = torch.randn(self.num_subjects, self.time_steps, self.voxels)

        return {"features": features, "fmri": fmri}


def mock_make_data_loaders(cfg):
    print("MOCKING DATA LOADERS FOR DEBUG")
    # Using small dimensions for debug
    num_samples = 4
    batch_size = cfg.batch_size

    # Mock dimensions
    voxels = 1000
    feat_dims = (32, 64)

    ds = MockDataset(num_samples=num_samples, voxels=voxels, feat_dims=feat_dims)
    loader = DataLoader(ds, batch_size=batch_size)

    return {"train": loader, "val_debug": loader}  # Use same for val


def main():
    # Patch the make_data_loaders in training.py
    with unittest.mock.patch(
        "src.training.make_data_loaders", side_effect=mock_make_data_loaders
    ):
        # Manually set arguments to point to debug config
        # Or better yet, call main() but intercept argument parsing?
        # Since training.main() parses args, we can simulate command line args.

        # Override sys.argv
        sys.argv = ["training.py", "--cfg-path", "test/debug_config.yml"]

        # Call original main
        try:
            src.training.main()
        except Exception as e:
            print(f"Caught exception during debug run: {e}")
            import traceback

            traceback.print_exc()


if __name__ == "__main__":
    main()