| from __future__ import annotations |
|
|
| import importlib.util |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from xqs_stack import choose_attention_backend, choose_quant_backend |
| from xqs_triton_ops import triton_ternary_linear |
|
|
|
|
| _HAS_FLASH_ATTN = importlib.util.find_spec("flash_attn") is not None |
| if _HAS_FLASH_ATTN: |
| from flash_attn import flash_attn_func |
|
|
| _ATTN_BACKEND = choose_attention_backend(prefer_flash=True) |
| _QUANT_BACKEND = choose_quant_backend(prefer_triton=True) |
|
|
|
|
| def ternary_quantize(weight: torch.Tensor) -> torch.Tensor: |
| scale = weight.detach().abs().mean().clamp(min=1e-6) |
| pos = weight > (0.5 * scale) |
| neg = weight < (-0.5 * scale) |
| quantized = torch.zeros_like(weight) |
| quantized = torch.where(pos, torch.ones_like(weight), quantized) |
| quantized = torch.where(neg, -torch.ones_like(weight), quantized) |
| quantized = quantized * scale |
| return weight + (quantized - weight).detach() |
|
|
|
|
| class TernaryLinear(nn.Module): |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.backend = _QUANT_BACKEND |
| self.weight = nn.Parameter(torch.empty(out_features, in_features)) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(out_features)) |
| else: |
| self.register_parameter("bias", None) |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5) |
| if self.bias is not None: |
| bound = 1 / max(1, self.in_features) ** 0.5 |
| nn.init.uniform_(self.bias, -bound, bound) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.backend == "triton": |
| return triton_ternary_linear(x, self.weight, self.bias) |
| return F.linear(x, ternary_quantize(self.weight), self.bias) |
|
|
|
|
| def build_linear(in_features: int, out_features: int, bias: bool = True, ternary: bool = False) -> nn.Module: |
| if ternary: |
| return TernaryLinear(in_features, out_features, bias=bias) |
| return nn.Linear(in_features, out_features, bias=bias) |
|
|
|
|
| def fused_residual_add(x: torch.Tensor, residual: torch.Tensor, gate: Optional[torch.Tensor] = None) -> torch.Tensor: |
| if gate is None: |
| return x + residual |
| return x + (gate * residual) |
|
|
|
|
| def causal_scaled_dot_product_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| dropout_p: float = 0.0, |
| training: bool = False, |
| ) -> torch.Tensor: |
| if _ATTN_BACKEND == "flash_attn" and _HAS_FLASH_ATTN and q.is_cuda and q.dtype in {torch.float16, torch.bfloat16}: |
| q_flash = q.transpose(1, 2).contiguous() |
| k_flash = k.transpose(1, 2).contiguous() |
| v_flash = v.transpose(1, 2).contiguous() |
| out = flash_attn_func( |
| q_flash, |
| k_flash, |
| v_flash, |
| dropout_p=dropout_p if training else 0.0, |
| causal=True, |
| ) |
| return out.transpose(1, 2).contiguous() |
|
|
| if hasattr(F, "scaled_dot_product_attention"): |
| return F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=dropout_p if training else 0.0, |
| is_causal=True, |
| ) |
|
|
| scale = q.size(-1) ** -0.5 |
| scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
| causal_mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1), device=scores.device, dtype=torch.bool), diagonal=1) |
| scores = scores.masked_fill(causal_mask, float("-inf")) |
| probs = torch.softmax(scores, dim=-1) |
| if training and dropout_p > 0: |
| probs = F.dropout(probs, p=dropout_p) |
| return torch.matmul(probs, v) |
|
|
|
|
| def pack_rows(indices: torch.Tensor, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: |
| return tuple(t.index_select(0, indices) for t in tensors) |
|
|
|
|
| def scatter_rows(base: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: |
| if indices.numel() == 0: |
| return base |
| out = base.clone() |
| out.index_copy_(0, indices, updates) |
| return out |
|
|
|
|
| def maybe_compile_module(module: nn.Module, enabled: bool) -> nn.Module: |
| if not enabled: |
| return module |
| compile_fn = getattr(torch, "compile", None) |
| if compile_fn is None: |
| return module |
| try: |
| return compile_fn(module) |
| except Exception: |
| return module |
|
|