| | from datasets import Dataset |
| | from datasets import interleave_datasets |
| | from torch.utils.data import DataLoader |
| |
|
| | def convert_to_str(batch, dataset_name): |
| | batch['a'] = [f"{dataset_name}-{e}" for e in batch['a']] |
| | return batch |
| |
|
| | def gen1(): |
| | for ii in range(1, 25): |
| | yield {"a": ii} |
| |
|
| | def gen2(): |
| | for ii in range(1, 25): |
| | yield {"a": ii} |
| |
|
| | |
| | if __name__ == '__main__': |
| | dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=2) |
| | dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=2) |
| | dataset1 = dataset1.map(lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=10, drop_last_batch=True) |
| | dataset2 = dataset2.map(lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=10, drop_last_batch=True) |
| |
|
| | interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted") |
| |
|
| | print(f"num_workers=0") |
| | loader = DataLoader(interleaved, batch_size=5, num_workers=0) |
| | i = 0 |
| | for b in loader: |
| | print(i, b['a']) |
| | i += 1 |
| |
|
| | print('=-' * 20) |
| | print(f"num_workers=1") |
| | loader = DataLoader(interleaved, batch_size=5, num_workers=1) |
| | i = 0 |
| | for b in loader: |
| | print(i, b['a']) |
| | i += 1 |
| |
|
| | print('=-' * 20) |
| | print(f"num_workers=2") |
| | loader = DataLoader(interleaved, batch_size=5, num_workers=2) |
| | i = 0 |
| | for b in loader: |
| | print(i, b['a']) |
| | i += 1 |
| |
|
| | print('=-' * 20) |
| | print(f"num_workers=3") |
| | loader = DataLoader(interleaved, batch_size=5, num_workers=3) |
| | i = 0 |
| | for b in loader: |
| | print(i, b['a']) |
| | i += 1 |