fim-ode / base_model /train_parameters.yaml
jrhuebers's picture
Upload FIM-ODE base model
b8f4b1f verified
dataset:
add_dim_keys:
test: !!python/tuple
- drift_at_observations
train: !!python/tuple
- drift_at_observations
validation: !!python/tuple
- drift_at_observations
add_paths_keys:
test: !!python/tuple
- drift_at_observations
train: !!python/tuple
- drift_at_observations
validation: !!python/tuple
- drift_at_observations
batch_size:
test: 32
train: 64
validation: 32
data_dirs:
test: !!python/tuple
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_3
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_2
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_1
train: !!python/tuple
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_3
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_2
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_1
validation: !!python/tuple
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_3
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_2
- /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_1
dataset_name:
test: HeterogeneousFIMSDEDataset
train: StreamingFIMSDEDataset
validation: StreamingFIMSDEDataset
files_to_load:
drift_at_locations: drift_at_locations.h5
drift_at_observations: drift_at_observations.h5
locations: locations.h5
obs_mask: obs_mask.h5
obs_times: obs_times.h5
obs_values: obs_values.h5
max_dim: 3
name: FIMSDEDataloaderIterableDataset
num_locations:
test: null
train: 2000
validation: 10000
num_observations:
test: null
train: !!python/tuple
- 0
- 1801
validation: !!python/tuple
- 1799
- 1801
num_workers:
test: 0
train: 7
validation: 5
shard:
test: false
train: true
validation: true
shuffle_elements: true
shuffle_locations:
test: false
train: true
validation: true
shuffle_paths: true
distributed:
activation_chekpoint: false
checkpoint_type: full_state
enabled: true
min_num_params: 1e5
sharding_strategy: NO_SHARD
wrap_policy: SIZE_BAZED
experiment:
device_map: cuda
name: big_model_l1_600k_examples
name_add_date: true
seed: 10
model:
model_config:
attention_map: softmax
attention_method: linear
dim_embed: 256
dim_feedforward: 1024
dim_ffn_u_model: 1024
dim_hidden_u_model: 256
dim_max_trajectory: 3
dropout: 0.1
num_context_encoder_layers: 2
num_heads: 8
num_res_layer_u_model: 6
num_res_layers_functional_decoder: 8
use_bias_for_projection: true
use_bias_in_attention: true
use_query_residual_in_attention: true
model_type: TrainingWrapper
train_config:
corruption_model_type: odeformer
loss_filter_nans: true
loss_type: l1
max_sigma_trajectory_noise: 0.06
max_subsampling_ration: 0.5
train_type: vector_field
train_with_normalized_head: true
optimizers: !!python/tuple
- optimizer_d:
gradient_norm_clipping: 10
lr: 1.0e-05
name: torch.optim.AdamW
weight_decay: 0.0001
trainer:
best_metric: loss
debug_iterations: null
detect_anomaly: false
epochs: 2500
experiment_dir: ./results/
gradient_accumulation_steps: 1
logging_format: RANK_%(rank)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s
name: Trainer
precision: bf16mixed
save_every: 1
schedulers: !!python/tuple
- beta: 1.0
label: drift_loss_scale
name: fim.utils.param_scheduler.ConstantScheduler