| | --- |
| | device: "$torch.device('cuda:' + os.environ['LOCAL_RANK'])" |
| | network: |
| | _target_: torch.nn.parallel.DistributedDataParallel |
| | module: "$@network_def.to(@device)" |
| | find_unused_parameters: true |
| | device_ids: |
| | - "@device" |
| | optimizer#lr: "$0.025*dist.get_world_size()" |
| | lr_scheduler#step_size: "$80*dist.get_world_size()" |
| | train#handlers: |
| | - _target_: LrScheduleHandler |
| | lr_scheduler: "@lr_scheduler" |
| | print_lr: true |
| | - _target_: ValidationHandler |
| | validator: "@validate#evaluator" |
| | epoch_level: true |
| | interval: "$10*dist.get_world_size()" |
| | - _target_: StatsHandler |
| | tag_name: train_loss |
| | output_transform: "$monai.handlers.from_engine(['loss'], first=True)" |
| | - _target_: TensorBoardStatsHandler |
| | log_dir: "@output_dir" |
| | tag_name: train_loss |
| | output_transform: "$monai.handlers.from_engine(['loss'], first=True)" |
| | train#trainer#max_epochs: "$400*dist.get_world_size()" |
| | train#trainer#train_handlers: "$@train#handlers[: -2 if dist.get_rank() > 0 else None]" |
| | validate#evaluator#val_handlers: "$None if dist.get_rank() > 0 else @validate#handlers" |
| | initialize: |
| | - "$import torch.distributed as dist" |
| | - "$dist.is_initialized() or dist.init_process_group(backend='nccl')" |
| | - "$torch.cuda.set_device(@device)" |
| | - "$monai.utils.set_determinism(seed=123)" |
| | - "$setattr(torch.backends.cudnn, 'benchmark', True)" |
| | run: |
| | - "$@train#trainer.run()" |
| | finalize: |
| | - "$dist.is_initialized() and dist.destroy_process_group()" |
| | train_data_partition: "$monai.data.partition_dataset(data=@train_datalist, num_partitions=dist.get_world_size(), |
| | shuffle=True, even_divisible=True,)[dist.get_rank()]" |
| | train#dataset: |
| | _target_: CacheDataset |
| | data: "@train_data_partition" |
| | transform: "@train#preprocessing" |
| | cache_rate: 1 |
| | num_workers: 4 |
| | val_data_partition: "$monai.data.partition_dataset(data=@val_datalist, num_partitions=dist.get_world_size(), |
| | shuffle=False, even_divisible=False,)[dist.get_rank()]" |
| | validate#dataset: |
| | _target_: CacheDataset |
| | data: "@val_data_partition" |
| | transform: "@validate#preprocessing" |
| | cache_rate: 1 |
| | num_workers: 4 |
| |
|