File size: 4,568 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import math
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from specforge.distributed import get_tp_group, shard_tensor


class VocabParallelEmbedding(nn.Module):

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.sparse = sparse

        if padding_idx is not None:
            if padding_idx > 0:
                assert (
                    padding_idx < self.num_embeddings
                ), "Padding_idx must be within num_embeddings"
            elif padding_idx < 0:
                assert (
                    padding_idx >= -self.num_embeddings
                ), "Padding_idx must be within num_embeddings"
                padding_idx = self.num_embeddings + padding_idx

        # tp-realted
        self.tp_group = get_tp_group()
        self.tp_rank = dist.get_rank(self.tp_group)
        self.tp_size = dist.get_world_size(self.tp_group)

        # deal with the case where the embedding is not divisible by the TP size
        self.num_embeddings_per_shard = math.ceil(num_embeddings / self.tp_size)
        self.padded_num_embeddings = (
            self.num_embeddings_per_shard * self.tp_size - self.num_embeddings
        )
        self.vocab_start_index = self.tp_rank * self.num_embeddings_per_shard
        self.vocab_end_index = min(
            self.vocab_start_index + self.num_embeddings_per_shard,
            self.num_embeddings,
        )

        if (
            padding_idx is not None
            and padding_idx >= self.vocab_start_index
            and padding_idx < self.vocab_end_index
        ):
            self.padding_idx = padding_idx - self.vocab_start_index
        else:
            self.padding_idx = None

        self.weight = nn.Parameter(
            torch.empty(
                (self.num_embeddings_per_shard, self.embedding_dim), **factory_kwargs
            ),
            requires_grad=True,
        )
        self.reset_parameters()
        self._register_load_state_dict_pre_hook(self.shard_state_dict)

    def shard_state_dict(self, state_dict, *args):
        if "weight" in state_dict:
            value = state_dict["weight"]

            # pad this value if it is not divisible by the TP size
            if value.shape[0] % self.tp_size != 0:
                padding_size = self.tp_size - value.shape[0] % self.tp_size
                value = F.pad(value, (0, 0, 0, padding_size))
            state_dict["weight"] = shard_tensor(value, self.tp_group, 0)

    def reset_parameters(self) -> None:
        torch.nn.init.normal_(self.weight)
        self._fill_padding_idx_with_zero()

    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def generate_mask(self, input_):
        # generate the mask for the vocab which is only owned by the current rank
        mask = (input_ >= self.vocab_start_index) & (input_ < self.vocab_end_index)
        return mask

    def forward(self, input_):
        if self.tp_size > 1:
            # Build the mask.
            mask = self.generate_mask(input_)
            masked_input = input_ - self.vocab_start_index
            masked_input[~mask] = 0
        else:
            masked_input = input_

        output_parallel = F.embedding(
            masked_input,
            self.weight,
            padding_idx=self.padding_idx,
            max_norm=self.max_norm,
            norm_type=self.norm_type,
            scale_grad_by_freq=self.scale_grad_by_freq,
            sparse=self.sparse,
        )

        # Mask the output embedding.
        if self.tp_size > 1:
            output_parallel[~mask] = 0
            # Reduce across all the model parallel GPUs.
            dist.all_reduce(output_parallel, op=dist.ReduceOp.SUM, group=self.tp_group)
            output = output_parallel
        else:
            output = output_parallel
        return output