from random import randint import torch import torch.nn.functional as F import triton import triton.language as tl from scipy.linalg import hadamard def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device): return torch.tensor( hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device, requires_grad=False, ) def rerotate_hadamard(hadamard_matrix): signs = torch.diag( torch.randint( 0, 2, (hadamard_matrix.size(0),), device=hadamard_matrix.device, dtype=hadamard_matrix.dtype, ) * 2 - 1 ) return hadamard_matrix @ signs @triton.jit def _rtn_fp4(x): x_abs = tl.abs(x) x_sign = tl.where(x > 0, 1, -1) x_fp4_abs = tl.where( x_abs >= 5, 6, tl.where(x_abs >= 3.5, 4, tl.where(x_abs >= 2.5, 3, tl.where(x_abs >= 1.75, 2, tl.where(x_abs >= 1.25, 1.5, tl.where(x_abs >= 0.75, 1, tl.where(x_abs >= 0.25, 0.5, 0.0))))))) return x_fp4_abs * x_sign @triton.jit def _get_scales(x, amax, val_max, scales_max): s_dec = tl.where(amax == 0.0, 1.0, amax / scales_max / val_max) s_dec_b = tl.max(tl.abs(x), axis=-1, keep_dims=True) / val_max s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32) s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3) return s_dec_b_e4m3, s_dec @triton.jit def _get_alt_scales(x, val_max, s_dec): s_dec_b = tl.max(tl.abs(x), axis=-1, keep_dims=True) / val_max s_dec_b_e4m3 = (s_dec_b * (6 / 4) / s_dec).to(tl.float8e4nv).to(tl.float32) s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3) return s_dec_b_e4m3 @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 64 * 32}), triton.Config({"BLOCK_SIZE": 128 * 32}), triton.Config({"BLOCK_SIZE": 256 * 32}), triton.Config({"BLOCK_SIZE": 512 * 32}), ], key=[], ) @triton.jit def _rtn_1x16s_fp4_kernel( x_ptr, amax_ptr, output_ptr, n_elements: tl.constexpr, scale_override: tl.constexpr, group_size: tl.constexpr, four_over_six: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) start_idx = pid * BLOCK_SIZE offsets = start_idx + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x_flat = tl.load(x_ptr + offsets, mask=mask, other=0.0) x_grouped = tl.reshape(x_flat, (BLOCK_SIZE // group_size, group_size)) scales_max = 256.00 if four_over_six else 448.00 val_max = 6.0 / scale_override amax = tl.load(amax_ptr) s_dec_b_e4m3, s_dec = _get_scales(x_grouped, amax, val_max, scales_max) x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec) x_fp4 = _rtn_fp4(x_scaled) x_dequantized = x_fp4 * (s_dec_b_e4m3 * s_dec) if not four_over_six: best_x_dequantized = x_dequantized else: alt_s_dec_b_e4m3 = _get_alt_scales(x_grouped, val_max, s_dec) alt_x_scaled = x_grouped / (alt_s_dec_b_e4m3 * s_dec) alt_x_fp4 = _rtn_fp4(alt_x_scaled) alt_x_dequantized = alt_x_fp4 * (alt_s_dec_b_e4m3 * s_dec) error_six = tl.sum((x_grouped - x_dequantized) * (x_grouped - x_dequantized), axis=-1, keep_dims=True) error_four = tl.sum((x_grouped - alt_x_dequantized) * (x_grouped - alt_x_dequantized), axis=-1, keep_dims=True) best_x_dequantized = tl.where(error_six <= error_four, x_dequantized, alt_x_dequantized) x_dequantized_flat = tl.reshape(best_x_dequantized, (BLOCK_SIZE,)) tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask) @torch.compiler.disable() def rtn_1x16s_fp4(x, scale_override: float, group_size: int, four_over_six: bool): x = x.contiguous() output = torch.empty_like(x) n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _rtn_1x16s_fp4_kernel[grid]( x_ptr=x, amax_ptr=x.abs().max(), output_ptr=output, n_elements=n_elements, scale_override=scale_override, group_size=group_size, four_over_six=four_over_six, ) return output @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 64 * 32}), triton.Config({"BLOCK_SIZE": 128 * 32}), triton.Config({"BLOCK_SIZE": 256 * 32}), triton.Config({"BLOCK_SIZE": 512 * 32}), ], key=[], ) @triton.jit def _eden_1x16s_fp4_kernel( x_ptr, hadamard_matrix_ptr, current_amax_ptr, output_ptr, next_amax_ptr, n_elements: tl.constexpr, hadamard_dim: tl.constexpr, scale_override: tl.constexpr, group_size: tl.constexpr, seed: int, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) start_idx = pid * BLOCK_SIZE offsets = start_idx + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x_flat = tl.load(x_ptr + offsets, mask=mask, other=0.0) offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim) hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim) x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) x_had = tl.dot(x, hadamard_matrix) tl.atomic_max(next_amax_ptr, tl.max(tl.abs(x_had)).to(tl.float32), sem="relaxed") x_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size)) scales_max = 255.99 val_max = 6.0 / scale_override amax = tl.load(current_amax_ptr) s_dec = tl.where(amax == 0.0, 1.0, amax / scales_max / val_max) s_dec_b = tl.max(tl.abs(x_grouped), axis=-1, keep_dims=True) / val_max s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32) s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3) x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec) x_scaled_abs = tl.abs(x_scaled) x_scaled_sign = tl.where(x_scaled > 0, 1, -1) x_fp4 = tl.where( x_scaled_abs >= 5, 6, tl.where(x_scaled_abs >= 3.5, 4, tl.where(x_scaled_abs >= 2.5, 3, tl.where(x_scaled_abs >= 1.75, 2, tl.where(x_scaled_abs >= 1.25, 1.5, tl.where(x_scaled_abs >= 0.75, 1, tl.where(x_scaled_abs >= 0.25, 0.5, 0))))))) * x_scaled_sign x_scaled = tl.reshape(x_scaled, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True) denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True) correction = tl.where(denom == 0.0, 1.0, num / denom) scales = tl.reshape(s_dec_b_e4m3, (BLOCK_SIZE // hadamard_dim, hadamard_dim // group_size)) corrected_scales = tl.reshape(scales * correction, (BLOCK_SIZE // group_size, 1)) bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True) prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32) currscale = tl.cast((bitscales), tl.float8e4nv, bitcast=True).to(tl.float32) nextscale = tl.cast((bitscales + 1), tl.float8e4nv, bitcast=True).to(tl.float32) up = tl.where(currscale > corrected_scales, currscale, nextscale) down = tl.where(currscale > corrected_scales, prevscale, currscale) prob_up = (corrected_scales - down) / (up - down) scale_start_idx = pid * (BLOCK_SIZE // group_size) scale_offsets = scale_start_idx + tl.arange(0, BLOCK_SIZE // group_size) sampled_prob = tl.rand(seed, scale_offsets).reshape(BLOCK_SIZE // group_size, 1) scales = tl.where(sampled_prob < prob_up, up, down) scales = tl.reshape(scales, (BLOCK_SIZE // group_size, 1)) x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // group_size, group_size)) x_dequantized = x_fp4 * scales * s_dec x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,)) tl.store(output_ptr + offsets, x_dequantized_flat.to(x_ptr.dtype.element_ty), mask=mask) @torch.compiler.disable() def eden_1x16s_fp4(x, hadamard_matrix, scale_override: float, group_size: int, current_amax): hadamard_dim = hadamard_matrix.size(0) x = x.contiguous() hadamard_matrix = hadamard_matrix.T.contiguous() output = torch.empty_like(x) seed = randint(0, 1_000_000) next_amax = torch.zeros_like(current_amax) n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _eden_1x16s_fp4_kernel[grid]( x_ptr=x, hadamard_matrix_ptr=hadamard_matrix, current_amax_ptr=current_amax, output_ptr=output, next_amax_ptr=next_amax, n_elements=n_elements, hadamard_dim=hadamard_dim, scale_override=scale_override, group_size=group_size, seed=seed, ) return output, next_amax class AmaxStorage: __slots__ = ("e_ht_amax", "weght_tht_amax", "e_tht_amax", "input_tht_amax") def __init__(self): self.e_ht_amax = None self.weght_tht_amax = None self.e_tht_amax = None self.input_tht_amax = None class FakeQuartetFn(torch.autograd.Function): group_size = 16 forward_scale_override = 1.0 backward_scale_override = (17 / 16) * 0.93 hadamard_matrix = None @torch.compile(dynamic=False) @staticmethod def forward(ctx, input, weight, amax_storage, delayed_amax, disable_forward_quant, disable_backward_quant, four_over_six): ctx.batch = input.shape[0] ctx.seq = input.shape[1] ctx.in_dim = weight.shape[1] ctx.out_dim = weight.shape[0] ctx.delayed_amax = delayed_amax ctx.amax_storage = amax_storage ctx.disable_backward_quant = disable_backward_quant if disable_forward_quant: input_fq = input weight_fq = weight else: input_fq = rtn_1x16s_fp4(input, FakeQuartetFn.forward_scale_override, FakeQuartetFn.group_size, four_over_six) weight_fq = rtn_1x16s_fp4(weight, FakeQuartetFn.forward_scale_override, FakeQuartetFn.group_size, four_over_six) ctx.save_for_backward(input_fq, weight_fq) return F.linear(input_fq, weight_fq) @staticmethod def backward(ctx, grad_output): input_fq, weight_fq = ctx.saved_tensors dtype = grad_output.dtype input_fq = input_fq.to(dtype).reshape(ctx.batch * ctx.seq, ctx.in_dim) weight_fq = weight_fq.to(dtype) grad_output = grad_output.reshape(ctx.batch * ctx.seq, ctx.out_dim) FakeQuartetFn.hadamard_matrix = rerotate_hadamard(FakeQuartetFn.hadamard_matrix) if ctx.disable_backward_quant: grad_input = F.linear(grad_output, weight_fq.T, None).view(ctx.batch, ctx.seq, ctx.in_dim) grad_weight = F.linear(grad_output.T, input_fq.T, None) return grad_input, grad_weight, None, None, None, None, None had = FakeQuartetFn.hadamard_matrix.to(grad_output.dtype) bso = FakeQuartetFn.backward_scale_override gs = FakeQuartetFn.group_size # EW: grad_output @ weight^T if ctx.amax_storage.e_ht_amax is None or not ctx.delayed_amax: ctx.amax_storage.e_ht_amax = (grad_output.reshape(-1, had.size(0)) @ had.T).abs().max().float() e_ht_fp4, ctx.amax_storage.e_ht_amax = eden_1x16s_fp4(grad_output, had, bso, gs, ctx.amax_storage.e_ht_amax) if ctx.amax_storage.weght_tht_amax is None or not ctx.delayed_amax: ctx.amax_storage.weght_tht_amax = (weight_fq.T.reshape(-1, had.size(0)) @ had.T).abs().max().float() weight_tht_fp4, ctx.amax_storage.weght_tht_amax = eden_1x16s_fp4(weight_fq.T, had, bso, gs, ctx.amax_storage.weght_tht_amax) grad_input = F.linear(e_ht_fp4, weight_tht_fp4, None).view(ctx.batch, ctx.seq, ctx.in_dim) # EtX: grad_output^T @ input if ctx.amax_storage.e_tht_amax is None or not ctx.delayed_amax: ctx.amax_storage.e_tht_amax = (grad_output.T.reshape(-1, had.size(0)) @ had.T).abs().max().float() e_tht_fp4, ctx.amax_storage.e_tht_amax = eden_1x16s_fp4(grad_output.T, had, bso, gs, ctx.amax_storage.e_tht_amax) if ctx.amax_storage.input_tht_amax is None or not ctx.delayed_amax: ctx.amax_storage.input_tht_amax = (input_fq.T.reshape(-1, had.size(0)) @ had.T).abs().max().float() input_tht_fp4, ctx.amax_storage.input_tht_amax = eden_1x16s_fp4(input_fq.T, had, bso, gs, ctx.amax_storage.input_tht_amax) grad_weight = F.linear(e_tht_fp4, input_tht_fp4, None) return grad_input, grad_weight, None, None, None, None, None class FakeQuartetLinear(torch.nn.Linear): def __init__(self, *args, hadamard_dim=32, delayed_amax=False, disable_forward_quant=False, disable_backward_quant=False, four_over_six=True, **kwargs): super().__init__(*args, **kwargs) self.hadamard_dim = hadamard_dim self.delayed_amax = delayed_amax self.disable_forward_quant = disable_forward_quant self.disable_backward_quant = disable_backward_quant self.four_over_six = four_over_six self.amax_storage = AmaxStorage() if FakeQuartetFn.hadamard_matrix is None: FakeQuartetFn.hadamard_matrix = get_hadamard_matrix( self.hadamard_dim, dtype=torch.float32, device="cuda", ) def forward(self, x): return FakeQuartetFn.apply( x, self.weight, self.amax_storage, self.delayed_amax, self.disable_forward_quant, self.disable_backward_quant, self.four_over_six, )