| | |
| | |
| | |
| | |
| |
|
| | imports: |
| | - $import os |
| | - $import datetime |
| | - $import torch |
| | - $import glob |
| |
|
| | |
| | image: $monai.utils.CommonKeys.IMAGE |
| | label: $monai.utils.CommonKeys.LABEL |
| | pred: $monai.utils.CommonKeys.PRED |
| | both_keys: ['@image', '@label'] |
| |
|
| | |
| | rank: 0 |
| | is_not_rank0: '$@rank > 0' |
| |
|
| | |
| | val_interval: 1 |
| | ckpt_interval: 1 |
| | rand_prob: 0.5 |
| | batch_size: 5 |
| | num_epochs: 20 |
| | num_substeps: 1 |
| | num_workers: 4 |
| | learning_rate: 0.001 |
| | num_classes: 4 |
| | device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | bundle_root: . |
| | ckpt_path: $@bundle_root + '/models/model.pt' |
| | dataset_dir: $@bundle_root + '/train_data' |
| | results_dir: $@bundle_root + '/results' |
| | |
| | output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' |
| |
|
| | |
| | network_def: |
| | _target_: UNet |
| | spatial_dims: 3 |
| | in_channels: 1 |
| | out_channels: '@num_classes' |
| | channels: [8, 16, 32, 64] |
| | strides: [2, 2, 2] |
| | num_res_units: 2 |
| | network: $@network_def.to(@device) |
| |
|
| | |
| | imgs: '$sorted(glob.glob(@dataset_dir+''/img*.nii.gz''))' |
| | lbls: '$[i.replace(''img'',''lbl'') for i in @imgs]' |
| | all_pairs: '$[{@image: i, @label: l} for i, l in zip(@imgs, @lbls)]' |
| | partitions: '$monai.data.partition_dataset(@all_pairs, (4, 1), shuffle=True, seed=0)' |
| | train_sub: '$@partitions[0]' |
| | val_sub: '$@partitions[1]' |
| |
|
| | |
| | base_transforms: |
| | - _target_: LoadImaged |
| | keys: '@both_keys' |
| | image_only: true |
| | - _target_: EnsureChannelFirstd |
| | keys: '@both_keys' |
| |
|
| | |
| | train_transforms: |
| | - _target_: RandAxisFlipd |
| | keys: '@both_keys' |
| | prob: '@rand_prob' |
| | - _target_: RandRotate90d |
| | keys: '@both_keys' |
| | prob: '@rand_prob' |
| | - _target_: RandGaussianNoised |
| | keys: '@image' |
| | prob: '@rand_prob' |
| | std: 0.05 |
| | - _target_: ScaleIntensityd |
| | keys: '@image' |
| |
|
| | |
| | val_transforms: |
| | - _target_: ScaleIntensityd |
| | keys: '@image' |
| |
|
| | |
| |
|
| | preprocessing: |
| | _target_: Compose |
| | transforms: $@base_transforms + @train_transforms |
| |
|
| | val_preprocessing: |
| | _target_: Compose |
| | transforms: $@base_transforms + @val_transforms |
| |
|
| | |
| |
|
| | train_dataset: |
| | _target_: Dataset |
| | data: '@train_sub' |
| | transform: '@preprocessing' |
| |
|
| | val_dataset: |
| | _target_: Dataset |
| | data: '@val_sub' |
| | transform: '@val_preprocessing' |
| |
|
| | |
| |
|
| | train_dataloader: |
| | _target_: ThreadDataLoader |
| | dataset: '@train_dataset' |
| | batch_size: '@batch_size' |
| | repeats: '@num_substeps' |
| | num_workers: '@num_workers' |
| |
|
| | val_dataloader: |
| | _target_: DataLoader |
| | dataset: '@val_dataset' |
| | batch_size: '@batch_size' |
| | num_workers: '@num_workers' |
| |
|
| | |
| | |
| | lossfn: |
| | _target_: DiceLoss |
| | include_background: true |
| | to_onehot_y: true |
| | softmax: true |
| |
|
| | |
| | optimizer: |
| | _target_: torch.optim.Adam |
| | params: $@network.parameters() |
| | lr: '@learning_rate' |
| |
|
| | |
| | inferer: |
| | _target_: SimpleInferer |
| |
|
| | |
| | postprocessing: |
| | _target_: Compose |
| | transforms: |
| | - _target_: Activationsd |
| | keys: '@pred' |
| | softmax: true |
| | - _target_: AsDiscreted |
| | keys: ['@pred', '@label'] |
| | argmax: [true, false] |
| | to_onehot: '@num_classes' |
| |
|
| | |
| | val_handlers: |
| | - _target_: StatsHandler |
| | name: null |
| | output_transform: '$lambda x: None' |
| | - _target_: LogfileHandler |
| | output_dir: '@output_dir' |
| | - _target_: CheckpointSaver |
| | _disabled_: '@is_not_rank0' |
| | save_dir: '@output_dir' |
| | save_dict: |
| | model: '@network' |
| | save_interval: 0 |
| | save_final: false |
| | epoch_level: false |
| | save_key_metric: true |
| | key_metric_name: val_mean_dice |
| |
|
| | |
| | evaluator: |
| | _target_: SupervisedEvaluator |
| | device: '@device' |
| | val_data_loader: '@val_dataloader' |
| | network: '@network' |
| | postprocessing: '@postprocessing' |
| | key_val_metric: |
| | val_mean_dice: |
| | _target_: MeanDice |
| | include_background: false |
| | output_transform: $monai.handlers.from_engine([@pred, @label]) |
| | val_mean_iou: |
| | _target_: MeanIoUHandler |
| | include_background: false |
| | output_transform: $monai.handlers.from_engine([@pred, @label]) |
| | additional_metrics: |
| | val_mae: |
| | _target_: MeanAbsoluteError |
| | output_transform: $monai.handlers.from_engine([@pred, @label]) |
| | val_handlers: '@val_handlers' |
| |
|
| | |
| | metriclogger: |
| | _target_: MetricLogger |
| | evaluator: '@evaluator' |
| |
|
| | handlers: |
| | - '@metriclogger' |
| | - _target_: CheckpointLoader |
| | _disabled_: $not os.path.exists(@ckpt_path) |
| | load_path: '@ckpt_path' |
| | load_dict: |
| | model: '@network' |
| | - _target_: ValidationHandler |
| | validator: '@evaluator' |
| | epoch_level: true |
| | interval: '@val_interval' |
| | - _target_: CheckpointSaver |
| | _disabled_: '@is_not_rank0' |
| | save_dir: '@output_dir' |
| | save_dict: |
| | model: '@network' |
| | logger: '@metriclogger' |
| | save_interval: '@ckpt_interval' |
| | save_final: true |
| | epoch_level: true |
| | - _target_: StatsHandler |
| | name: null |
| | tag_name: train_loss |
| | output_transform: $monai.handlers.from_engine(['loss'], first=True) |
| | - _target_: LogfileHandler |
| | output_dir: '@output_dir' |
| |
|
| | |
| | trainer: |
| | _target_: SupervisedTrainer |
| | max_epochs: '@num_epochs' |
| | device: '@device' |
| | train_data_loader: '@train_dataloader' |
| | network: '@network' |
| | inferer: '@inferer' |
| | loss_function: '@lossfn' |
| | optimizer: '@optimizer' |
| | |
| | key_train_metric: null |
| | train_handlers: '@handlers' |
| |
|
| | run: |
| | - $@trainer.run() |
| |
|