File size: 4,741 Bytes
b8f4b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

dataset:
  add_dim_keys:
    test: !!python/tuple
    - drift_at_observations
    train: !!python/tuple
    - drift_at_observations
    validation: !!python/tuple
    - drift_at_observations
  add_paths_keys:
    test: !!python/tuple
    - drift_at_observations
    train: !!python/tuple
    - drift_at_observations
    validation: !!python/tuple
    - drift_at_observations
  batch_size:
    test: 32
    train: 64
    validation: 32
  data_dirs:
    test: !!python/tuple
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_3
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_2
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_1
    train: !!python/tuple
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_3
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_2
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_1
    validation: !!python/tuple
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_3
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_2
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_1
  dataset_name:
    test: HeterogeneousFIMSDEDataset
    train: StreamingFIMSDEDataset
    validation: StreamingFIMSDEDataset
  files_to_load:
    drift_at_locations: drift_at_locations.h5
    drift_at_observations: drift_at_observations.h5
    locations: locations.h5
    obs_mask: obs_mask.h5
    obs_times: obs_times.h5
    obs_values: obs_values.h5
  max_dim: 3
  name: FIMSDEDataloaderIterableDataset
  num_locations:
    test: null
    train: 2000
    validation: 10000
  num_observations:
    test: null
    train: !!python/tuple
    - 0
    - 1801
    validation: !!python/tuple
    - 1799
    - 1801
  num_workers:
    test: 0
    train: 7
    validation: 5
  shard:
    test: false
    train: true
    validation: true
  shuffle_elements: true
  shuffle_locations:
    test: false
    train: true
    validation: true
  shuffle_paths: true

distributed:
  activation_chekpoint: false
  checkpoint_type: full_state
  enabled: true
  min_num_params: 1e5
  sharding_strategy: NO_SHARD
  wrap_policy: SIZE_BAZED

experiment:
  device_map: cuda
  name: big_model_l1_600k_examples
  name_add_date: true
  seed: 10

model:
  model_config:
    attention_map: softmax
    attention_method: linear
    dim_embed: 256
    dim_feedforward: 1024
    dim_ffn_u_model: 1024
    dim_hidden_u_model: 256
    dim_max_trajectory: 3
    dropout: 0.1
    num_context_encoder_layers: 2
    num_heads: 8
    num_res_layer_u_model: 6
    num_res_layers_functional_decoder: 8
    use_bias_for_projection: true
    use_bias_in_attention: true
    use_query_residual_in_attention: true
  model_type: TrainingWrapper
  train_config:
    corruption_model_type: odeformer
    loss_filter_nans: true
    loss_type: l1
    max_sigma_trajectory_noise: 0.06
    max_subsampling_ration: 0.5
    train_type: vector_field
    train_with_normalized_head: true

optimizers: !!python/tuple
- optimizer_d:
    gradient_norm_clipping: 10
    lr: 1.0e-05
    name: torch.optim.AdamW
    weight_decay: 0.0001

trainer:
  best_metric: loss
  debug_iterations: null
  detect_anomaly: false
  epochs: 2500
  experiment_dir: ./results/
  gradient_accumulation_steps: 1
  logging_format: RANK_%(rank)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s
  name: Trainer
  precision: bf16mixed
  save_every: 1
  schedulers: !!python/tuple
  - beta: 1.0
    label: drift_loss_scale
    name: fim.utils.param_scheduler.ConstantScheduler