| import numpy as np |
| import os |
| import math |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data.sampler import Sampler |
| import random |
| from uniperceiver.utils import comm |
| import itertools |
|
|
| from .sampler import TrainingSampler, NaiveSampler, NodeDistributedSampler |
|
|
| from uniperceiver.datasets.unified_dataset import UnifiedDataset |
| try: |
| import deepspeed.utils.groups as groups |
| DEEPSPEED_INSTALLED = True |
| except: |
|
|
| DEEPSPEED_INSTALLED = False |
| |
|
|
|
|
|
|
|
|
| class WeightedBatchSampler(torch.utils.data.sampler.BatchSampler): |
| def __init__(self, |
| dataset: UnifiedDataset, |
| cfg, |
| task_cfg, |
| stage='train', |
| shuffle=True, |
| drop_last=True): |
| self.dataset = dataset |
| self.cfg = cfg |
| self.task_cfg = task_cfg |
|
|
| self._tasks = list(self.task_cfg.keys()) |
|
|
| |
| |
|
|
| unit_sampler = dict() |
| for name, new_cfg in self.task_cfg.items(): |
| if new_cfg.DATASETS.DATASET_NAME in [ |
| "MSCOCO", "FLICKR", "ImageNet22k", "ImageNet1k", "VG", "VideoDataSet", "K700", 'K400', 'MiT', 'MSVDDataset', 'MSRVTTDataset', |
| "RTE", "CoLA", "SST-2", "MRPC", "QQP", "QNLI", "MNLI", "MNLI_Match", "VQA" |
| ]: |
| sampler = TrainingSampler(self.dataset.datasets[name]) |
| elif new_cfg.DATASETS.DATASET_NAME in ["BooksWiki"]: |
| |
| sampler = NaiveSampler(self.dataset.datasets[name]) |
| elif new_cfg.DATASETS.DATASET_NAME in [ |
| |
| 'YFCC', 'CC12M', 'CC3M', 'SBU', 'TQAPretrain' |
| ]: |
| sampler = NodeDistributedSampler( |
| self.dataset.datasets[name], |
| shuffle=True, |
| num_replicas=comm.get_world_size(), |
| rank=comm.get_rank(), |
| local_rank=comm.get_local_rank(), |
| local_size=comm.get_local_size()) |
| else: |
| raise NotImplementedError( |
| f'please check the sampler used for this dataset {new_cfg.DATASETS.DATASET_NAME}' |
| ) |
| unit_sampler[name] = sampler |
| self.unit_sampler = unit_sampler |
|
|
| self.unit_sampler_iter = { |
| k: iter(v) |
| for k, v in self.unit_sampler.items() |
| } |
|
|
| self.sampling_weights = { |
| k: v.DATALOADER.SAMPLING_WEIGHT |
| for k, v in self.task_cfg.items() |
| } |
|
|
| self._weights = [self.sampling_weights[k] for k in self._tasks] |
|
|
| self.stage = stage |
| if self.stage == 'train': |
| self.task_batch_size = { |
| k: v.DATALOADER.TRAIN_BATCH_SIZE |
| for k, v in self.task_cfg.items() |
| } |
| else: |
| raise NotImplementedError('only train dataset supportted now!') |
|
|
|
|
|
|
| self.len = [ len_ds//bs for len_ds, bs in zip([len(ds) for ds in self.dataset.dataset_list], self.task_batch_size.values())] |
|
|
| self.special_strategy = cfg.DATALOADER.STRATEGY |
|
|
| self.count = 0 |
|
|
| self.task_index_offset = { |
| k: v |
| for k, v in zip(self.task_cfg.keys(),self.dataset.dataset_scale.tolist()) |
| } |
|
|
|
|
| def __len__(self): |
| return sum(self.len) |
|
|
| def __iter__(self): |
|
|
| batch = [] |
| while True: |
|
|
| if self.special_strategy == 'uniform': |
| task = self._tasks[comm.get_local_rank() % len(self._tasks)] |
| elif self.special_strategy == 'uniformv2': |
| task = self._tasks[(self.count + comm.get_rank()) % |
| len(self._tasks)] |
| self.count = (self.count + 1) % len(self._tasks) |
| elif self.special_strategy == 'turn': |
| task = self._tasks[self.count % len(self._tasks)] |
| self.count = (self.count + 1) % len(self._tasks) |
| else: |
| task = random.choices(self._tasks, |
| weights=self._weights)[0] |
|
|
| if self.cfg.MOE.MOE and DEEPSPEED_INSTALLED and groups.expert_parallel_is_initialized( |
| ) and groups.get_expert_data_parallel_world_size() > 1: |
| task = comm.broadcast_object( |
| task, |
| src=comm.get_rank() - |
| comm.get_rank() % groups.get_expert_parallel_world_size(), |
| group=groups.get_expert_parallel_group()) |
|
|
| """ |
| all sampler are infinite stream |
| """ |
| sample_index_offset = self.task_index_offset[task] |
| for i in range(self.task_batch_size[task]): |
| try: |
| batch.append( |
| next(self.unit_sampler_iter[task]) + sample_index_offset) |
| except: |
| self.unit_sampler_iter[task] = iter(self.unit_sampler[task]) |
| batch.append( |
| next(self.unit_sampler_iter[task]) + sample_index_offset) |
|
|
| assert len(batch) == self.task_batch_size[task] |
| yield batch |
| batch = [] |
|
|