File size: 4,486 Bytes
cd16f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import importlib.util
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from xqs_stack import choose_attention_backend, choose_quant_backend
from xqs_triton_ops import triton_ternary_linear


_HAS_FLASH_ATTN = importlib.util.find_spec("flash_attn") is not None
if _HAS_FLASH_ATTN:
    from flash_attn import flash_attn_func

_ATTN_BACKEND = choose_attention_backend(prefer_flash=True)
_QUANT_BACKEND = choose_quant_backend(prefer_triton=True)


def ternary_quantize(weight: torch.Tensor) -> torch.Tensor:
    scale = weight.detach().abs().mean().clamp(min=1e-6)
    pos = weight > (0.5 * scale)
    neg = weight < (-0.5 * scale)
    quantized = torch.zeros_like(weight)
    quantized = torch.where(pos, torch.ones_like(weight), quantized)
    quantized = torch.where(neg, -torch.ones_like(weight), quantized)
    quantized = quantized * scale
    return weight + (quantized - weight).detach()


class TernaryLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = _QUANT_BACKEND
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
        if self.bias is not None:
            bound = 1 / max(1, self.in_features) ** 0.5
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.backend == "triton":
            return triton_ternary_linear(x, self.weight, self.bias)
        return F.linear(x, ternary_quantize(self.weight), self.bias)


def build_linear(in_features: int, out_features: int, bias: bool = True, ternary: bool = False) -> nn.Module:
    if ternary:
        return TernaryLinear(in_features, out_features, bias=bias)
    return nn.Linear(in_features, out_features, bias=bias)


def fused_residual_add(x: torch.Tensor, residual: torch.Tensor, gate: Optional[torch.Tensor] = None) -> torch.Tensor:
    if gate is None:
        return x + residual
    return x + (gate * residual)


def causal_scaled_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout_p: float = 0.0,
    training: bool = False,
) -> torch.Tensor:
    if _ATTN_BACKEND == "flash_attn" and _HAS_FLASH_ATTN and q.is_cuda and q.dtype in {torch.float16, torch.bfloat16}:
        q_flash = q.transpose(1, 2).contiguous()
        k_flash = k.transpose(1, 2).contiguous()
        v_flash = v.transpose(1, 2).contiguous()
        out = flash_attn_func(
            q_flash,
            k_flash,
            v_flash,
            dropout_p=dropout_p if training else 0.0,
            causal=True,
        )
        return out.transpose(1, 2).contiguous()

    if hasattr(F, "scaled_dot_product_attention"):
        return F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=dropout_p if training else 0.0,
            is_causal=True,
        )

    scale = q.size(-1) ** -0.5
    scores = torch.matmul(q, k.transpose(-2, -1)) * scale
    causal_mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1), device=scores.device, dtype=torch.bool), diagonal=1)
    scores = scores.masked_fill(causal_mask, float("-inf"))
    probs = torch.softmax(scores, dim=-1)
    if training and dropout_p > 0:
        probs = F.dropout(probs, p=dropout_p)
    return torch.matmul(probs, v)


def pack_rows(indices: torch.Tensor, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
    return tuple(t.index_select(0, indices) for t in tensors)


def scatter_rows(base: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
    if indices.numel() == 0:
        return base
    out = base.clone()
    out.index_copy_(0, indices, updates)
    return out


def maybe_compile_module(module: nn.Module, enabled: bool) -> nn.Module:
    if not enabled:
        return module
    compile_fn = getattr(torch, "compile", None)
    if compile_fn is None:
        return module
    try:
        return compile_fn(module)
    except Exception:
        return module