| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| QUANTS = [ |
| None |
| ] |
|
|
|
|
| try: |
| from flashinfer import nvfp4_quantize, mm_fp4, SfLayout |
|
|
| QUANTS.append("nvfp4") |
| except ImportError: |
| pass |
|
|
|
|
| @torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) |
| def fp4_linear( |
| a_bf16: torch.Tensor, |
| b_fp4_T: torch.Tensor, |
| a_global_sf: torch.Tensor, |
| b_sf_T: torch.Tensor, |
| alpha: torch.Tensor, |
| ) -> torch.Tensor: |
| a_fp4, a_sf = nvfp4_quantize( |
| a_bf16, |
| a_global_sf, |
| sfLayout=SfLayout.layout_128x4, |
| do_shuffle=False, |
| ) |
| return mm_fp4( |
| a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass" |
| ) |
|
|
|
|
| @fp4_linear.register_fake |
| def _fp4_linear_fake( |
| a_bf16: torch.Tensor, |
| b_fp4_T: torch.Tensor, |
| a_global_sf: torch.Tensor, |
| b_sf_T: torch.Tensor, |
| alpha: torch.Tensor, |
| ) -> torch.Tensor: |
| return torch.empty( |
| (a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16 |
| ) |
|
|
|
|
| class FP4Linear(nn.Module): |
| """FP4 Linear layer using FlashInfer's NVFP4 quantization.""" |
|
|
| def __init__(self, lin: nn.Linear): |
| super().__init__() |
|
|
| self.in_features = lin.in_features |
| self.out_features = lin.out_features |
|
|
| |
| assert self.in_features % 32 == 0 and self.out_features % 32 == 0, ( |
| "features % 32 != 0, nvfp4 disallowed" |
| ) |
|
|
| |
| self.weight = nn.Parameter(lin.weight.detach().clone()) |
|
|
| |
| self._weight_fp4_T: Optional[torch.Tensor] = None |
| self._weight_scales_T: Optional[torch.Tensor] = None |
| self._alpha: Optional[torch.Tensor] = None |
| self._dummy_scale: Optional[torch.Tensor] = None |
| self._weight_global_sf = None |
|
|
| with torch.no_grad(): |
| |
| self._dummy_scale = torch.full( |
| (1,), 1.0, device=self.weight.device, dtype=torch.float32 |
| ) |
| weight_bf16 = ( |
| self.weight.to(torch.bfloat16).to(self.weight.device).contiguous() |
| ) |
| weight_amax = weight_bf16.float().abs().nan_to_num().max() |
| self._weight_global_sf = (1.0) / weight_amax |
| self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale) |
| w_fp4, w_sf = nvfp4_quantize( |
| weight_bf16, |
| self._weight_global_sf, |
| sfLayout=SfLayout.layout_128x4, |
| do_shuffle=False, |
| ) |
| self._weight_fp4_T = w_fp4.t() |
| self._weight_scales_T = w_sf.t() |
|
|
| |
| assert self.weight.is_cuda, "Weights need to be on GPU before quantization" |
| |
| lazy_x = torch.zeros( |
| (1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16 |
| ) |
| fp4_linear( |
| lazy_x, |
| self._weight_fp4_T, |
| self._dummy_scale, |
| self._weight_scales_T, |
| self._alpha, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass using FP4 quantization and FlashInfer GEMM.""" |
| x_flat = x.reshape(-1, x.shape[-1]) |
| y = fp4_linear( |
| x_flat.to(torch.bfloat16).contiguous(), |
| self._weight_fp4_T, |
| self._dummy_scale, |
| self._weight_scales_T, |
| self._alpha, |
| ) |
| return y.reshape(x.shape[:-1] + (-1,)) |
|
|
|
|
| class FP8W8A8Linear(nn.Module): |
| __constants__ = ("in_features", "out_features") |
|
|
| def __init__(self, lin: nn.Linear): |
| super().__init__() |
| self.in_features, self.out_features = lin.in_features, lin.out_features |
|
|
| f8 = torch.float8_e4m3fn |
| inv = 1.0 / float(torch.finfo(f8).max) |
| self._inv = inv |
|
|
| w = lin.weight.detach() |
| ws = (w.abs().amax() * inv).clamp_min(1e-8).float() |
| wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() |
| self.register_buffer("wT", wf8.t()) |
| self.register_buffer("ws", ws) |
|
|
| if lin.bias is None: |
| self.bias = None |
| else: |
| self.register_buffer("bias", lin.bias.detach().to(torch.float16)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| s = x.shape |
| x2 = x.reshape(-1, s[-1]) |
|
|
| xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() |
| xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous() |
|
|
| y = torch._scaled_mm( |
| xf8, |
| self.wT, |
| xs, |
| self.ws, |
| bias=self.bias, |
| out_dtype=torch.float16, |
| use_fast_accum=True, |
| ) |
| return y.reshape(*s[:-1], self.out_features).to(x.dtype) |
|
|
|
|
| class FP8Linear(nn.Module): |
| def __init__(self, lin: nn.Linear): |
| super().__init__() |
| self.in_features, self.out_features = lin.in_features, lin.out_features |
|
|
| self.bias = ( |
| nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn)) |
| if lin.bias is not None |
| else None |
| ) |
| w_amax = lin.weight.data.clone().amax().float().squeeze() |
| w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn) |
| self.register_buffer("w_amax", w_amax) |
| self.register_buffer("weightT", w.t()) |
| self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass using FP8 matmul. |
| |
| Args: |
| x: Input tensor of shape [..., in_features] (flattens if > 2D) |
| |
| Returns: |
| Output tensor of shape [..., out_features] in BF16 format, unflattened if input is > 2D |
| """ |
|
|
| |
| x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous() |
|
|
| result = torch._scaled_mm( |
| x_fp8, |
| self.weightT, |
| bias=self.bias, |
| scale_a=self.dummy_scale, |
| scale_b=self.w_amax, |
| out_dtype=torch.bfloat16, |
| use_fast_accum=True, |
| ) |
|
|
| return result.reshape(x.shape[:-1] + (-1,)) |
|
|
|
|
| def quantize_model(model: nn.Module, quant: str): |
| if quant is None: |
| return model |
|
|
| def eligible(m: nn.Module) -> bool: |
| w = getattr(m, "weight", None) |
| if not isinstance(m, nn.Linear): |
| return False |
| if getattr(w, "dtype", None) != torch.bfloat16: |
| return False |
| o, k = w.shape |
| return (o % 32 == 0) and (k % 32 == 0) |
|
|
| new_linear = { |
| "w8a8": FP8W8A8Linear, |
| "nvfp4": FP4Linear, |
| "fp8": FP8Linear, |
| }[quant] |
|
|
| for name, child in model.named_children(): |
| setattr(model, name, new_linear(child)) if eligible(child) else quantize_model( |
| child, quant |
| ) |
| return model |
|
|