|
|
from typing import Callable |
|
|
from tqdm import tqdm |
|
|
from math import log |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import sympy |
|
|
|
|
|
from utils import register as R |
|
|
|
|
|
|
|
|
class MixDatasetWrapper(torch.utils.data.Dataset): |
|
|
def __init__(self, *datasets, collate_fn: Callable=None) -> None: |
|
|
super().__init__() |
|
|
self.datasets = datasets |
|
|
self.cum_len = [] |
|
|
self.total_len = 0 |
|
|
for dataset in datasets: |
|
|
self.total_len += len(dataset) |
|
|
self.cum_len.append(self.total_len) |
|
|
self.collate_fn = self.datasets[0].collate_fn if collate_fn is None else collate_fn |
|
|
if hasattr(datasets[0], '_lengths'): |
|
|
self._lengths = [] |
|
|
for dataset in datasets: |
|
|
self._lengths.extend(dataset._lengths) |
|
|
|
|
|
def update_epoch(self): |
|
|
for dataset in self.datasets: |
|
|
if hasattr(dataset, 'update_epoch'): |
|
|
dataset.update_epoch() |
|
|
|
|
|
def get_len(self, idx): |
|
|
return self._lengths[idx] |
|
|
|
|
|
def __len__(self): |
|
|
return self.total_len |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
last_cum_len = 0 |
|
|
for i, cum_len in enumerate(self.cum_len): |
|
|
if idx < cum_len: |
|
|
return self.datasets[i].__getitem__(idx - last_cum_len) |
|
|
last_cum_len = cum_len |
|
|
return None |
|
|
|
|
|
|
|
|
@R.register('DynamicBatchWrapper') |
|
|
class DynamicBatchWrapper(torch.utils.data.Dataset): |
|
|
def __init__(self, dataset, complexity, ubound_per_batch) -> None: |
|
|
super().__init__() |
|
|
self.dataset = dataset |
|
|
self.indexes = [i for i in range(len(dataset))] |
|
|
self.complexity = complexity |
|
|
self.eval_func = sympy.lambdify('n', sympy.simplify(complexity)) |
|
|
self.ubound_per_batch = ubound_per_batch |
|
|
self.total_size = None |
|
|
self.batch_indexes = [] |
|
|
self._form_batch() |
|
|
|
|
|
def __getattr__(self, attr): |
|
|
if attr in self.__dict__: |
|
|
return self.__dict__[attr] |
|
|
elif hasattr(self.dataset, attr): |
|
|
return getattr(self.dataset, attr) |
|
|
else: |
|
|
raise AttributeError(f"'DynamicBatchWrapper'(or '{type(self.dataset)}') object has no attribute '{attr}'") |
|
|
|
|
|
def update_epoch(self): |
|
|
if hasattr(self.dataset, 'update_epoch'): |
|
|
self.dataset.update_epoch() |
|
|
self._form_batch() |
|
|
|
|
|
|
|
|
def _form_batch(self): |
|
|
|
|
|
np.random.shuffle(self.indexes) |
|
|
last_batch_indexes = self.batch_indexes |
|
|
self.batch_indexes = [] |
|
|
|
|
|
cur_complexity = 0 |
|
|
batch = [] |
|
|
|
|
|
for i in tqdm(self.indexes): |
|
|
item_len = self.eval_func(self.dataset.get_len(i)) |
|
|
if item_len > self.ubound_per_batch: |
|
|
continue |
|
|
cur_complexity += item_len |
|
|
if cur_complexity > self.ubound_per_batch: |
|
|
self.batch_indexes.append(batch) |
|
|
batch = [] |
|
|
cur_complexity = item_len |
|
|
batch.append(i) |
|
|
self.batch_indexes.append(batch) |
|
|
|
|
|
if self.total_size is None: |
|
|
self.total_size = len(self.batch_indexes) |
|
|
else: |
|
|
|
|
|
if len(self.batch_indexes) < self.total_size: |
|
|
num_add = self.total_size - len(self.batch_indexes) |
|
|
self.batch_indexes = self.batch_indexes + last_batch_indexes[:num_add] |
|
|
else: |
|
|
self.batch_indexes = self.batch_indexes[:self.total_size] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.batch_indexes) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return [self.dataset[i] for i in self.batch_indexes[idx]] |
|
|
|
|
|
def collate_fn(self, batched_batch): |
|
|
batch = [] |
|
|
for minibatch in batched_batch: |
|
|
batch.extend(minibatch) |
|
|
return self.dataset.collate_fn(batch) |