bbkdevops's picture
download
raw
10.5 kB
"""TinyMind INT6 2:4 sparse export.
INT6 is a TinyMind middle-precision format: higher fidelity than INT4 while
remaining smaller than INT8. Each 8-value chunk keeps two adjacent 2-value
pairs (2:4 pair sparsity, four non-zero values), stores the four signed INT6
values in 24 bits, and reuses a 4-bit pair-selection metadata nibble.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
import torch
import torch.nn as nn
from model.sparse_int4 import pack_metadata_word, _pad_to_multiple
FORMAT_NAME = "int6_2x4_pairwise_sparse"
USER_ALIAS = "int6_2:4sp"
@dataclass(frozen=True)
class Int6SparseChunk:
values: tuple[int, int, int, int]
metadata_nibble: int
nonzero_pairs: tuple[int, int]
def _check_int6(x: int) -> None:
if x < -32 or x > 31:
raise ValueError(f"{x} is outside signed INT6 range [-32, 31]")
def _to_s6(x: int) -> int:
_check_int6(x)
return x & 0x3F
def _from_s6(x: int) -> int:
x &= 0x3F
return x - 64 if x >= 32 else x
def pack_int6_quad(values: Iterable[int]) -> bytes:
xs = tuple(int(x) for x in values)
if len(xs) != 4:
raise ValueError(f"expected 4 INT6 values, got {len(xs)}")
word = 0
for idx, x in enumerate(xs):
word |= _to_s6(x) << (idx * 6)
return bytes((word & 0xFF, (word >> 8) & 0xFF, (word >> 16) & 0xFF))
def unpack_int6_quad(payload: bytes | Iterable[int]) -> tuple[int, int, int, int]:
bs = bytes(payload)
if len(bs) != 3:
raise ValueError(f"expected 3 packed bytes, got {len(bs)}")
word = bs[0] | (bs[1] << 8) | (bs[2] << 16)
return tuple(_from_s6((word >> (idx * 6)) & 0x3F) for idx in range(4)) # type: ignore[return-value]
def pack_sparse_chunk_2x4_int6(chunk: Iterable[int]) -> Int6SparseChunk:
xs = tuple(int(x) for x in chunk)
if len(xs) != 8:
raise ValueError(f"expected 8 values, got {len(xs)}")
for x in xs:
_check_int6(x)
nonzero_pairs: list[int] = []
compressed: list[int] = []
for pair_index in range(4):
a, b = xs[pair_index * 2], xs[pair_index * 2 + 1]
both_zero = a == 0 and b == 0
both_nonzero = a != 0 and b != 0
if not (both_zero or both_nonzero):
raise ValueError(f"pair {pair_index} must be both zero or both non-zero, got {(a, b)}")
if both_nonzero:
nonzero_pairs.append(pair_index)
compressed.extend((a, b))
if len(nonzero_pairs) != 2:
raise ValueError(f"expected exactly 2 non-zero pairs, got {nonzero_pairs}")
metadata = nonzero_pairs[0] | (nonzero_pairs[1] << 2)
return Int6SparseChunk(
values=tuple(compressed), # type: ignore[arg-type]
metadata_nibble=metadata,
nonzero_pairs=(nonzero_pairs[0], nonzero_pairs[1]),
)
def pack_sparse_row_2x4_int6(row: Iterable[int]) -> tuple[bytes, list[int]]:
xs = tuple(int(x) for x in row)
if len(xs) % 64 != 0:
raise ValueError("row length must be a multiple of 64 for metadata words")
packed = bytearray()
metadata_words: list[int] = []
pending_nibbles: list[int] = []
for offset in range(0, len(xs), 8):
sparse = pack_sparse_chunk_2x4_int6(xs[offset : offset + 8])
packed.extend(pack_int6_quad(sparse.values))
pending_nibbles.append(sparse.metadata_nibble)
if len(pending_nibbles) == 8:
metadata_words.append(pack_metadata_word(pending_nibbles))
pending_nibbles.clear()
return bytes(packed), metadata_words
def quantize_s6_per_row(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
scale = weight.abs().amax(dim=1).clamp(min=1e-6) / 31.0
q = torch.round(weight / scale[:, None]).clamp(-32, 31).to(torch.int16)
nonzero = weight != 0
forced_sign = torch.where(weight >= 0, torch.ones_like(q), -torch.ones_like(q))
q = torch.where(nonzero & (q == 0), forced_sign, q)
q = torch.where(nonzero, q, torch.zeros_like(q))
q_float = q.to(torch.float32)
numerator = (q_float * weight).sum(dim=1)
denominator = (q_float * q_float).sum(dim=1).clamp(min=1e-6)
refit_scale = (numerator / denominator).abs().clamp(min=1e-6)
return q, refit_scale.to(torch.float32)
def prune_tensor_pairwise_4x8_l2(weight: torch.Tensor) -> torch.Tensor:
"""Keep the two adjacent pairs with highest squared energy per 8 values."""
if weight.dim() != 2:
raise ValueError("expected a 2D weight tensor")
padded = _pad_to_multiple(weight.detach(), 8, dim=1).clone()
rows, cols = padded.shape
view = padded.view(rows, cols // 8, 4, 2)
pair_scores = (view * view).sum(dim=-1)
keep = torch.zeros_like(pair_scores, dtype=torch.bool)
top2 = torch.topk(pair_scores, k=2, dim=-1).indices
keep.scatter_(-1, top2, True)
return (view * keep.unsqueeze(-1)).view(rows, cols)[:, : weight.shape[1]]
def _make_pairwise_chunks_encodable(qweight: torch.Tensor) -> torch.Tensor:
padded = _pad_to_multiple(qweight, 64, dim=1).cpu().clone()
rows, cols = padded.shape
view = padded.view(rows, cols // 8, 4, 2)
for row in range(rows):
for chunk in range(cols // 8):
pairs = view[row, chunk]
active = [idx for idx in range(4) if bool((pairs[idx] != 0).any().item())]
selected = active[:2]
for idx in range(4):
if idx not in selected:
pairs[idx].zero_()
for idx in range(4):
if len(selected) >= 2:
break
if idx not in selected:
selected.append(idx)
for idx in selected:
pair = pairs[idx]
if int(pair[0].item()) == 0 and int(pair[1].item()) == 0:
pair[0] = 1
pair[1] = 1
elif int(pair[0].item()) == 0:
pair[0] = 1 if int(pair[1].item()) >= 0 else -1
elif int(pair[1].item()) == 0:
pair[1] = 1 if int(pair[0].item()) >= 0 else -1
return padded
def _pack_quantized_matrix(qweight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
padded = _make_pairwise_chunks_encodable(qweight)
packed_rows: list[int] = []
metadata_rows: list[int] = []
for row in padded.tolist():
packed, metadata = pack_sparse_row_2x4_int6(row)
packed_rows.extend(packed)
metadata_rows.extend(metadata)
return (
torch.tensor(packed_rows, dtype=torch.uint8),
torch.tensor(metadata_rows, dtype=torch.int64),
int(padded.shape[1]),
)
class INT6SparseLinear(nn.Module):
format_name = FORMAT_NAME
user_alias = USER_ALIAS
def __init__(
self,
packed_weight: torch.Tensor,
metadata: torch.Tensor,
scales: torch.Tensor,
out_features: int,
in_features: int,
padded_in_features: int,
bias: torch.Tensor | None = None,
):
super().__init__()
self.out_features = out_features
self.in_features = in_features
self.padded_in_features = padded_in_features
self.register_buffer("packed_weight", packed_weight.contiguous())
self.register_buffer("metadata", metadata.contiguous())
self.register_buffer("scales", scales.contiguous())
if bias is not None:
self.register_buffer("bias", bias.detach().clone().to(torch.float32))
else:
self.bias = None # type: ignore[assignment]
@classmethod
def from_dense(cls, layer: nn.Linear) -> "INT6SparseLinear":
pruned = prune_tensor_pairwise_4x8_l2(layer.weight.detach().to(torch.float32))
qweight, scales = quantize_s6_per_row(pruned)
packed, metadata, padded_in = _pack_quantized_matrix(qweight)
bias = layer.bias.detach() if layer.bias is not None else None
return cls(packed, metadata, scales, layer.out_features, layer.in_features, padded_in, bias)
def dequantize_weight(self) -> torch.Tensor:
rows: list[list[int]] = []
packed_per_row = (self.padded_in_features // 8) * 3
metadata_words_per_row = self.padded_in_features // 64
chunks_per_row = self.padded_in_features // 8
for row_idx in range(self.out_features):
start = row_idx * packed_per_row
bytes_row = self.packed_weight[start : start + packed_per_row].tolist()
meta_start = row_idx * metadata_words_per_row
meta_row = self.metadata[meta_start : meta_start + metadata_words_per_row].tolist()
vals = [0 for _ in range(self.padded_in_features)]
for chunk_idx in range(chunks_per_row):
metadata_word = int(meta_row[chunk_idx // 8])
nibble = (metadata_word >> ((chunk_idx % 8) * 4)) & 0xF
pair_indices = (nibble & 0x3, (nibble >> 2) & 0x3)
byte_offset = chunk_idx * 3
compressed = unpack_int6_quad(bytes(bytes_row[byte_offset : byte_offset + 3]))
chunk_base = chunk_idx * 8
for value_offset, pair_index in enumerate(pair_indices):
src = value_offset * 2
dst = chunk_base + pair_index * 2
vals[dst] = compressed[src]
vals[dst + 1] = compressed[src + 1]
rows.append(vals[: self.in_features])
q = torch.tensor(rows, dtype=torch.float32, device=self.scales.device)
return q * self.scales[:, None]
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.dequantize_weight().to(device=x.device, dtype=x.dtype)
bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None
return torch.nn.functional.linear(x, weight, bias)
def export_sparse_int6_model(model: nn.Module, quality_gate_delta: float = 0.025) -> dict:
layers: dict[str, INT6SparseLinear] = {}
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.dim() == 2 and module.in_features >= 64:
layers[name] = INT6SparseLinear.from_dense(module)
return {
"format": FORMAT_NAME,
"user_alias": USER_ALIAS,
"quality_gate_delta": quality_gate_delta,
"bits_per_stored_value": 6,
"pair_sparsity": "2_of_4_pairs_per_8_values",
"effective_payload_bits_per_dense_weight_before_metadata": 3.0,
"layers": layers,
}

Xet Storage Details

Size:
10.5 kB
·
Xet hash:
f3145ec4f16c1c902281cb08d5e87d7c1a702ae4c476e444fefcbf19995f861a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.