Lekr0's picture
Add files using upload-large-folder tool
212a146 verified
import math
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from specforge.distributed import get_tp_group, shard_tensor
class ParallelLMHead(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self.tp_group = get_tp_group()
self.tp_size = dist.get_world_size(self.tp_group)
self.tp_rank = dist.get_rank(self.tp_group)
# tp-related
self.out_features_per_shard = math.ceil(out_features / self.tp_size)
self.padded_out_features = (
self.out_features_per_shard * self.tp_size - out_features
)
assert (
self.out_features_per_shard * self.tp_size
== out_features + self.padded_out_features
)
self.weight = nn.Parameter(
torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs)
)
self.bias = (
nn.Parameter(torch.zeros(self.out_features_per_shard, **factory_kwargs))
if bias
else None
)
# init params
self.reset_parameters()
# handle weight loading
self._register_load_state_dict_pre_hook(self.shard_state_dict)
def shard_state_dict(self, state_dict, *args):
if "weight" in state_dict:
value = state_dict["weight"]
# pad this value if it is not divisible by the TP size
if value.shape[0] % self.tp_size != 0:
padding_size = self.tp_size - value.shape[0] % self.tp_size
value = F.pad(value, (0, 0, 0, padding_size))
state_dict["weight"] = shard_tensor(value, self.tp_group, 0)
if "bias" in state_dict:
value = state_dict["bias"]
# pad this value if it is not divisible by the TP size
if value.shape[0] % self.tp_size != 0:
padding_size = self.tp_size - value.shape[0] % self.tp_size
value = F.pad(value, (0, padding_size))
state_dict["bias"] = shard_tensor(value, self.tp_group, 0)
def forward(self, hidden: torch.Tensor, gather_output: bool = False):
"""
hidden: [B, T, H] or [N, H]
returns:
- if gather_output=False: local logits [*, local_vocab] and (start,end) for stitching
- if gather_output=True: full logits [*, vocab] via all-gather (use for inference)
"""
orig_shape = hidden.shape
hidden = hidden.reshape(-1, self.in_features) # [N, H]
local_logits = hidden @ self.weight.T # [N, local_vocab]
if self.bias is not None:
local_logits = local_logits + self.bias
if not gather_output or self.tp_size == 1:
return local_logits.view(
*orig_shape[:-1], self.out_features_per_shard
).contiguous()
else:
# all-gather shards along vocab dim
chunks = [torch.empty_like(local_logits) for _ in range(self.tp_size)]
dist.all_gather(chunks, local_logits, group=self.tp_group)
full = torch.cat(chunks, dim=-1)[
:, : self.out_features
] # trim padding from ceil-div
return full.view(*orig_shape[:-1], self.out_features).contiguous()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def __repr__(self):
return f"ParallelLMHead(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})"