| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import heapq |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | from tensordict import TensorDict |
| | from torch import distributed as dist |
| |
|
| |
|
| | class Set: |
| | def __init__(self) -> None: |
| | self.sum = 0 |
| | self.items = [] |
| |
|
| | def add(self, idx: int, val: int): |
| | self.items.append((idx, val)) |
| | self.sum += val |
| |
|
| | def merge(self, other): |
| | for idx, val in other.items: |
| | self.items.append((idx, val)) |
| | self.sum += val |
| |
|
| | def __lt__(self, other): |
| | if self.sum != other.sum: |
| | return self.sum < other.sum |
| | if len(self.items) != len(other.items): |
| | return len(self.items) < len(other.items) |
| | return self.items < other.items |
| |
|
| |
|
| | class State: |
| | def __init__(self, items: List[Tuple[int, int]], k: int) -> None: |
| | self.k = k |
| | |
| | self.sets = [Set() for _ in range(k)] |
| | assert len(items) in [1, k], f"{len(items)} not in [1, {k}]" |
| | for i, (idx, seqlen) in enumerate(items): |
| | self.sets[i].add(idx=idx, val=seqlen) |
| | self.sets = sorted(self.sets, reverse=True) |
| |
|
| | def get_partitions(self): |
| | partitions = [] |
| | for i in range(len(self.sets)): |
| | cur_partition = [] |
| | for idx, _ in self.sets[i].items: |
| | cur_partition.append(idx) |
| | partitions.append(cur_partition) |
| | return partitions |
| |
|
| | def merge(self, other): |
| | for i in range(self.k): |
| | self.sets[i].merge(other.sets[self.k - 1 - i]) |
| | self.sets = sorted(self.sets, reverse=True) |
| |
|
| | @property |
| | def spread(self) -> int: |
| | return self.sets[0].sum - self.sets[-1].sum |
| |
|
| | def __lt__(self, other): |
| | |
| | |
| | |
| | if self.spread != other.spread: |
| | return self.spread > other.spread |
| | return self.sets[0] > other.sets[0] |
| |
|
| | def __repr__(self) -> str: |
| | repr_str = "[" |
| | for i in range(self.k): |
| | if i > 0: |
| | repr_str += "," |
| | repr_str += "{" |
| | for j, (_, seqlen) in enumerate(self.sets[i].items): |
| | if j > 0: |
| | repr_str += "," |
| | repr_str += str(seqlen) |
| | repr_str += "}" |
| | repr_str += "]" |
| | return repr_str |
| |
|
| |
|
| | def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): |
| | |
| | sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) |
| | states_pq: List[State] = [] |
| | if equal_size: |
| | assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" |
| | for offset in range(0, len(sorted_seqlen_list), k_partitions): |
| | items = [] |
| | for i in range(k_partitions): |
| | seqlen, idx = sorted_seqlen_list[offset + i] |
| | items.append((idx, seqlen)) |
| | heapq.heappush(states_pq, State(items=items, k=k_partitions)) |
| | else: |
| | for seqlen, idx in sorted_seqlen_list: |
| | heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions)) |
| |
|
| | while len(states_pq) > 1: |
| | state0 = heapq.heappop(states_pq) |
| | state1 = heapq.heappop(states_pq) |
| | |
| | state0.merge(state1) |
| | heapq.heappush(states_pq, state0) |
| |
|
| | final_state = states_pq[0] |
| | partitions = final_state.get_partitions() |
| | if equal_size: |
| | for i, partition in enumerate(partitions): |
| | assert len(partition) * k_partitions == len(seqlen_list), ( |
| | f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" |
| | ) |
| | return partitions |
| |
|
| |
|
| | def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool): |
| | bias = sum(seqlen_list) + 1 if equal_size else 0 |
| | sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)] |
| | partitions = [[] for _ in range(k_partitions)] |
| | partition_sums = [0 for _ in range(k_partitions)] |
| | for seqlen, i in sorted_seqlen: |
| | min_idx = None |
| | for j in range(k_partitions): |
| | if min_idx is None or partition_sums[j] < partition_sums[min_idx]: |
| | min_idx = j |
| | partitions[min_idx].append(i) |
| | partition_sums[min_idx] += seqlen |
| | if equal_size: |
| | for i, partition in enumerate(partitions): |
| | assert len(partition) * k_partitions == len(seqlen_list), ( |
| | f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" |
| | ) |
| | return partitions |
| |
|
| |
|
| | def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): |
| | """get order of seq lengths to make partitions balanced, this is |
| | used in balacing sum of seqlength across dp ranks and microbatches |
| | Parameters: |
| | seqlen_list (List[int]): |
| | seq lengths of each items |
| | k_partitions (int): |
| | resulting number of partitions |
| | equal_size (bool): |
| | if True, number of items in each partitions must be equal. |
| | if False, only consider balancing the sum, each partition can have |
| | variable number of items |
| | Returns: |
| | partitions (List[List[int]]): |
| | return k_partitions list containing the index of items. |
| | """ |
| | assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" |
| |
|
| | def _check_and_sort_partitions(partitions): |
| | assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" |
| | seen_idx = set() |
| | sorted_partitions = [None] * k_partitions |
| | for i, partition in enumerate(partitions): |
| | assert len(partition) > 0, f"the {i}-th partition is empty" |
| | for idx in partition: |
| | seen_idx.add(idx) |
| | sorted_partitions[i] = sorted(partition) |
| | assert seen_idx == set(range(len(seqlen_list))) |
| | return sorted_partitions |
| |
|
| | partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) |
| | return _check_and_sort_partitions(partitions) |
| |
|
| |
|
| | def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix): |
| | |
| | k_partition = len(partitions) |
| | |
| | batch_size = len(seqlen_list) // k_partition |
| | min_sum_seqlen = None |
| | max_sum_seqlen = None |
| | total_sum_seqlen = 0 |
| | for offset in range(0, len(seqlen_list), batch_size): |
| | cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) |
| | if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: |
| | min_sum_seqlen = cur_sum_seqlen |
| | if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: |
| | max_sum_seqlen = cur_sum_seqlen |
| | total_sum_seqlen += cur_sum_seqlen |
| |
|
| | balanced_sum_seqlen_list = [] |
| | for partition in partitions: |
| | cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition]) |
| | balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced) |
| | |
| | min_sum_seqlen_balanced = min(balanced_sum_seqlen_list) |
| | max_sum_seqlen_balanced = max(balanced_sum_seqlen_list) |
| |
|
| | return { |
| | f"{prefix}/min": min_sum_seqlen, |
| | f"{prefix}/max": max_sum_seqlen, |
| | f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen, |
| | f"{prefix}/balanced_min": min_sum_seqlen_balanced, |
| | f"{prefix}/balanced_max": max_sum_seqlen_balanced, |
| | f"{prefix}/mean": total_sum_seqlen / len(partitions), |
| | } |
| |
|
| |
|
| | def ceildiv(a, b): |
| | return -(a // -b) |
| |
|
| |
|
| | def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): |
| | """Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len |
| | and the number of valid tokens in each micro batch is well balanced. |
| | """ |
| | |
| | max_seq_len = batch["attention_mask"].shape[-1] |
| | assert max_token_len >= max_seq_len, ( |
| | f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" |
| | ) |
| |
|
| | seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) |
| | total_seqlen = seq_len_effective.sum().item() |
| | num_micro_batches = ceildiv(total_seqlen, max_token_len) |
| | if dist.is_initialized(): |
| | num_micro_batches = torch.tensor([num_micro_batches], device="cuda") |
| | dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) |
| | num_micro_batches = num_micro_batches.cpu().item() |
| |
|
| | seq_len_effective = seq_len_effective.tolist() |
| | assert num_micro_batches <= len(seq_len_effective) |
| |
|
| | micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) |
| |
|
| | micro_batches = [] |
| |
|
| | for partition in micro_bsz_idx: |
| | curr_micro_batch = [] |
| | for idx in partition: |
| | curr_micro_batch.append(batch[idx : idx + 1]) |
| | curr_micro_batch = torch.cat(curr_micro_batch) |
| |
|
| | micro_batches.append(curr_micro_batch) |
| |
|
| | return micro_batches, micro_bsz_idx |
| |
|
| |
|
| | def get_reverse_idx(idx_map): |
| | reverse_idx_map = copy.deepcopy(idx_map) |
| |
|
| | for i, idx in enumerate(idx_map): |
| | reverse_idx_map[idx] = i |
| |
|
| | return reverse_idx_map |
| |
|