OrthoReg / src /distributed.py
gezi2333's picture
Upload folder using huggingface_hub
3589275 verified
import os
import torch
def setup_ddp(rank, world_size, port=12357):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
# initialize the process group
torch.distributed.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
)
torch.cuda.set_device(rank)
torch.distributed.barrier()
def cleanup_ddp():
torch.distributed.destroy_process_group()
def is_main_process():
return torch.distributed.get_rank() == 0
def distribute_loader(loader):
return torch.utils.data.DataLoader(
loader.dataset,
batch_size=loader.batch_size // torch.distributed.get_world_size(),
sampler=torch.utils.data.distributed.DistributedSampler(
loader.dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
),
num_workers=loader.num_workers,
pin_memory=loader.pin_memory,
)