File size: 8,028 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from specforge.distributed import get_tp_group, shard_tensor
class RowParallelLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias=True,
device=None,
dtype=None,
kv_head_replicas=False,
layout_type: str = "normal",
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.layout_type = layout_type
self.tp_group = get_tp_group()
self.tp_size = dist.get_world_size(self.tp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.in_features = in_features
self.out_features = out_features
if kv_head_replicas:
self.in_features_per_shard = in_features
else:
self.in_features_per_shard = in_features // self.tp_size
self.weight = nn.Parameter(
torch.empty(self.out_features, self.in_features_per_shard, **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
self._register_load_state_dict_pre_hook(self.shard_state_dict)
def shard_state_dict(self, state_dict, *args):
"""
This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type.
"""
if self.layout_type == "normal":
self.handle_normal_layout(state_dict, *args)
else:
raise ValueError(f"Invalid layout type: {self.layout_type}")
def handle_normal_layout(self, state_dict, *args):
# shard the weights
if "weight" in state_dict:
state_dict["weight"] = shard_tensor(state_dict["weight"], self.tp_group, -1)
if "bias" in state_dict and self.tp_rank != 0:
state_dict["bias"] = torch.zeros_like(state_dict["bias"])
def forward(self, x):
return F.linear(x, self.weight, self.bias)
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def __repr__(self):
return f"RowParallelLinear(in_features={self.in_features_per_shard}, out_features={self.out_features}, tp_size={self.tp_size}, tp_rank={self.tp_rank})"
class ColumnParallelLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias=True,
device=None,
dtype=None,
layout_type: str = "normal",
kv_head_replicas=False,
kv_head_idx=None,
total_num_kv_heads=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.layout_type = layout_type
self.tp_group = get_tp_group()
self.tp_size = dist.get_world_size(self.tp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.in_features = in_features
self.out_features = out_features
self.kv_head_replicas = kv_head_replicas
self.kv_head_idx = kv_head_idx
self.total_num_kv_heads = total_num_kv_heads
if self.kv_head_replicas:
self.out_features_per_shard = out_features
else:
self.out_features_per_shard = out_features // self.tp_size
self.weight = nn.Parameter(
torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(
torch.empty(self.out_features_per_shard, **factory_kwargs)
)
else:
self.register_parameter("bias", None)
self.reset_parameters()
self._register_load_state_dict_pre_hook(self.shard_state_dict)
def shard_state_dict(self, state_dict, *args):
"""
This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type.
"""
if self.kv_head_replicas:
assert self.kv_head_idx is not None
assert self.layout_type == "normal"
self.handle_kv_head_replicas(state_dict, *args)
else:
if self.layout_type == "normal":
self.handle_normal_layout(state_dict, *args)
elif self.layout_type == "merged_qkv":
self.handle_merged_qkv(state_dict, *args)
elif self.layout_type == "gate_up":
self.handle_gate_up_layout(state_dict, *args)
else:
raise ValueError(f"Invalid layout type: {self.layout_type}")
def handle_kv_head_replicas(self, state_dict, *args):
"""
This is a special case for GQA where the key/value are split according to the number of kv heads and the head which belongs to this rank.
As the TP size is larger than the number of kv heads, we only keep one kv head per rank.
"""
if "weight" in state_dict:
state_dict["weight"] = state_dict["weight"].chunk(
self.total_num_kv_heads, dim=0
)[self.kv_head_idx]
if "bias" in state_dict and state_dict["bias"] is not None:
state_dict["bias"] = state_dict["bias"].chunk(
self.total_num_kv_heads, dim=0
)[self.kv_head_idx]
def handle_normal_layout(self, state_dict, *args):
"""
This shards the weights and biases along the column dimension.
"""
# shard the weights
if "weight" in state_dict:
state_dict["weight"] = shard_tensor(state_dict["weight"], self.tp_group, 0)
if "bias" in state_dict and state_dict["bias"] is not None:
state_dict["bias"] = shard_tensor(state_dict["bias"], self.tp_group, 0)
def handle_gate_up_layout(self, state_dict, *args):
"""
This handles the gate_up layout where the gate and up weights are concatenated along the column dimension.
"""
if "weight" in state_dict:
gate, up = state_dict["weight"].chunk(2, dim=0)
gate = shard_tensor(gate, self.tp_group, 0)
up = shard_tensor(up, self.tp_group, 0)
state_dict["weight"] = torch.cat((gate, up), dim=0)
if "bias" in state_dict and state_dict["bias"] is not None:
gate, up = state_dict["bias"].chunk(2, dim=0)
gate = shard_tensor(gate, self.tp_group, 0)
up = shard_tensor(up, self.tp_group, 0)
state_dict["bias"] = torch.cat((gate, up), dim=0)
def handle_merged_qkv(self, state_dict, *args):
"""
This handles the merged QKV layout where the q, k, v weights are concatenated along the column dimension.
"""
if "weight" in state_dict:
# need to split into qkv and take the correct chunk for the rank
q, k, v = state_dict["weight"].chunk(3, dim=0)
q = shard_tensor(q, self.tp_group, 0)
k = shard_tensor(k, self.tp_group, 0)
v = shard_tensor(v, self.tp_group, 0)
state_dict["weight"] = torch.cat((q, k, v), dim=0)
if "bias" in state_dict and state_dict["bias"] is not None:
q, k, v = state_dict["bias"].chunk(3, dim=0)
q = shard_tensor(q, self.tp_group, 0)
k = shard_tensor(k, self.tp_group, 0)
v = shard_tensor(v, self.tp_group, 0)
state_dict["bias"] = torch.cat((q, k, v), dim=0)
def forward(self, x):
return F.linear(x, self.weight, self.bias)
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def __repr__(self):
return f"ColumnParallelLinear(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})"
|