PepGLAD / data /__init__.py
Irwiny123's picture
添加PepGLAD初始代码
52007f8
#!/usr/bin/python
# -*- coding:utf-8 -*-
from .dataset_wrapper import MixDatasetWrapper
from .codesign import CoDesignDataset
from .resample import ClusterResampler
import torch
from torch.utils.data import DataLoader
import utils.register as R
from utils.logger import print_log
def create_dataset(config: dict):
splits = []
for split_name in ['train', 'valid', 'test']:
split_config = config.get(split_name, None)
if split_config is None:
splits.append(None)
continue
if isinstance(split_config, list):
dataset = MixDatasetWrapper(
*[R.construct(cfg) for cfg in split_config]
)
else:
dataset = R.construct(split_config)
splits.append(dataset)
return splits # train/valid/test
def create_dataloader(dataset, config: dict, n_gpu: int=1, validation: bool=False):
if 'wrapper' in config:
dataset = R.construct(config['wrapper'], dataset=dataset)
batch_size = config.get('batch_size', n_gpu) # default 1 on each gpu
if validation:
batch_size = config.get('val_batch_size', batch_size)
shuffle = config.get('shuffle', False)
num_workers = config.get('num_workers', 4)
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
if n_gpu > 1:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
batch_size = int(batch_size / n_gpu)
print_log(f'Batch size on a single GPU: {batch_size}')
else:
sampler = None
return DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=(shuffle and sampler is None),
collate_fn=collate_fn,
sampler=sampler
)