Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-8b-remote-handoff /bundle /model /sparse_int6.py
| """TinyMind INT6 2:4 sparse export. | |
| INT6 is a TinyMind middle-precision format: higher fidelity than INT4 while | |
| remaining smaller than INT8. Each 8-value chunk keeps two adjacent 2-value | |
| pairs (2:4 pair sparsity, four non-zero values), stores the four signed INT6 | |
| values in 24 bits, and reuses a 4-bit pair-selection metadata nibble. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Iterable | |
| import torch | |
| import torch.nn as nn | |
| from model.sparse_int4 import pack_metadata_word, _pad_to_multiple | |
| FORMAT_NAME = "int6_2x4_pairwise_sparse" | |
| USER_ALIAS = "int6_2:4sp" | |
| class Int6SparseChunk: | |
| values: tuple[int, int, int, int] | |
| metadata_nibble: int | |
| nonzero_pairs: tuple[int, int] | |
| def _check_int6(x: int) -> None: | |
| if x < -32 or x > 31: | |
| raise ValueError(f"{x} is outside signed INT6 range [-32, 31]") | |
| def _to_s6(x: int) -> int: | |
| _check_int6(x) | |
| return x & 0x3F | |
| def _from_s6(x: int) -> int: | |
| x &= 0x3F | |
| return x - 64 if x >= 32 else x | |
| def pack_int6_quad(values: Iterable[int]) -> bytes: | |
| xs = tuple(int(x) for x in values) | |
| if len(xs) != 4: | |
| raise ValueError(f"expected 4 INT6 values, got {len(xs)}") | |
| word = 0 | |
| for idx, x in enumerate(xs): | |
| word |= _to_s6(x) << (idx * 6) | |
| return bytes((word & 0xFF, (word >> 8) & 0xFF, (word >> 16) & 0xFF)) | |
| def unpack_int6_quad(payload: bytes | Iterable[int]) -> tuple[int, int, int, int]: | |
| bs = bytes(payload) | |
| if len(bs) != 3: | |
| raise ValueError(f"expected 3 packed bytes, got {len(bs)}") | |
| word = bs[0] | (bs[1] << 8) | (bs[2] << 16) | |
| return tuple(_from_s6((word >> (idx * 6)) & 0x3F) for idx in range(4)) # type: ignore[return-value] | |
| def pack_sparse_chunk_2x4_int6(chunk: Iterable[int]) -> Int6SparseChunk: | |
| xs = tuple(int(x) for x in chunk) | |
| if len(xs) != 8: | |
| raise ValueError(f"expected 8 values, got {len(xs)}") | |
| for x in xs: | |
| _check_int6(x) | |
| nonzero_pairs: list[int] = [] | |
| compressed: list[int] = [] | |
| for pair_index in range(4): | |
| a, b = xs[pair_index * 2], xs[pair_index * 2 + 1] | |
| both_zero = a == 0 and b == 0 | |
| both_nonzero = a != 0 and b != 0 | |
| if not (both_zero or both_nonzero): | |
| raise ValueError(f"pair {pair_index} must be both zero or both non-zero, got {(a, b)}") | |
| if both_nonzero: | |
| nonzero_pairs.append(pair_index) | |
| compressed.extend((a, b)) | |
| if len(nonzero_pairs) != 2: | |
| raise ValueError(f"expected exactly 2 non-zero pairs, got {nonzero_pairs}") | |
| metadata = nonzero_pairs[0] | (nonzero_pairs[1] << 2) | |
| return Int6SparseChunk( | |
| values=tuple(compressed), # type: ignore[arg-type] | |
| metadata_nibble=metadata, | |
| nonzero_pairs=(nonzero_pairs[0], nonzero_pairs[1]), | |
| ) | |
| def pack_sparse_row_2x4_int6(row: Iterable[int]) -> tuple[bytes, list[int]]: | |
| xs = tuple(int(x) for x in row) | |
| if len(xs) % 64 != 0: | |
| raise ValueError("row length must be a multiple of 64 for metadata words") | |
| packed = bytearray() | |
| metadata_words: list[int] = [] | |
| pending_nibbles: list[int] = [] | |
| for offset in range(0, len(xs), 8): | |
| sparse = pack_sparse_chunk_2x4_int6(xs[offset : offset + 8]) | |
| packed.extend(pack_int6_quad(sparse.values)) | |
| pending_nibbles.append(sparse.metadata_nibble) | |
| if len(pending_nibbles) == 8: | |
| metadata_words.append(pack_metadata_word(pending_nibbles)) | |
| pending_nibbles.clear() | |
| return bytes(packed), metadata_words | |
| def quantize_s6_per_row(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| scale = weight.abs().amax(dim=1).clamp(min=1e-6) / 31.0 | |
| q = torch.round(weight / scale[:, None]).clamp(-32, 31).to(torch.int16) | |
| nonzero = weight != 0 | |
| forced_sign = torch.where(weight >= 0, torch.ones_like(q), -torch.ones_like(q)) | |
| q = torch.where(nonzero & (q == 0), forced_sign, q) | |
| q = torch.where(nonzero, q, torch.zeros_like(q)) | |
| q_float = q.to(torch.float32) | |
| numerator = (q_float * weight).sum(dim=1) | |
| denominator = (q_float * q_float).sum(dim=1).clamp(min=1e-6) | |
| refit_scale = (numerator / denominator).abs().clamp(min=1e-6) | |
| return q, refit_scale.to(torch.float32) | |
| def prune_tensor_pairwise_4x8_l2(weight: torch.Tensor) -> torch.Tensor: | |
| """Keep the two adjacent pairs with highest squared energy per 8 values.""" | |
| if weight.dim() != 2: | |
| raise ValueError("expected a 2D weight tensor") | |
| padded = _pad_to_multiple(weight.detach(), 8, dim=1).clone() | |
| rows, cols = padded.shape | |
| view = padded.view(rows, cols // 8, 4, 2) | |
| pair_scores = (view * view).sum(dim=-1) | |
| keep = torch.zeros_like(pair_scores, dtype=torch.bool) | |
| top2 = torch.topk(pair_scores, k=2, dim=-1).indices | |
| keep.scatter_(-1, top2, True) | |
| return (view * keep.unsqueeze(-1)).view(rows, cols)[:, : weight.shape[1]] | |
| def _make_pairwise_chunks_encodable(qweight: torch.Tensor) -> torch.Tensor: | |
| padded = _pad_to_multiple(qweight, 64, dim=1).cpu().clone() | |
| rows, cols = padded.shape | |
| view = padded.view(rows, cols // 8, 4, 2) | |
| for row in range(rows): | |
| for chunk in range(cols // 8): | |
| pairs = view[row, chunk] | |
| active = [idx for idx in range(4) if bool((pairs[idx] != 0).any().item())] | |
| selected = active[:2] | |
| for idx in range(4): | |
| if idx not in selected: | |
| pairs[idx].zero_() | |
| for idx in range(4): | |
| if len(selected) >= 2: | |
| break | |
| if idx not in selected: | |
| selected.append(idx) | |
| for idx in selected: | |
| pair = pairs[idx] | |
| if int(pair[0].item()) == 0 and int(pair[1].item()) == 0: | |
| pair[0] = 1 | |
| pair[1] = 1 | |
| elif int(pair[0].item()) == 0: | |
| pair[0] = 1 if int(pair[1].item()) >= 0 else -1 | |
| elif int(pair[1].item()) == 0: | |
| pair[1] = 1 if int(pair[0].item()) >= 0 else -1 | |
| return padded | |
| def _pack_quantized_matrix(qweight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: | |
| padded = _make_pairwise_chunks_encodable(qweight) | |
| packed_rows: list[int] = [] | |
| metadata_rows: list[int] = [] | |
| for row in padded.tolist(): | |
| packed, metadata = pack_sparse_row_2x4_int6(row) | |
| packed_rows.extend(packed) | |
| metadata_rows.extend(metadata) | |
| return ( | |
| torch.tensor(packed_rows, dtype=torch.uint8), | |
| torch.tensor(metadata_rows, dtype=torch.int64), | |
| int(padded.shape[1]), | |
| ) | |
| class INT6SparseLinear(nn.Module): | |
| format_name = FORMAT_NAME | |
| user_alias = USER_ALIAS | |
| def __init__( | |
| self, | |
| packed_weight: torch.Tensor, | |
| metadata: torch.Tensor, | |
| scales: torch.Tensor, | |
| out_features: int, | |
| in_features: int, | |
| padded_in_features: int, | |
| bias: torch.Tensor | None = None, | |
| ): | |
| super().__init__() | |
| self.out_features = out_features | |
| self.in_features = in_features | |
| self.padded_in_features = padded_in_features | |
| self.register_buffer("packed_weight", packed_weight.contiguous()) | |
| self.register_buffer("metadata", metadata.contiguous()) | |
| self.register_buffer("scales", scales.contiguous()) | |
| if bias is not None: | |
| self.register_buffer("bias", bias.detach().clone().to(torch.float32)) | |
| else: | |
| self.bias = None # type: ignore[assignment] | |
| def from_dense(cls, layer: nn.Linear) -> "INT6SparseLinear": | |
| pruned = prune_tensor_pairwise_4x8_l2(layer.weight.detach().to(torch.float32)) | |
| qweight, scales = quantize_s6_per_row(pruned) | |
| packed, metadata, padded_in = _pack_quantized_matrix(qweight) | |
| bias = layer.bias.detach() if layer.bias is not None else None | |
| return cls(packed, metadata, scales, layer.out_features, layer.in_features, padded_in, bias) | |
| def dequantize_weight(self) -> torch.Tensor: | |
| rows: list[list[int]] = [] | |
| packed_per_row = (self.padded_in_features // 8) * 3 | |
| metadata_words_per_row = self.padded_in_features // 64 | |
| chunks_per_row = self.padded_in_features // 8 | |
| for row_idx in range(self.out_features): | |
| start = row_idx * packed_per_row | |
| bytes_row = self.packed_weight[start : start + packed_per_row].tolist() | |
| meta_start = row_idx * metadata_words_per_row | |
| meta_row = self.metadata[meta_start : meta_start + metadata_words_per_row].tolist() | |
| vals = [0 for _ in range(self.padded_in_features)] | |
| for chunk_idx in range(chunks_per_row): | |
| metadata_word = int(meta_row[chunk_idx // 8]) | |
| nibble = (metadata_word >> ((chunk_idx % 8) * 4)) & 0xF | |
| pair_indices = (nibble & 0x3, (nibble >> 2) & 0x3) | |
| byte_offset = chunk_idx * 3 | |
| compressed = unpack_int6_quad(bytes(bytes_row[byte_offset : byte_offset + 3])) | |
| chunk_base = chunk_idx * 8 | |
| for value_offset, pair_index in enumerate(pair_indices): | |
| src = value_offset * 2 | |
| dst = chunk_base + pair_index * 2 | |
| vals[dst] = compressed[src] | |
| vals[dst + 1] = compressed[src + 1] | |
| rows.append(vals[: self.in_features]) | |
| q = torch.tensor(rows, dtype=torch.float32, device=self.scales.device) | |
| return q * self.scales[:, None] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| weight = self.dequantize_weight().to(device=x.device, dtype=x.dtype) | |
| bias = self.bias.to(device=x.device, dtype=x.dtype) if self.bias is not None else None | |
| return torch.nn.functional.linear(x, weight, bias) | |
| def export_sparse_int6_model(model: nn.Module, quality_gate_delta: float = 0.025) -> dict: | |
| layers: dict[str, INT6SparseLinear] = {} | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Linear) and module.weight.dim() == 2 and module.in_features >= 64: | |
| layers[name] = INT6SparseLinear.from_dense(module) | |
| return { | |
| "format": FORMAT_NAME, | |
| "user_alias": USER_ALIAS, | |
| "quality_gate_delta": quality_gate_delta, | |
| "bits_per_stored_value": 6, | |
| "pair_sparsity": "2_of_4_pairs_per_8_values", | |
| "effective_payload_bits_per_dense_weight_before_metadata": 3.0, | |
| "layers": layers, | |
| } | |
Xet Storage Details
- Size:
- 10.5 kB
- Xet hash:
- f3145ec4f16c1c902281cb08d5e87d7c1a702ae4c476e444fefcbf19995f861a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.