RubiRLM-1B-Base / x_quantum_sparse_ops.py
DevHunterAI's picture
Upload folder using huggingface_hub
cd16f07 verified
from __future__ import annotations
import importlib.util
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from xqs_stack import choose_attention_backend, choose_quant_backend
from xqs_triton_ops import triton_ternary_linear
_HAS_FLASH_ATTN = importlib.util.find_spec("flash_attn") is not None
if _HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
_ATTN_BACKEND = choose_attention_backend(prefer_flash=True)
_QUANT_BACKEND = choose_quant_backend(prefer_triton=True)
def ternary_quantize(weight: torch.Tensor) -> torch.Tensor:
scale = weight.detach().abs().mean().clamp(min=1e-6)
pos = weight > (0.5 * scale)
neg = weight < (-0.5 * scale)
quantized = torch.zeros_like(weight)
quantized = torch.where(pos, torch.ones_like(weight), quantized)
quantized = torch.where(neg, -torch.ones_like(weight), quantized)
quantized = quantized * scale
return weight + (quantized - weight).detach()
class TernaryLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.backend = _QUANT_BACKEND
self.weight = nn.Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
if self.bias is not None:
bound = 1 / max(1, self.in_features) ** 0.5
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.backend == "triton":
return triton_ternary_linear(x, self.weight, self.bias)
return F.linear(x, ternary_quantize(self.weight), self.bias)
def build_linear(in_features: int, out_features: int, bias: bool = True, ternary: bool = False) -> nn.Module:
if ternary:
return TernaryLinear(in_features, out_features, bias=bias)
return nn.Linear(in_features, out_features, bias=bias)
def fused_residual_add(x: torch.Tensor, residual: torch.Tensor, gate: Optional[torch.Tensor] = None) -> torch.Tensor:
if gate is None:
return x + residual
return x + (gate * residual)
def causal_scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
training: bool = False,
) -> torch.Tensor:
if _ATTN_BACKEND == "flash_attn" and _HAS_FLASH_ATTN and q.is_cuda and q.dtype in {torch.float16, torch.bfloat16}:
q_flash = q.transpose(1, 2).contiguous()
k_flash = k.transpose(1, 2).contiguous()
v_flash = v.transpose(1, 2).contiguous()
out = flash_attn_func(
q_flash,
k_flash,
v_flash,
dropout_p=dropout_p if training else 0.0,
causal=True,
)
return out.transpose(1, 2).contiguous()
if hasattr(F, "scaled_dot_product_attention"):
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p if training else 0.0,
is_causal=True,
)
scale = q.size(-1) ** -0.5
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
causal_mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1), device=scores.device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(causal_mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
if training and dropout_p > 0:
probs = F.dropout(probs, p=dropout_p)
return torch.matmul(probs, v)
def pack_rows(indices: torch.Tensor, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return tuple(t.index_select(0, indices) for t in tensors)
def scatter_rows(base: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
if indices.numel() == 0:
return base
out = base.clone()
out.index_copy_(0, indices, updates)
return out
def maybe_compile_module(module: nn.Module, enabled: bool) -> nn.Module:
if not enabled:
return module
compile_fn = getattr(torch, "compile", None)
if compile_fn is None:
return module
try:
return compile_fn(module)
except Exception:
return module