| import os | |
| import torch | |
| def setup_ddp(rank, world_size, port=12357): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = str(port) | |
| # initialize the process group | |
| torch.distributed.init_process_group( | |
| "nccl", | |
| rank=rank, | |
| world_size=world_size, | |
| ) | |
| torch.cuda.set_device(rank) | |
| torch.distributed.barrier() | |
| def cleanup_ddp(): | |
| torch.distributed.destroy_process_group() | |
| def is_main_process(): | |
| return torch.distributed.get_rank() == 0 | |
| def distribute_loader(loader): | |
| return torch.utils.data.DataLoader( | |
| loader.dataset, | |
| batch_size=loader.batch_size // torch.distributed.get_world_size(), | |
| sampler=torch.utils.data.distributed.DistributedSampler( | |
| loader.dataset, | |
| num_replicas=torch.distributed.get_world_size(), | |
| rank=torch.distributed.get_rank(), | |
| ), | |
| num_workers=loader.num_workers, | |
| pin_memory=loader.pin_memory, | |
| ) | |