PepGLAD / data /dataset_wrapper.py
Irwiny123's picture
添加PepGLAD初始代码
52007f8
from typing import Callable
from tqdm import tqdm
from math import log
import numpy as np
import torch
import sympy
from utils import register as R
class MixDatasetWrapper(torch.utils.data.Dataset):
def __init__(self, *datasets, collate_fn: Callable=None) -> None:
super().__init__()
self.datasets = datasets
self.cum_len = []
self.total_len = 0
for dataset in datasets:
self.total_len += len(dataset)
self.cum_len.append(self.total_len)
self.collate_fn = self.datasets[0].collate_fn if collate_fn is None else collate_fn
if hasattr(datasets[0], '_lengths'):
self._lengths = []
for dataset in datasets:
self._lengths.extend(dataset._lengths)
def update_epoch(self):
for dataset in self.datasets:
if hasattr(dataset, 'update_epoch'):
dataset.update_epoch()
def get_len(self, idx):
return self._lengths[idx]
def __len__(self):
return self.total_len
def __getitem__(self, idx):
last_cum_len = 0
for i, cum_len in enumerate(self.cum_len):
if idx < cum_len:
return self.datasets[i].__getitem__(idx - last_cum_len)
last_cum_len = cum_len
return None # this is not possible
@R.register('DynamicBatchWrapper')
class DynamicBatchWrapper(torch.utils.data.Dataset):
def __init__(self, dataset, complexity, ubound_per_batch) -> None:
super().__init__()
self.dataset = dataset
self.indexes = [i for i in range(len(dataset))]
self.complexity = complexity
self.eval_func = sympy.lambdify('n', sympy.simplify(complexity))
self.ubound_per_batch = ubound_per_batch
self.total_size = None
self.batch_indexes = []
self._form_batch()
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
elif hasattr(self.dataset, attr):
return getattr(self.dataset, attr)
else:
raise AttributeError(f"'DynamicBatchWrapper'(or '{type(self.dataset)}') object has no attribute '{attr}'")
def update_epoch(self):
if hasattr(self.dataset, 'update_epoch'):
self.dataset.update_epoch()
self._form_batch()
########## overload with your criterion ##########
def _form_batch(self):
np.random.shuffle(self.indexes)
last_batch_indexes = self.batch_indexes
self.batch_indexes = []
cur_complexity = 0
batch = []
for i in tqdm(self.indexes):
item_len = self.eval_func(self.dataset.get_len(i))
if item_len > self.ubound_per_batch:
continue
cur_complexity += item_len
if cur_complexity > self.ubound_per_batch:
self.batch_indexes.append(batch)
batch = []
cur_complexity = item_len
batch.append(i)
self.batch_indexes.append(batch)
if self.total_size is None:
self.total_size = len(self.batch_indexes)
else:
# control the lengths of the dataset, otherwise the dataloader will raise error
if len(self.batch_indexes) < self.total_size:
num_add = self.total_size - len(self.batch_indexes)
self.batch_indexes = self.batch_indexes + last_batch_indexes[:num_add]
else:
self.batch_indexes = self.batch_indexes[:self.total_size]
def __len__(self):
return len(self.batch_indexes)
def __getitem__(self, idx):
return [self.dataset[i] for i in self.batch_indexes[idx]]
def collate_fn(self, batched_batch):
batch = []
for minibatch in batched_batch:
batch.extend(minibatch)
return self.dataset.collate_fn(batch)