File size: 4,486 Bytes
cd16f07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | from __future__ import annotations
import importlib.util
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from xqs_stack import choose_attention_backend, choose_quant_backend
from xqs_triton_ops import triton_ternary_linear
_HAS_FLASH_ATTN = importlib.util.find_spec("flash_attn") is not None
if _HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
_ATTN_BACKEND = choose_attention_backend(prefer_flash=True)
_QUANT_BACKEND = choose_quant_backend(prefer_triton=True)
def ternary_quantize(weight: torch.Tensor) -> torch.Tensor:
scale = weight.detach().abs().mean().clamp(min=1e-6)
pos = weight > (0.5 * scale)
neg = weight < (-0.5 * scale)
quantized = torch.zeros_like(weight)
quantized = torch.where(pos, torch.ones_like(weight), quantized)
quantized = torch.where(neg, -torch.ones_like(weight), quantized)
quantized = quantized * scale
return weight + (quantized - weight).detach()
class TernaryLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.backend = _QUANT_BACKEND
self.weight = nn.Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
if self.bias is not None:
bound = 1 / max(1, self.in_features) ** 0.5
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.backend == "triton":
return triton_ternary_linear(x, self.weight, self.bias)
return F.linear(x, ternary_quantize(self.weight), self.bias)
def build_linear(in_features: int, out_features: int, bias: bool = True, ternary: bool = False) -> nn.Module:
if ternary:
return TernaryLinear(in_features, out_features, bias=bias)
return nn.Linear(in_features, out_features, bias=bias)
def fused_residual_add(x: torch.Tensor, residual: torch.Tensor, gate: Optional[torch.Tensor] = None) -> torch.Tensor:
if gate is None:
return x + residual
return x + (gate * residual)
def causal_scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
training: bool = False,
) -> torch.Tensor:
if _ATTN_BACKEND == "flash_attn" and _HAS_FLASH_ATTN and q.is_cuda and q.dtype in {torch.float16, torch.bfloat16}:
q_flash = q.transpose(1, 2).contiguous()
k_flash = k.transpose(1, 2).contiguous()
v_flash = v.transpose(1, 2).contiguous()
out = flash_attn_func(
q_flash,
k_flash,
v_flash,
dropout_p=dropout_p if training else 0.0,
causal=True,
)
return out.transpose(1, 2).contiguous()
if hasattr(F, "scaled_dot_product_attention"):
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p if training else 0.0,
is_causal=True,
)
scale = q.size(-1) ** -0.5
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
causal_mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1), device=scores.device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(causal_mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
if training and dropout_p > 0:
probs = F.dropout(probs, p=dropout_p)
return torch.matmul(probs, v)
def pack_rows(indices: torch.Tensor, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return tuple(t.index_select(0, indices) for t in tensors)
def scatter_rows(base: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
if indices.numel() == 0:
return base
out = base.clone()
out.index_copy_(0, indices, updates)
return out
def maybe_compile_module(module: nn.Module, enabled: bool) -> nn.Module:
if not enabled:
return module
compile_fn = getattr(torch, "compile", None)
if compile_fn is None:
return module
try:
return compile_fn(module)
except Exception:
return module
|