Instructions to use Entrit/tritllm-codec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Entrit/tritllm-codec with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Entrit/tritllm-codec", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Ternary quantizer v2 — multi-config single-pass with per-matrix checkpointing. | |
| Key improvements over v1: | |
| 1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3 | |
| Computes per-group MSE-best scales (over a fixed 4-candidate set) ONCE per | |
| matrix, derives all configs. ~3x faster than running v1 four times. | |
| 2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/ | |
| dir as soon as it's done. Crash-resume picks up where it left off. | |
| 3. Durable atomic writes (write to .tmp, fsync, rename) — no half-written or | |
| post-power-loss-truncated checkpoints. | |
| 4. Streaming progress.json — monitors can poll without parsing logs. | |
| 5. Per-config HF model assembled at the end from checkpoints. | |
| 6. Resume validation: a fingerprint of (model id, revision, codec version, | |
| depth-power mapping, tensor shape) is stored in each checkpoint and | |
| re-checked on resume. A mismatch causes the stale checkpoint to be | |
| discarded and re-quantized rather than silently mixed. | |
| What this codec quantizes (and what it does not): | |
| - Quantized: every 2D linear weight matrix in the model. | |
| - Kept FP16: token embeddings, all *_norm layers, and lm_head. | |
| This matches the convention used by GPTQ/AWQ/NF4 and is what the paper's | |
| bits-per-weight figures account for. | |
| Usage: | |
| python quantize_model_v2.py --model Qwen/Qwen2.5-7B \ | |
| --configs uniform-d2,uniform-d3 \ | |
| --output /path/to/output_root \ | |
| --revision <git-sha-of-source-model> \ | |
| --workers 8 --dtype float16 | |
| Output structure: | |
| output_root/ | |
| .checkpoint/ | |
| matrix_00000__model.layers.0.self_attn.q_proj.npz # all configs in one file | |
| matrix_00001__model.layers.0.self_attn.k_proj.npz | |
| ... | |
| progress.json # live status | |
| <config>/ | |
| model/ # HF-format output | |
| config.json | |
| """ | |
| import os, sys, time, json, gc, argparse, tempfile | |
| from multiprocessing import Pool | |
| import numpy as np | |
| # ============================================================ | |
| # CODEC CORE (unchanged from v1) | |
| # ============================================================ | |
| GS = 16 | |
| DEPTH_POWERS = {1: 1.0, 2: 1.5, 3: 1.2, 4: 1.0} | |
| def build_levels(half, power): | |
| int_levels = np.arange(-half, half + 1).astype(np.float64) | |
| n = int_levels / max(half, 1) | |
| if power != 1.0: | |
| return np.sign(n) * np.abs(n) ** power * max(half, 1) | |
| return int_levels | |
| def make_boundaries(level_map, zero_boundary=None): | |
| """Default = midpoints between levels. If zero_boundary given, override the | |
| boundaries straddling 0 (used for d1 with custom zero-zone width).""" | |
| boundaries = (level_map[:-1] + level_map[1:]) / 2 | |
| if zero_boundary is not None: | |
| zero_idx = int(np.argmin(np.abs(level_map))) | |
| if zero_idx > 0: | |
| boundaries[zero_idx - 1] = -abs(zero_boundary) | |
| if zero_idx < len(level_map) - 1: | |
| boundaries[zero_idx] = abs(zero_boundary) | |
| return boundaries | |
| def compute_best_scale_4cand(groups, depth, power, zero_boundary=None): | |
| """Pick the per-group scale that minimises reconstruction MSE among 4 fixed | |
| order-statistic candidates of the sorted absolute weights: | |
| indices [gs-6, gs-4, gs-2, gs-1] (roughly the 69th/81st/94th/100th | |
| percentiles for gs=16). | |
| This is a deliberately small candidate set, not an exhaustive sweep. | |
| Empirically <1% PPL gap from a dense sweep on Qwen2.5-7B; in exchange | |
| quantization is ~50x faster than evaluating every percentile. | |
| """ | |
| half = (3 ** depth) // 2 | |
| gs = groups.shape[1] | |
| sa = np.sort(np.abs(groups), axis=1) | |
| cand_idx = np.clip(np.array([gs-6, gs-4, gs-2, gs-1]), 0, gs-1) | |
| level_map = build_levels(half, power) | |
| boundaries = make_boundaries(level_map, zero_boundary) | |
| N = len(groups) | |
| best_scale = np.zeros(N); best_mse = np.full(N, np.inf) | |
| for ki in cand_idx: | |
| scales = np.maximum(sa[:, ki] / max(half, 1), 1e-30) | |
| normalized = groups / scales[:, None] | |
| idx = np.searchsorted(boundaries, normalized.ravel()) | |
| idx = np.clip(idx, 0, len(level_map) - 1) | |
| q = level_map[idx].reshape(N, gs) | |
| recon = q * scales[:, None] | |
| mse = np.mean((groups - recon) ** 2, axis=1) | |
| better = mse < best_mse | |
| best_mse[better] = mse[better]; best_scale[better] = scales[better] | |
| return best_scale, best_mse | |
| # Backwards-compatible alias — earlier scripts and the published paper repo | |
| # refer to this as the "MSE-optimal" call site. The name overstates the | |
| # guarantee (see docstring on compute_best_scale_4cand) but the algorithm is | |
| # unchanged. | |
| compute_optimal_scale = compute_best_scale_4cand | |
| def trit_quantize_scales(scales, sd): | |
| log_scales = np.log(np.maximum(scales, 1e-30)) | |
| half = (3 ** sd) // 2 | |
| n_levels = 2 * half + 1 | |
| log_min = np.percentile(log_scales, 0.1) | |
| log_max = np.max(log_scales) # 100th pct — never clip large scales | |
| if log_max - log_min < 1e-9: | |
| log_max = log_min + 1e-9 | |
| codebook_log = np.linspace(log_min, log_max, n_levels) | |
| idx = np.argmin(np.abs(log_scales[:, None] - codebook_log[None, :]), axis=1) | |
| return np.exp(codebook_log[idx]) | |
| def quantize_with_scale(groups, scale, depth, power, zero_boundary=None): | |
| half = (3 ** depth) // 2 | |
| level_map = build_levels(half, power) | |
| boundaries = make_boundaries(level_map, zero_boundary) | |
| scale = np.maximum(scale, 1e-30) | |
| normalized = groups / scale[:, None] | |
| idx = np.searchsorted(boundaries, normalized.ravel()) | |
| idx = np.clip(idx, 0, len(level_map) - 1) | |
| q = level_map[idx].reshape(groups.shape) | |
| return q * scale[:, None] | |
| # ============================================================ | |
| # CODEC CONFIGS | |
| # ============================================================ | |
| CODECS = { | |
| 'd3scale-sens002': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.002}, | |
| 'd3scale-sens003': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.003}, | |
| # d1 with narrow zero zone (zw=0.25): 3 levels {-1,0,+1}, zero only when |w|<0.25*scale. | |
| # Old default was zw=0.5 which made 97.5% of weights round to 0 (random-chance MMLU). | |
| 'uniform-d1': {'mode': 'uniform', 'scale_depth': 3, 'depth': 1, 'zero_boundary': 0.25}, | |
| 'uniform-d2': {'mode': 'uniform', 'scale_depth': 3, 'depth': 2}, | |
| 'uniform-d3': {'mode': 'uniform', 'scale_depth': 3, 'depth': 3}, | |
| 'uniform-d4': {'mode': 'uniform', 'scale_depth': 3, 'depth': 4}, | |
| } | |
| # ============================================================ | |
| # MULTI-CONFIG MATRIX QUANTIZATION | |
| # ============================================================ | |
| def quantize_matrix_multi(args): | |
| """Quantize one matrix for ALL requested configs in a single pass. | |
| Returns dict: config_name -> (recon_w, depth_counts, weight_bits, scale_bits, n_groups) | |
| """ | |
| w_flat, rows, cols, config_names = args | |
| w = w_flat.reshape(rows, cols) | |
| pad = (GS - cols % GS) % GS | |
| if pad > 0: | |
| w = np.pad(w, ((0, 0), (0, pad))) | |
| groups = w.reshape(-1, GS).astype(np.float64) | |
| N = len(groups) | |
| group_var = np.maximum(np.var(groups, axis=1), 1e-30) | |
| # Precompute optimal scale + MSE for every (depth, zero_boundary) combo used. | |
| # Adaptive uses default boundaries for d2/d3/d4; uniform configs may override (e.g. d1 zw=0.25). | |
| needed_keys = set() # (depth, zero_boundary) | |
| for cn in config_names: | |
| cfg = CODECS[cn] | |
| if cfg['mode'] == 'adaptive': | |
| for d in (2, 3, 4): | |
| needed_keys.add((d, None)) | |
| else: | |
| needed_keys.add((cfg['depth'], cfg.get('zero_boundary'))) | |
| scales_per_key = {} | |
| mse_per_key = {} | |
| recon_per_key = {} | |
| for d, zb in sorted(needed_keys, key=lambda x: (x[0], x[1] or 0)): | |
| power = DEPTH_POWERS[d] | |
| opt_s, _ = compute_optimal_scale(groups, d, power, zero_boundary=zb) | |
| use_s = trit_quantize_scales(opt_s, 3) | |
| r = quantize_with_scale(groups, use_s, d, power, zero_boundary=zb) | |
| mse = np.mean((groups - r) ** 2, axis=1) | |
| scales_per_key[(d, zb)] = use_s | |
| mse_per_key[(d, zb)] = mse | |
| recon_per_key[(d, zb)] = r | |
| out = {} | |
| for cn in config_names: | |
| cfg = CODECS[cn] | |
| if cfg['mode'] == 'uniform': | |
| d = cfg['depth'] | |
| zb = cfg.get('zero_boundary') | |
| recon = recon_per_key[(d, zb)] | |
| depth_counts = {1:0, 2:0, 3:0, 4:0} | |
| depth_counts[d] = N | |
| wb = N * GS * d * np.log2(3) | |
| sb = N * cfg['scale_depth'] * np.log2(3) | |
| else: # adaptive | |
| eff_thresh = cfg['threshold'] * 5.5 | |
| recon = np.zeros_like(groups) | |
| assigned = np.zeros(N, dtype=bool) | |
| depth_counts = {1:0, 2:0, 3:0, 4:0} | |
| wb = 0.0; sb = 0.0 | |
| for d in [2, 3, 4]: | |
| unassigned = ~assigned | |
| if not np.any(unassigned): | |
| break | |
| if d == 4: | |
| recon[unassigned] = recon_per_key[(4, None)][unassigned] | |
| n_d = int(np.sum(unassigned)) | |
| depth_counts[d] = n_d | |
| wb += n_d * GS * d * np.log2(3) | |
| sb += n_d * cfg['scale_depth'] * np.log2(3) | |
| break | |
| mse_d = mse_per_key[(d, None)][unassigned] | |
| meets = (mse_d / group_var[unassigned]) < eff_thresh | |
| uidx = np.where(unassigned)[0] | |
| midx = uidx[meets] | |
| recon[midx] = recon_per_key[(d, None)][midx] | |
| assigned[midx] = True | |
| n_d = int(np.sum(meets)) | |
| depth_counts[d] = n_d | |
| wb += n_d * GS * d * np.log2(3) | |
| sb += n_d * cfg['scale_depth'] * np.log2(3) | |
| recon_w = recon.reshape(rows, -1)[:, :cols].astype(np.float32) | |
| out[cn] = { | |
| 'recon_w': recon_w, | |
| 'depth_counts': depth_counts, | |
| 'weight_bits': float(wb), | |
| 'scale_bits': float(sb), | |
| 'n_groups': N, | |
| } | |
| return out | |
| # ============================================================ | |
| # CHECKPOINTING | |
| # ============================================================ | |
| def matrix_ckpt_path(ckpt_dir, idx, name): | |
| safe = name.replace('/', '__').replace('.', '_') | |
| return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz') | |
| def atomic_save_npz(path, data): | |
| """Write `data` to `path` atomically, with fsync before rename so the | |
| checkpoint survives power loss / SIGKILL after the rename returns.""" | |
| # NOTE: np.savez_compressed silently appends '.npz' if missing — so we | |
| # name the tmp file with .npz suffix and pass it the same path. | |
| fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path)) | |
| os.close(fd) | |
| np.savez_compressed(tmp, **data) | |
| # fsync the file so its data is durable before we rename. os.replace then | |
| # makes the rename atomic (POSIX guarantees same-filesystem rename atomicity). | |
| fd = os.open(tmp, os.O_RDONLY) | |
| try: | |
| os.fsync(fd) | |
| finally: | |
| os.close(fd) | |
| os.replace(tmp, path) | |
| # fsync the parent directory so the rename itself is durable. | |
| dir_fd = os.open(os.path.dirname(path) or '.', os.O_RDONLY) | |
| try: | |
| os.fsync(dir_fd) | |
| except OSError: | |
| pass # not all filesystems support directory fsync (e.g. some FUSE) | |
| finally: | |
| os.close(dir_fd) | |
| # Codec version — bumped whenever the algorithm changes in a way that would | |
| # make older checkpoints invalid (e.g. depth-power mapping change, scale | |
| # codebook range change, group-size change). Used by the fingerprint validator. | |
| CODEC_VERSION = 'v2.0' | |
| def codec_fingerprint(model_id, revision, depth_powers, group_size, codec_version): | |
| """Stable string that identifies the algorithmic state behind a checkpoint. | |
| Two checkpoints with the same fingerprint can be safely interleaved. | |
| Two with different fingerprints must not be mixed — a mismatch on resume | |
| causes the stale checkpoint to be discarded and re-quantized. | |
| """ | |
| parts = [ | |
| f'codec={codec_version}', | |
| f'model={model_id}', | |
| f'revision={revision or "unspecified"}', | |
| f'gs={group_size}', | |
| f'powers=' + ','.join(f'{d}:{p}' for d, p in sorted(depth_powers.items())), | |
| ] | |
| return '|'.join(parts) | |
| def load_ckpt(path): | |
| with np.load(path, allow_pickle=True) as z: | |
| return {k: z[k] for k in z.files} | |
| def write_progress(out_root, state): | |
| path = os.path.join(out_root, 'progress.json') | |
| fd, tmp = tempfile.mkstemp(prefix='.tmp_', dir=out_root) | |
| with os.fdopen(fd, 'w') as f: | |
| json.dump(state, f, indent=2) | |
| os.replace(tmp, path) | |
| # ============================================================ | |
| # MAIN | |
| # ============================================================ | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Multi-config ternary quantizer with checkpointing') | |
| parser.add_argument('--model', required=True) | |
| parser.add_argument('--configs', required=True, | |
| help='Comma-separated codec names: ' + ','.join(CODECS.keys())) | |
| parser.add_argument('--output', required=True, help='Output root dir') | |
| parser.add_argument('--workers', type=int, default=1) | |
| parser.add_argument('--dtype', default='float16', choices=['float16', 'bfloat16']) | |
| parser.add_argument('--skip-assembly', action='store_true', | |
| help='Quantize matrices and checkpoint only; skip final HF model assembly.') | |
| parser.add_argument('--matrix-range', default=None, | |
| help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). ' | |
| 'Use to manually parallelize across processes/machines via shared checkpoint dir.') | |
| parser.add_argument('--revision', default=None, | |
| help='HuggingFace revision (commit SHA or tag) to pin the source model. ' | |
| 'Recommended for reproducibility — without it, the upstream repo can move under you.') | |
| args = parser.parse_args() | |
| config_names = [c.strip() for c in args.configs.split(',') if c.strip()] | |
| for cn in config_names: | |
| if cn not in CODECS: | |
| print(f'ERROR: unknown codec {cn}', file=sys.stderr); sys.exit(2) | |
| os.makedirs(args.output, exist_ok=True) | |
| ckpt_dir = os.path.join(args.output, '.checkpoint') | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| print(f'=== Quantizing {args.model} ===', flush=True) | |
| print(f' configs: {config_names}', flush=True) | |
| print(f' workers: {args.workers}', flush=True) | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel | |
| dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16 | |
| print(' loading model (CPU)...', flush=True) | |
| t_load = time.time() | |
| _cfg = AutoConfig.from_pretrained(args.model, revision=args.revision, trust_remote_code=True) | |
| _arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower() | |
| if 't5' in _arch or 'encoder' in _arch: | |
| from transformers import T5EncoderModel | |
| print(' loading as T5EncoderModel (encoder-only)', flush=True) | |
| model = T5EncoderModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype, | |
| device_map='cpu', trust_remote_code=True, | |
| low_cpu_mem_usage=True) | |
| else: | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype, | |
| device_map='cpu', trust_remote_code=True, | |
| low_cpu_mem_usage=True) | |
| except ValueError: | |
| print(' fallback to generic AutoModel', flush=True) | |
| model = AutoModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype, | |
| device_map='cpu', trust_remote_code=True, | |
| low_cpu_mem_usage=True) | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| except Exception as e: | |
| print(f' tokenizer load failed (ok for encoder-only): {e}', flush=True) | |
| tokenizer = None | |
| print(f' loaded in {time.time()-t_load:.0f}s', flush=True) | |
| # Collect matrices to quantize (skip embeddings, norms, lm_head) | |
| matrices = [] | |
| for pn, p in model.named_parameters(): | |
| if p.dim() != 2 or 'norm' in pn or 'embed' in pn or 'lm_head' in pn: | |
| continue | |
| matrices.append((pn, p)) | |
| print(f' {len(matrices)} matrices to quantize', flush=True) | |
| # Apply --matrix-range slice (for parallel sharded processing) | |
| range_start, range_end = 0, len(matrices) | |
| if args.matrix_range: | |
| s, e = args.matrix_range.split(':') | |
| range_start = int(s) if s else 0 | |
| range_end = int(e) if e else len(matrices) | |
| range_end = min(range_end, len(matrices)) | |
| print(f' matrix-range: [{range_start}:{range_end})', flush=True) | |
| # Codec fingerprint for this run — used to validate resumed checkpoints. | |
| expected_fp = codec_fingerprint(args.model, args.revision, DEPTH_POWERS, GS, CODEC_VERSION) | |
| # Determine which need work (resume from checkpoints) | |
| todo = [] | |
| done_count = 0 | |
| discarded_count = 0 | |
| for idx, (pn, p) in enumerate(matrices): | |
| if idx < range_start or idx >= range_end: | |
| continue | |
| cp = matrix_ckpt_path(ckpt_dir, idx, pn) | |
| if os.path.exists(cp): | |
| try: | |
| z = np.load(cp, allow_pickle=True) | |
| meta = json.loads(str(z['_meta'][()])) | |
| # Validate: configs cover requested set, fingerprint matches, shape matches. | |
| have_configs = set(meta.get('configs', [])) | |
| ckpt_fp = meta.get('fingerprint') | |
| ckpt_shape = tuple(meta.get('shape', ())) | |
| cur_shape = tuple(p.shape) | |
| if all(cn in have_configs for cn in config_names) \ | |
| and ckpt_fp == expected_fp \ | |
| and ckpt_shape == cur_shape: | |
| done_count += 1 | |
| continue | |
| if ckpt_fp != expected_fp: | |
| print(f' fingerprint mismatch on {cp}: stale={ckpt_fp!r} expected={expected_fp!r} — discarding', flush=True) | |
| elif ckpt_shape != cur_shape: | |
| print(f' shape mismatch on {cp}: stale={ckpt_shape} current={cur_shape} — discarding', flush=True) | |
| else: | |
| print(f' missing configs in {cp}: have={have_configs}, need={config_names} — redoing', flush=True) | |
| discarded_count += 1 | |
| os.remove(cp) | |
| except Exception as e: | |
| print(f' bad checkpoint {cp}: {e}, will redo', flush=True) | |
| os.remove(cp) | |
| todo.append((idx, pn, p)) | |
| if discarded_count: | |
| print(f' discarded {discarded_count} stale checkpoint(s)', flush=True) | |
| print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True) | |
| t0 = time.time() | |
| state = { | |
| 'model': args.model, 'configs': config_names, | |
| 'total_matrices': len(matrices), | |
| 'done_matrices': done_count, | |
| 'started_at': t0, 'updated_at': t0, | |
| } | |
| write_progress(args.output, state) | |
| def process_one(idx, pn, p): | |
| w = p.data.float().numpy() | |
| result = quantize_matrix_multi( | |
| (w.ravel(), w.shape[0], w.shape[1], config_names)) | |
| # Pack into npz: one key per config + meta (with codec fingerprint | |
| # so a future resume can detect a stale checkpoint and discard it). | |
| save_data = {'_meta': np.array(json.dumps({ | |
| 'name': pn, 'idx': idx, 'shape': list(w.shape), | |
| 'configs': config_names, | |
| 'fingerprint': expected_fp, | |
| }))} | |
| for cn, info in result.items(): | |
| save_data[f'{cn}__w'] = info['recon_w'] | |
| save_data[f'{cn}__stats'] = np.array(json.dumps({ | |
| 'depth_counts': info['depth_counts'], | |
| 'weight_bits': info['weight_bits'], | |
| 'scale_bits': info['scale_bits'], | |
| 'n_groups': info['n_groups'], | |
| })) | |
| atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data) | |
| return idx | |
| if args.workers > 1 and len(todo) > 1: | |
| # Streaming generator: yield (matrix, config_names) one at a time. | |
| # CRITICAL: do NOT pre-build all matrices in a list — for large models | |
| # (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator | |
| # is consumed lazily by Pool.imap. | |
| idx_name = [(idx, pn, list(p.shape)) for idx, pn, p in todo] | |
| def gen(): | |
| for idx, pn, p in todo: | |
| w = p.data.float().numpy() | |
| yield (w.ravel(), w.shape[0], w.shape[1], config_names) | |
| # Free the source tensor after we've handed off the numpy view. | |
| # The Pool worker has its own copy via pickle. | |
| p.data = __import__('torch').zeros(1, dtype=p.dtype) | |
| with Pool(args.workers) as pool: | |
| for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)): | |
| idx, pn, shape = idx_name[i] | |
| save_data = {'_meta': np.array(json.dumps({ | |
| 'name': pn, 'idx': idx, 'shape': shape, | |
| 'configs': config_names, | |
| 'fingerprint': expected_fp, | |
| }))} | |
| for cn, info in result.items(): | |
| save_data[f'{cn}__w'] = info['recon_w'] | |
| save_data[f'{cn}__stats'] = np.array(json.dumps({ | |
| 'depth_counts': info['depth_counts'], | |
| 'weight_bits': info['weight_bits'], | |
| 'scale_bits': info['scale_bits'], | |
| 'n_groups': info['n_groups'], | |
| })) | |
| atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data) | |
| done_count += 1 | |
| state['done_matrices'] = done_count | |
| state['updated_at'] = time.time() | |
| state['elapsed_s'] = time.time() - t0 | |
| if (i+1) % 5 == 0 or (i+1) == len(todo): | |
| write_progress(args.output, state) | |
| eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1) | |
| print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True) | |
| else: | |
| for i, (idx, pn, p) in enumerate(todo): | |
| process_one(idx, pn, p) | |
| done_count += 1 | |
| state['done_matrices'] = done_count | |
| state['updated_at'] = time.time() | |
| state['elapsed_s'] = time.time() - t0 | |
| if (i+1) % 5 == 0 or (i+1) == len(todo): | |
| write_progress(args.output, state) | |
| eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1) | |
| print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True) | |
| print(f' Quantization complete in {time.time()-t0:.0f}s', flush=True) | |
| # If we processed only a slice, don't assemble — leave that for the merge step. | |
| if args.matrix_range: | |
| # Verify which checkpoints exist for this slice; print summary | |
| slice_done = sum(1 for idx, (pn, p) in enumerate(matrices) | |
| if range_start <= idx < range_end | |
| and os.path.exists(matrix_ckpt_path(ckpt_dir, idx, pn))) | |
| print(f' slice [{range_start}:{range_end}): {slice_done} checkpointed', flush=True) | |
| return | |
| if args.skip_assembly: | |
| print(' --skip-assembly: not building HF model dirs', flush=True) | |
| return | |
| # ============================================================ | |
| # ASSEMBLY: load each config from checkpoints, write HF model | |
| # ============================================================ | |
| print(' Assembling HF models per config...', flush=True) | |
| for cn in config_names: | |
| cfg_dir = os.path.join(args.output, cn) | |
| os.makedirs(cfg_dir, exist_ok=True) | |
| model_dir = os.path.join(cfg_dir, 'model') | |
| # Aggregate stats | |
| total_groups = 0 | |
| total_depth = {1:0, 2:0, 3:0, 4:0} | |
| total_wb = 0.0; total_sb = 0.0 | |
| # Replace tensors in-place with this config's reconstruction | |
| name_to_param = {pn: p for pn, p in matrices} | |
| for idx, (pn, p) in enumerate(matrices): | |
| cp = matrix_ckpt_path(ckpt_dir, idx, pn) | |
| z = np.load(cp, allow_pickle=True) | |
| recon_w = z[f'{cn}__w'] | |
| stats = json.loads(str(z[f'{cn}__stats'][()])) | |
| p.data = __import__('torch').from_numpy(recon_w).to(p.dtype) | |
| total_groups += stats['n_groups'] | |
| for d in [1,2,3,4]: | |
| total_depth[d] += stats['depth_counts'].get(str(d), stats['depth_counts'].get(d, 0)) | |
| total_wb += stats['weight_bits'] | |
| total_sb += stats['scale_bits'] | |
| tg = max(total_groups, 1) | |
| trit_bpw = total_wb / (tg * GS) | |
| scale_bpw = total_sb / (tg * GS) | |
| total_bpw = trit_bpw + scale_bpw | |
| print(f' [{cn}] BPW={total_bpw:.3f} (trit={trit_bpw:.3f}+scale={scale_bpw:.3f})', flush=True) | |
| print(f' [{cn}] Saving to {model_dir}...', flush=True) | |
| model.save_pretrained(model_dir, safe_serialization=True) | |
| if tokenizer is not None: | |
| tokenizer.save_pretrained(model_dir) | |
| config = { | |
| 'model': os.path.basename(args.model.rstrip('/')), | |
| 'model_revision': args.revision, | |
| 'codec_version': CODEC_VERSION, | |
| 'codec_fingerprint': expected_fp, | |
| 'codec': cn, | |
| 'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw, | |
| 'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]}, | |
| 'n_matrices': len(matrices), | |
| 'group_size': GS, | |
| 'fp16_layers': ['lm_head', 'embed_tokens', '*_norm'], | |
| 'codec_params': CODECS[cn], | |
| } | |
| with open(os.path.join(cfg_dir, 'config.json'), 'w') as f: | |
| json.dump(config, f, indent=2) | |
| print(f' [{cn}] DONE: {cfg_dir}', flush=True) | |
| print(f' ALL CONFIGS COMPLETE in {time.time()-t0:.0f}s total', flush=True) | |
| if __name__ == '__main__': | |
| main() | |