| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import copy
|
| | import heapq
|
| | from typing import List, Tuple
|
| |
|
| | import torch
|
| | from torch import distributed as dist
|
| |
|
| |
|
| | def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
| |
|
| | class Set:
|
| | def __init__(self) -> None:
|
| | self.sum = 0
|
| | self.items = []
|
| |
|
| | def add(self, idx: int, val: int):
|
| | self.items.append((idx, val))
|
| | self.sum += val
|
| |
|
| | def merge(self, other):
|
| | for idx, val in other.items:
|
| | self.items.append((idx, val))
|
| | self.sum += val
|
| |
|
| | def __lt__(self, other):
|
| | if self.sum != other.sum:
|
| | return self.sum < other.sum
|
| | if len(self.items) != len(other.items):
|
| | return len(self.items) < len(other.items)
|
| | return self.items < other.items
|
| |
|
| | class State:
|
| | def __init__(self, items: List[Tuple[int, int]], k: int) -> None:
|
| | self.k = k
|
| |
|
| | self.sets = [Set() for _ in range(k)]
|
| | assert len(items) in [1, k], f"{len(items)} not in [1, {k}]"
|
| | for i, (idx, seqlen) in enumerate(items):
|
| | self.sets[i].add(idx=idx, val=seqlen)
|
| | self.sets = sorted(self.sets, reverse=True)
|
| |
|
| | def get_partitions(self):
|
| | partitions = []
|
| | for i in range(len(self.sets)):
|
| | cur_partition = []
|
| | for idx, _ in self.sets[i].items:
|
| | cur_partition.append(idx)
|
| | partitions.append(cur_partition)
|
| | return partitions
|
| |
|
| | def merge(self, other):
|
| | for i in range(self.k):
|
| | self.sets[i].merge(other.sets[self.k - 1 - i])
|
| | self.sets = sorted(self.sets, reverse=True)
|
| |
|
| | @property
|
| | def spread(self) -> int:
|
| | return self.sets[0].sum - self.sets[-1].sum
|
| |
|
| | def __lt__(self, other):
|
| |
|
| |
|
| |
|
| | if self.spread != other.spread:
|
| | return self.spread > other.spread
|
| | return self.sets[0] > other.sets[0]
|
| |
|
| | def __repr__(self) -> str:
|
| | repr_str = "["
|
| | for i in range(self.k):
|
| | if i > 0:
|
| | repr_str += ","
|
| | repr_str += "{"
|
| | for j, (_, seqlen) in enumerate(self.sets[i].items):
|
| | if j > 0:
|
| | repr_str += ","
|
| | repr_str += str(seqlen)
|
| | repr_str += "}"
|
| | repr_str += "]"
|
| | return repr_str
|
| |
|
| | sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
|
| | states_pq = []
|
| | if equal_size:
|
| | assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
|
| | for offset in range(0, len(sorted_seqlen_list), k_partitions):
|
| | items = []
|
| | for i in range(k_partitions):
|
| | seqlen, idx = sorted_seqlen_list[offset + i]
|
| | items.append((idx, seqlen))
|
| | heapq.heappush(states_pq, State(items=items, k=k_partitions))
|
| | else:
|
| | for seqlen, idx in sorted_seqlen_list:
|
| | heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))
|
| |
|
| | while len(states_pq) > 1:
|
| | state0 = heapq.heappop(states_pq)
|
| | state1 = heapq.heappop(states_pq)
|
| |
|
| | state0.merge(state1)
|
| | heapq.heappush(states_pq, state0)
|
| |
|
| | final_state = states_pq[0]
|
| | partitions = final_state.get_partitions()
|
| | if equal_size:
|
| | for i, partition in enumerate(partitions):
|
| | assert len(partition) * k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
|
| | return partitions
|
| |
|
| |
|
| | def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
| | bias = sum(seqlen_list) + 1 if equal_size else 0
|
| | sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
|
| | partitions = [[] for _ in range(k_partitions)]
|
| | partition_sums = [0 for _ in range(k_partitions)]
|
| | for seqlen, i in sorted_seqlen:
|
| | min_idx = None
|
| | for j in range(k_partitions):
|
| | if min_idx is None or partition_sums[j] < partition_sums[min_idx]:
|
| | min_idx = j
|
| | partitions[min_idx].append(i)
|
| | partition_sums[min_idx] += seqlen
|
| | if equal_size:
|
| | for i, partition in enumerate(partitions):
|
| | assert len(partition) * k_partitions == len(seqlen_list), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
|
| | return partitions
|
| |
|
| |
|
| | def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
| | """get order of seq lengths to make partitions balanced, this is
|
| | used in balacing sum of seqlength across dp ranks and microbatches
|
| | Parameters:
|
| | seqlen_list (List[int]):
|
| | seq lengths of each items
|
| | k_partitions (int):
|
| | resulting number of partitions
|
| | equal_size (bool):
|
| | if True, number of items in each partitions must be equal.
|
| | if False, only consider balancing the sum, each partition can have
|
| | variable number of items
|
| | Returns:
|
| | partitions (List[List[int]]):
|
| | return k_partitions list containing the index of items.
|
| | """
|
| | assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
|
| |
|
| | def _check_and_sort_partitions(partitions):
|
| | assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
|
| | seen_idx = set()
|
| | sorted_partitions = [None] * k_partitions
|
| | for i, partition in enumerate(partitions):
|
| | assert len(partition) > 0, f"the {i}-th partition is empty"
|
| | for idx in partition:
|
| | seen_idx.add(idx)
|
| | sorted_partitions[i] = sorted(partition)
|
| | assert seen_idx == set(range(len(seqlen_list)))
|
| | return sorted_partitions
|
| |
|
| | partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
|
| | return _check_and_sort_partitions(partitions)
|
| |
|
| |
|
| | def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
|
| |
|
| | k_partition = len(partitions)
|
| |
|
| | batch_size = len(seqlen_list) // k_partition
|
| | min_sum_seqlen = None
|
| | max_sum_seqlen = None
|
| | total_sum_seqlen = 0
|
| | for offset in range(0, len(seqlen_list), batch_size):
|
| | cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
|
| | if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
|
| | min_sum_seqlen = cur_sum_seqlen
|
| | if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
|
| | max_sum_seqlen = cur_sum_seqlen
|
| | total_sum_seqlen += cur_sum_seqlen
|
| |
|
| | balanced_sum_seqlen_list = []
|
| | for partition in partitions:
|
| | cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
|
| | balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
|
| |
|
| | min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
|
| | max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
|
| |
|
| | return {
|
| | f"{prefix}/min": min_sum_seqlen,
|
| | f"{prefix}/max": max_sum_seqlen,
|
| | f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen,
|
| | f"{prefix}/balanced_min": min_sum_seqlen_balanced,
|
| | f"{prefix}/balanced_max": max_sum_seqlen_balanced,
|
| | f"{prefix}/mean": total_sum_seqlen / len(partitions),
|
| | }
|
| |
|
| |
|
| | def ceildiv(a, b):
|
| | return -(a // -b)
|
| |
|
| |
|
| | def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_in_dp=True, min_num_micro_batch=None):
|
| | """
|
| | Split a batch into micro-batches by total token count, with optional DP sync and padding.
|
| |
|
| | Args:
|
| | batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly.
|
| | max_token_len (int): max sum of attention_mask per micro-batch.
|
| | dp_group (optional): torch.distributed group for data-parallel sync.
|
| | same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count.
|
| | min_num_micro_batch (int, optional): force at least this many splits (pads empty ones).
|
| |
|
| | Returns:
|
| | List[TensorDict]: the micro-batches.
|
| | List[List[int]]: index lists mapping each micro-batch back to original positions.
|
| | """
|
| |
|
| | max_seq_len = batch["attention_mask"].shape[-1]
|
| | assert max_token_len >= max_seq_len, f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
|
| | seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
|
| | total_seqlen = seq_len_effective.sum().item()
|
| |
|
| | num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
|
| | if min_num_micro_batch is not None:
|
| |
|
| | num_micro_batches = max(min_num_micro_batch, num_micro_batches)
|
| | if dist.is_initialized() and same_micro_num_in_dp:
|
| | num_micro_batches = torch.tensor([num_micro_batches], device="cuda")
|
| | dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
|
| | num_micro_batches = num_micro_batches.cpu().item()
|
| |
|
| | seq_len_effective = seq_len_effective.tolist()
|
| | assert num_micro_batches <= len(seq_len_effective)
|
| |
|
| | micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
|
| |
|
| | micro_batches = []
|
| |
|
| | for partition in micro_bsz_idx:
|
| | curr_micro_batch = []
|
| | for idx in partition:
|
| | curr_micro_batch.append(batch[idx : idx + 1])
|
| | curr_micro_batch = torch.cat(curr_micro_batch)
|
| |
|
| | micro_batches.append(curr_micro_batch)
|
| |
|
| | return micro_batches, micro_bsz_idx
|
| |
|
| |
|
| | def get_reverse_idx(idx_map):
|
| | reverse_idx_map = copy.deepcopy(idx_map)
|
| |
|
| | for i, idx in enumerate(idx_map):
|
| | reverse_idx_map[idx] = i
|
| |
|
| | return reverse_idx_map
|
| |
|