File size: 4,568 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import math
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from specforge.distributed import get_tp_group, shard_tensor
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
# tp-realted
self.tp_group = get_tp_group()
self.tp_rank = dist.get_rank(self.tp_group)
self.tp_size = dist.get_world_size(self.tp_group)
# deal with the case where the embedding is not divisible by the TP size
self.num_embeddings_per_shard = math.ceil(num_embeddings / self.tp_size)
self.padded_num_embeddings = (
self.num_embeddings_per_shard * self.tp_size - self.num_embeddings
)
self.vocab_start_index = self.tp_rank * self.num_embeddings_per_shard
self.vocab_end_index = min(
self.vocab_start_index + self.num_embeddings_per_shard,
self.num_embeddings,
)
if (
padding_idx is not None
and padding_idx >= self.vocab_start_index
and padding_idx < self.vocab_end_index
):
self.padding_idx = padding_idx - self.vocab_start_index
else:
self.padding_idx = None
self.weight = nn.Parameter(
torch.empty(
(self.num_embeddings_per_shard, self.embedding_dim), **factory_kwargs
),
requires_grad=True,
)
self.reset_parameters()
self._register_load_state_dict_pre_hook(self.shard_state_dict)
def shard_state_dict(self, state_dict, *args):
if "weight" in state_dict:
value = state_dict["weight"]
# pad this value if it is not divisible by the TP size
if value.shape[0] % self.tp_size != 0:
padding_size = self.tp_size - value.shape[0] % self.tp_size
value = F.pad(value, (0, 0, 0, padding_size))
state_dict["weight"] = shard_tensor(value, self.tp_group, 0)
def reset_parameters(self) -> None:
torch.nn.init.normal_(self.weight)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def generate_mask(self, input_):
# generate the mask for the vocab which is only owned by the current rank
mask = (input_ >= self.vocab_start_index) & (input_ < self.vocab_end_index)
return mask
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
mask = self.generate_mask(input_)
masked_input = input_ - self.vocab_start_index
masked_input[~mask] = 0
else:
masked_input = input_
output_parallel = F.embedding(
masked_input,
self.weight,
padding_idx=self.padding_idx,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse,
)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[~mask] = 0
# Reduce across all the model parallel GPUs.
dist.all_reduce(output_parallel, op=dist.ReduceOp.SUM, group=self.tp_group)
output = output_parallel
else:
output = output_parallel
return output
|