| | import math |
| | import logging |
| | import random |
| |
|
| | class MultiDatasetDataLoader: |
| | def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1): |
| | self.allow_swap = True |
| | self.batch_size_pairs = batch_size_pairs |
| | self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets |
| |
|
| | |
| | self.dataset_lengths = list(map(len, datasets)) |
| | self.dataset_lengths_sum = sum(self.dataset_lengths) |
| |
|
| | weights = [] |
| | if dataset_size_temp > 0: |
| | for dataset in datasets: |
| | prob = len(dataset) / self.dataset_lengths_sum |
| | weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000))) |
| | else: |
| | weights = [100] * len(datasets) |
| |
|
| | logging.info("Dataset lenghts and weights: {}".format(list(zip(self.dataset_lengths, weights)))) |
| |
|
| | self.dataset_idx = [] |
| | self.dataset_idx_pointer = 0 |
| |
|
| | for idx, weight in enumerate(weights): |
| | self.dataset_idx.extend([idx] * weight) |
| | random.shuffle(self.dataset_idx) |
| |
|
| | self.datasets = [] |
| | for dataset in datasets: |
| | random.shuffle(dataset) |
| | self.datasets.append({ |
| | 'elements': dataset, |
| | 'pointer': 0, |
| | }) |
| |
|
| | def __iter__(self): |
| | for _ in range(int(self.__len__())): |
| | |
| | if self.dataset_idx_pointer >= len(self.dataset_idx): |
| | self.dataset_idx_pointer = 0 |
| | random.shuffle(self.dataset_idx) |
| |
|
| | dataset_idx = self.dataset_idx[self.dataset_idx_pointer] |
| | self.dataset_idx_pointer += 1 |
| |
|
| | |
| | dataset = self.datasets[dataset_idx] |
| | batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets |
| |
|
| | batch = [] |
| | texts_in_batch = set() |
| | guid_in_batch = set() |
| | while len(batch) < batch_size: |
| | example = dataset['elements'][dataset['pointer']] |
| |
|
| | valid_example = True |
| | |
| | for text in example.texts: |
| | text_norm = text.strip().lower() |
| | if text_norm in texts_in_batch: |
| | valid_example = False |
| |
|
| | texts_in_batch.add(text_norm) |
| |
|
| | |
| | if example.guid is not None: |
| | valid_example = valid_example and example.guid not in guid_in_batch |
| | guid_in_batch.add(example.guid) |
| |
|
| |
|
| | if valid_example: |
| | if self.allow_swap and random.random() > 0.5: |
| | example.texts[0], example.texts[1] = example.texts[1], example.texts[0] |
| |
|
| | batch.append(example) |
| |
|
| | dataset['pointer'] += 1 |
| | if dataset['pointer'] >= len(dataset['elements']): |
| | dataset['pointer'] = 0 |
| | random.shuffle(dataset['elements']) |
| |
|
| | yield self.collate_fn(batch) if self.collate_fn is not None else batch |
| |
|
| | def __len__(self): |
| | return int(self.dataset_lengths_sum / self.batch_size_pairs) |
| |
|