| import random | |
| import math | |
| class NoDuplicatesDataLoader: | |
| def __init__(self, train_examples, batch_size): | |
| """ | |
| A special data loader to be used with MultipleNegativesRankingLoss. | |
| The data loader ensures that there are no duplicate sentences within the same batch | |
| """ | |
| self.batch_size = batch_size | |
| self.data_pointer = 0 | |
| self.collate_fn = None | |
| self.train_examples = train_examples | |
| random.shuffle(self.train_examples) | |
| def __iter__(self): | |
| for _ in range(self.__len__()): | |
| batch = [] | |
| texts_in_batch = set() | |
| while len(batch) < self.batch_size: | |
| example = self.train_examples[self.data_pointer] | |
| valid_example = True | |
| for text in example.texts: | |
| if text.strip().lower() in texts_in_batch: | |
| valid_example = False | |
| break | |
| if valid_example: | |
| batch.append(example) | |
| for text in example.texts: | |
| texts_in_batch.add(text.strip().lower()) | |
| self.data_pointer += 1 | |
| if self.data_pointer >= len(self.train_examples): | |
| self.data_pointer = 0 | |
| random.shuffle(self.train_examples) | |
| yield self.collate_fn(batch) if self.collate_fn is not None else batch | |
| def __len__(self): | |
| return math.floor(len(self.train_examples) / self.batch_size) |