| import base64, collections, copy, fcntl, glob, io, lzma, math, os |
| from pathlib import Path |
| import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F |
| from torch import Tensor, nn |
| from flash_attn_interface import ( |
| flash_attn_func as flash_attn_3_func, |
| flash_attn_varlen_func, |
| ) |
| from concurrent.futures import ThreadPoolExecutor |
| import triton |
| import triton.language as tl |
| from triton.tools.tensor_descriptor import TensorDescriptor |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| _FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" |
| _FUSED_CE_BLOCK_SIZE = 1024 |
| _FUSED_CE_NUM_WARPS = 4 |
|
|
|
|
| @triton.jit |
| def _softcapped_ce_fwd_kernel( |
| logits_ptr, losses_ptr, lse_ptr, targets_ptr, |
| stride_logits_n, stride_logits_v, |
| n_rows, n_cols, softcap, |
| block_size: tl.constexpr, |
| ): |
| row_idx = tl.program_id(0).to(tl.int64) |
| logits_row_ptr = logits_ptr + row_idx * stride_logits_n |
| max_val = -float("inf") |
| sum_exp = 0.0 |
| A = 2.0 * softcap |
| inv_C = 2.0 / softcap |
| for off in range(0, n_cols, block_size): |
| cols = off + tl.arange(0, block_size) |
| mask = cols < n_cols |
| val = tl.load( |
| logits_row_ptr + cols * stride_logits_v, |
| mask=mask, other=-float("inf"), |
| ).to(tl.float32) |
| z = A * tl.sigmoid(val * inv_C) |
| z = tl.where(mask, z, -float("inf")) |
| curr_max = tl.max(z, axis=0) |
| new_max = tl.maximum(max_val, curr_max) |
| sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) |
| max_val = new_max |
| lse = max_val + tl.log(sum_exp) |
| tl.store(lse_ptr + row_idx, lse) |
| target = tl.load(targets_ptr + row_idx).to(tl.int32) |
| target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) |
| target_z = A * tl.sigmoid(target_val * inv_C) |
| tl.store(losses_ptr + row_idx, lse - target_z) |
|
|
|
|
| @triton.jit |
| def _softcapped_ce_bwd_kernel( |
| grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, |
| stride_logits_n, stride_logits_v, |
| stride_grad_n, stride_grad_v, |
| n_rows, n_cols, softcap, |
| block_size: tl.constexpr, |
| ): |
| row_idx = tl.program_id(0).to(tl.int64) |
| logits_row_ptr = logits_ptr + row_idx * stride_logits_n |
| grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n |
| lse = tl.load(lse_ptr + row_idx) |
| grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) |
| target = tl.load(targets_ptr + row_idx).to(tl.int32) |
| A = 2.0 * softcap |
| inv_C = 2.0 / softcap |
| dz_dx_scale = A * inv_C |
| for off in range(0, n_cols, block_size): |
| cols = off + tl.arange(0, block_size) |
| mask = cols < n_cols |
| val = tl.load( |
| logits_row_ptr + cols * stride_logits_v, |
| mask=mask, other=0.0, |
| ).to(tl.float32) |
| sigmoid_u = tl.sigmoid(val * inv_C) |
| z = A * sigmoid_u |
| probs = tl.exp(z - lse) |
| grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) |
| grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) |
| tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) |
|
|
|
|
| def _validate_softcapped_ce_inputs( |
| logits: Tensor, targets: Tensor, softcap: float, |
| ) -> tuple[Tensor, Tensor]: |
| if logits.ndim != 2: |
| raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") |
| if targets.ndim != 1: |
| raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") |
| if logits.shape[0] != targets.shape[0]: |
| raise ValueError( |
| f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" |
| ) |
| if not logits.is_cuda or not targets.is_cuda: |
| raise ValueError("softcapped_cross_entropy requires CUDA tensors") |
| if softcap <= 0.0: |
| raise ValueError(f"softcap must be positive, got {softcap}") |
| if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): |
| raise ValueError(f"Unsupported logits dtype: {logits.dtype}") |
| logits = logits.contiguous() |
| targets = targets.contiguous() |
| if targets.dtype != torch.int64: |
| targets = targets.to(dtype=torch.int64) |
| return logits, targets |
|
|
|
|
| @torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) |
| def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: |
| logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) |
| n_rows, n_cols = logits.shape |
| losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) |
| lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) |
| _softcapped_ce_fwd_kernel[(n_rows,)]( |
| logits, losses, lse, targets, |
| logits.stride(0), logits.stride(1), |
| n_rows, n_cols, float(softcap), |
| block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, |
| ) |
| return losses, lse |
|
|
|
|
| @softcapped_ce_op.register_fake |
| def _(logits: Tensor, targets: Tensor, softcap: float): |
| if logits.ndim != 2 or targets.ndim != 1: |
| raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") |
| if logits.shape[0] != targets.shape[0]: |
| raise ValueError( |
| f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" |
| ) |
| n_rows = logits.shape[0] |
| return ( |
| logits.new_empty((n_rows,), dtype=torch.float32), |
| logits.new_empty((n_rows,), dtype=torch.float32), |
| ) |
|
|
|
|
| @torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) |
| def softcapped_ce_backward_op( |
| logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, |
| ) -> Tensor: |
| logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) |
| lse = lse.contiguous() |
| grad_losses = grad_losses.contiguous().to(dtype=torch.float32) |
| if lse.ndim != 1 or grad_losses.ndim != 1: |
| raise ValueError("Expected 1D lse and grad_losses") |
| if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: |
| raise ValueError( |
| f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " |
| f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" |
| ) |
| grad_logits = torch.empty_like(logits) |
| n_rows, n_cols = logits.shape |
| _softcapped_ce_bwd_kernel[(n_rows,)]( |
| grad_logits, grad_losses, lse, logits, targets, |
| logits.stride(0), logits.stride(1), |
| grad_logits.stride(0), grad_logits.stride(1), |
| n_rows, n_cols, float(softcap), |
| block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, |
| ) |
| return grad_logits |
|
|
|
|
| @softcapped_ce_backward_op.register_fake |
| def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): |
| if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: |
| raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") |
| if ( |
| logits.shape[0] != targets.shape[0] |
| or logits.shape[0] != lse.shape[0] |
| or logits.shape[0] != grad_losses.shape[0] |
| ): |
| raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") |
| return logits.new_empty(logits.shape) |
|
|
|
|
| def _softcapped_ce_setup_context( |
| ctx: torch.autograd.function.FunctionCtx, inputs, output, |
| ) -> None: |
| logits, targets, softcap = inputs |
| _losses, lse = output |
| ctx.save_for_backward(logits, targets, lse) |
| ctx.softcap = float(softcap) |
|
|
|
|
| def _softcapped_ce_backward( |
| ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", |
| ): |
| del grad_lse |
| logits, targets, lse = ctx.saved_tensors |
| grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( |
| logits, targets, lse, grad_losses, ctx.softcap |
| ) |
| return grad_logits, None, None |
|
|
|
|
| softcapped_ce_op.register_autograd( |
| _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, |
| ) |
|
|
|
|
| def softcapped_cross_entropy( |
| logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", |
| ) -> Tensor: |
| losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( |
| logits, targets, float(softcap) |
| ) |
| if reduction == "none": |
| return losses |
| if reduction == "sum": |
| return losses.sum() |
| if reduction == "mean": |
| return losses.mean() |
| raise ValueError(f"Unsupported reduction={reduction!r}") |
|
|
|
|
| class Hyperparameters: |
| data_dir = os.environ.get("DATA_DIR", "./data/") |
| seed = int(os.environ.get("SEED", 1337)) |
| run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) |
| iterations = int(os.environ.get("ITERATIONS", 20000)) |
| warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) |
| warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) |
| train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) |
| |
| |
| fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) |
| train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) |
| train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) |
| max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) |
| val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) |
| eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) |
| val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) |
| vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) |
| num_layers = int(os.environ.get("NUM_LAYERS", 11)) |
| xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) |
| model_dim = int(os.environ.get("MODEL_DIM", 512)) |
| num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) |
| num_heads = int(os.environ.get("NUM_HEADS", 8)) |
| mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) |
| skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) |
| tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) |
| logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) |
| rope_base = float(os.environ.get("ROPE_BASE", 1e4)) |
| rope_dims = int(os.environ.get("ROPE_DIMS", 16)) |
| rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) |
| rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) |
| ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) |
| qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) |
| num_loops = int(os.environ.get("NUM_LOOPS", 2)) |
| loop_start = int(os.environ.get("LOOP_START", 3)) |
| loop_end = int(os.environ.get("LOOP_END", 5)) |
| enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) |
| parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) |
| parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") |
| min_lr = float(os.environ.get("MIN_LR", 0.0)) |
| embed_lr = float(os.environ.get("EMBED_LR", 0.6)) |
| tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) |
| tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) |
| matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) |
| scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) |
| muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) |
| muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) |
| muon_momentum_warmup_start = float( |
| os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) |
| ) |
| muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) |
| muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) |
| beta1 = float(os.environ.get("BETA1", 0.9)) |
| beta2 = float(os.environ.get("BETA2", 0.95)) |
| adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) |
| grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) |
| eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) |
| adam_wd = float(os.environ.get("ADAM_WD", 0.02)) |
| muon_wd = float(os.environ.get("MUON_WD", 0.095)) |
| embed_wd = float(os.environ.get("EMBED_WD", 0.085)) |
| ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) |
| ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) |
| ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) |
| ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) |
| ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) |
| ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) |
| ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) |
| ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) |
| ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 1.0)) |
| ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) |
| ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) |
| ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) |
| ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) |
| ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) |
| ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") |
| ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") |
| val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) |
| compressor = os.environ.get("COMPRESSOR", "brotli") |
| gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) |
| gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0)) |
| phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000)) |
| phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 1)) |
| global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) |
| global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) |
| global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) |
| global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) |
| global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) |
| global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) |
| global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) |
| global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) |
| global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) |
| matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) |
| embed_bits = int(os.environ.get("EMBED_BITS", 8)) |
| matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) |
| embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) |
| mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 10.0)) |
| attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) |
| |
| |
| |
| attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) |
| attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") |
| |
| |
| |
| smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "0"))) |
| |
| gate_window = int(os.environ.get("GATE_WINDOW", 12)) |
| |
| |
| |
| |
| |
| |
| |
| gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) |
| gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) |
| |
| |
| |
| |
| |
| |
| gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "0"))) |
| |
| |
| |
| |
| |
| |
| |
| sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "0"))) |
| sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) |
| sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 1.0)) |
| |
| |
| lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) |
| lqer_rank = int(os.environ.get("LQER_RANK", 4)) |
| lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) |
| lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) |
| lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) |
| lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) |
| distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ |
| rank = int(os.environ.get("RANK", "0")) |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| is_main_process = rank == 0 |
| grad_accum_steps = 8 // world_size |
| |
| |
| |
| |
| |
| caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "0"))) |
| _default_caseops_data = os.path.join( |
| data_dir, |
| "datasets", |
| "fineweb10B_sp8192_caseops", |
| "datasets", |
| "datasets", |
| "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", |
| ) |
| _default_caseops_tok = os.path.join( |
| data_dir, |
| "datasets", |
| "fineweb10B_sp8192_caseops", |
| "datasets", |
| "tokenizers", |
| "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", |
| ) |
| if caseops_enabled: |
| datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) |
| tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) |
| else: |
| datasets_dir = os.environ.get( |
| "DATA_PATH", |
| os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), |
| ) |
| tokenizer_path = os.environ.get( |
| "TOKENIZER_PATH", |
| os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), |
| ) |
| train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") |
| val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") |
| val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") |
| artifact_dir = os.environ.get("ARTIFACT_DIR", "") |
| logfile = ( |
| os.path.join(artifact_dir, f"{run_id}.txt") |
| if artifact_dir |
| else f"logs/{run_id}.txt" |
| ) |
| model_path = ( |
| os.path.join(artifact_dir, "final_model.pt") |
| if artifact_dir |
| else "final_model.pt" |
| ) |
| quantized_model_path = ( |
| os.path.join(artifact_dir, "final_model.int6.ptz") |
| if artifact_dir |
| else "final_model.int6.ptz" |
| ) |
|
|
|
|
| |
| |
| |
| TEST_ID = "2026-04-30_pr1855_sp10240_caseops_repro_8x" |
| TEST_DATE = "2026-04-30" |
| RUN_LABEL = "standard_8x" |
| RUN_KIND = "new_experiment" |
| SOURCE_PARENT = "legs/2026-04-30_pr1855_sp8192_lqer_smeargate_repro_8x/run.py" |
| SOURCE_PARENT_SHA256 = "454f710d174be80f4603069ca952833d694f60d1d34c0c25703528323bc8878b" |
| SOURCE_TOKENIZER_LANE = "scripts/prepare_sp10240_caseops_data.py" |
| PARENT_RUN = "2026-04-30_pr1855_sp8192_lqer_smeargate_repro_8x" |
| HYPOTHESIS = ( |
| "Port the accepted PR1855 CaseOps/LQER/pergroup/phased-TTT stack to the new " |
| "SP10240 CaseOps tokenizer/data sidecar. Keep PR1855 body shape and loop " |
| "policy fixed; only vocab/tokenizer/data changes from SP8192 to SP10240." |
| ) |
| SIZE_CAP_BYTES = 16000000 |
| BUILD_SECONDS = 600 |
| EVAL_SECONDS = 600 |
|
|
| Hyperparameters.test_id = TEST_ID |
| Hyperparameters.test_date = TEST_DATE |
| Hyperparameters.run_label = RUN_LABEL |
| Hyperparameters.run_kind = RUN_KIND |
| Hyperparameters.source_parent = SOURCE_PARENT |
| Hyperparameters.source_parent_sha256 = SOURCE_PARENT_SHA256 |
| Hyperparameters.source_tokenizer_lane = SOURCE_TOKENIZER_LANE |
| Hyperparameters.parent_run = PARENT_RUN |
| Hyperparameters.hypothesis = HYPOTHESIS |
| Hyperparameters.size_cap_bytes = SIZE_CAP_BYTES |
| Hyperparameters.build_seconds = BUILD_SECONDS |
| Hyperparameters.eval_seconds = EVAL_SECONDS |
|
|
| Hyperparameters.data_dir = "/workspace/SOTA_FINAL/data" |
| _caseops_root = os.path.join( |
| Hyperparameters.data_dir, "datasets", "fineweb10B_sp10240_caseops", "datasets" |
| ) |
| Hyperparameters.vocab_size = 10240 |
| Hyperparameters.caseops_enabled = True |
| Hyperparameters.datasets_dir = os.path.join( |
| _caseops_root, "datasets", "fineweb10B_sp10240_lossless_caps_caseops_v1_reserved" |
| ) |
| Hyperparameters.train_files = os.path.join(Hyperparameters.datasets_dir, "fineweb_train_*.bin") |
| Hyperparameters.val_files = os.path.join(Hyperparameters.datasets_dir, "fineweb_val_*.bin") |
| Hyperparameters.val_bytes_files = os.path.join(Hyperparameters.datasets_dir, "fineweb_val_bytes_*.bin") |
| Hyperparameters.tokenizer_path = os.path.join( |
| _caseops_root, "tokenizers", "fineweb_10240_bpe_lossless_caps_caseops_v1_reserved.model" |
| ) |
|
|
| Hyperparameters.seed = 42 |
| Hyperparameters.run_id = "pr1855_sp10240_caseops_repro_8x_seed42" |
| Hyperparameters.artifact_dir = "logs" |
| Hyperparameters.logfile = os.path.join(Hyperparameters.artifact_dir, f"{Hyperparameters.run_id}.txt") |
| Hyperparameters.model_path = os.path.join(Hyperparameters.artifact_dir, "final_model.pt") |
| Hyperparameters.quantized_model_path = os.path.join(Hyperparameters.artifact_dir, "final_model.int6.ptz") |
| Hyperparameters.iterations = 20000 |
| Hyperparameters.max_wallclock_seconds = float(BUILD_SECONDS) |
| Hyperparameters.num_layers = 11 |
| Hyperparameters.xsa_last_n = 11 |
| Hyperparameters.model_dim = 512 |
| Hyperparameters.num_heads = 8 |
| Hyperparameters.num_kv_heads = 4 |
| Hyperparameters.mlp_mult = 4.0 |
| Hyperparameters.num_loops = 2 |
| Hyperparameters.loop_start = 3 |
| Hyperparameters.loop_end = 5 |
| Hyperparameters.enable_looping_at = 0.35 |
| Hyperparameters.parallel_start_layer = 8 |
| Hyperparameters.qk_gain_init = 5.0 |
| Hyperparameters.warmdown_frac = 0.85 |
| Hyperparameters.warmup_steps = 20 |
| Hyperparameters.min_lr = 0.1 |
| Hyperparameters.matrix_lr = 0.026 |
| Hyperparameters.beta2 = 0.99 |
| Hyperparameters.muon_backend_steps = 5 |
| Hyperparameters.grad_clip_norm = 0.3 |
| Hyperparameters.val_loss_every = 0 |
| Hyperparameters.ttt_enabled = True |
| Hyperparameters.ttt_lora_rank = 80 |
| Hyperparameters.ttt_chunk_size = 48 |
| Hyperparameters.ttt_weight_decay = 0.5 |
| Hyperparameters.ttt_beta2 = 0.99 |
| Hyperparameters.phased_ttt_prefix_docs = 2500 |
| Hyperparameters.phased_ttt_num_phases = 3 |
| Hyperparameters.global_ttt_momentum = 0.9 |
| Hyperparameters.compressor = "pergroup" |
| Hyperparameters.gptq_reserve_seconds = 0.5 |
| Hyperparameters.gptq_calibration_batches = 16 |
| Hyperparameters.matrix_bits = 6 |
| Hyperparameters.embed_bits = 7 |
| Hyperparameters.mlp_clip_sigmas = 11.5 |
| Hyperparameters.attn_clip_sigmas = 13.0 |
| Hyperparameters.embed_clip_sigmas = 14.0 |
| Hyperparameters.gated_attn_quant_gate = True |
| Hyperparameters.sparse_attn_gate_enabled = True |
| Hyperparameters.sparse_attn_gate_scale = 0.5 |
| Hyperparameters.gate_window = 12 |
| Hyperparameters.smear_gate_enabled = True |
| Hyperparameters.lqer_enabled = True |
| Hyperparameters.lqer_asym_enabled = True |
| Hyperparameters.lqer_rank = 4 |
| Hyperparameters.lqer_factor_bits = 4 |
| Hyperparameters.lqer_asym_group = 64 |
| Hyperparameters.lqer_top_k = 3 |
| Hyperparameters.fused_ce_enabled = True |
|
|
| _logger_hparams = None |
|
|
|
|
| def set_logging_hparams(h): |
| global _logger_hparams |
| _logger_hparams = h |
|
|
|
|
| def log(msg, console=True): |
| if _logger_hparams is None: |
| print(msg) |
| return |
| if _logger_hparams.is_main_process: |
| if console: |
| print(msg) |
| if _logger_hparams.logfile is not None: |
| with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: |
| print(msg, file=f) |
|
|
|
|
| class ValidationData: |
| def __init__(self, h, device): |
| self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) |
| if int(self.sp.vocab_size()) != h.vocab_size: |
| raise ValueError( |
| f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" |
| ) |
| self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) |
| self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) |
| if self.caseops_enabled: |
| self.base_bytes_lut = None |
| self.has_leading_space_lut = None |
| self.is_boundary_token_lut = None |
| else: |
| ( |
| self.base_bytes_lut, |
| self.has_leading_space_lut, |
| self.is_boundary_token_lut, |
| ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) |
| self.val_bytes = None |
| if self.caseops_enabled: |
| self.val_bytes = load_validation_byte_sidecar( |
| h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() |
| ) |
|
|
|
|
| def build_sentencepiece_luts(sp, vocab_size, device): |
| sp_vocab_size = int(sp.vocab_size()) |
| assert ( |
| sp.piece_to_id("▁") != sp.unk_id() |
| ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" |
| table_size = max(sp_vocab_size, vocab_size) |
| base_bytes_np = np.zeros((table_size,), dtype=np.int16) |
| has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) |
| is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) |
| for token_id in range(sp_vocab_size): |
| if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): |
| continue |
| is_boundary_token_np[token_id] = False |
| if sp.is_byte(token_id): |
| base_bytes_np[token_id] = 1 |
| continue |
| piece = sp.id_to_piece(token_id) |
| if piece.startswith("▁"): |
| has_leading_space_np[token_id] = True |
| piece = piece[1:] |
| base_bytes_np[token_id] = len(piece.encode("utf-8")) |
| return ( |
| torch.tensor(base_bytes_np, dtype=torch.int16, device=device), |
| torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), |
| torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), |
| ) |
|
|
|
|
| def load_validation_tokens(pattern, seq_len): |
| |
| files = [ |
| Path(p) |
| for p in sorted(glob.glob(pattern)) |
| if "_bytes_" not in Path(p).name |
| ] |
| if not files: |
| raise FileNotFoundError(f"No files found for pattern: {pattern}") |
| tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() |
| usable = (tokens.numel() - 1) // seq_len * seq_len |
| if usable <= 0: |
| raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") |
| return tokens[: usable + 1] |
|
|
|
|
| def load_validation_byte_sidecar(pattern, seq_len, expected_len): |
| """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards |
| (256 int32 header + uint16 array). Each entry = canonical raw-text byte |
| budget for that token in the corresponding val shard. Returns a CPU |
| int16 tensor sliced to match expected_len (i.e. val_tokens length).""" |
| files = [Path(p) for p in sorted(glob.glob(pattern))] |
| if not files: |
| raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") |
| shards = [load_data_shard(file) for file in files] |
| |
| bytes_full = torch.cat(shards).contiguous() |
| if bytes_full.numel() < expected_len: |
| raise ValueError( |
| f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" |
| ) |
| return bytes_full[:expected_len].to(torch.int32) |
|
|
|
|
| def load_data_shard(file): |
| header_bytes = 256 * np.dtype("<i4").itemsize |
| token_bytes = np.dtype("<u2").itemsize |
| header = np.fromfile(file, dtype="<i4", count=256) |
| if header.size != 256 or int(header[0]) != 20240520 or int(header[1]) != 1: |
| raise ValueError(f"Unexpected shard header for {file}") |
| num_tokens = int(header[2]) |
| expected_size = header_bytes + num_tokens * token_bytes |
| if file.stat().st_size != expected_size: |
| raise ValueError( |
| f"Shard size mismatch for {file}: expected {expected_size} bytes" |
| ) |
| tokens_np = np.fromfile(file, dtype="<u2", count=num_tokens, offset=header_bytes) |
| if tokens_np.size != num_tokens: |
| raise ValueError(f"Short read for {file}") |
| return torch.from_numpy(tokens_np.astype(np.uint16, copy=False)) |
|
|
|
|
| _SHARD_HEADER_BYTES = 256 * np.dtype("<i4").itemsize |
| _SHARD_NTOKENS_CACHE = {} |
| _MMAP_CACHE = {} |
|
|
|
|
| def _read_num_tokens(file): |
| key = str(file) |
| cached = _SHARD_NTOKENS_CACHE.get(key) |
| if cached is not None: |
| return cached |
| header = np.fromfile(file, dtype="<i4", count=256) |
| if header.size != 256 or int(header[0]) != 20240520 or int(header[1]) != 1: |
| raise ValueError(f"Unexpected shard header for {file}") |
| n = int(header[2]) |
| _SHARD_NTOKENS_CACHE[key] = n |
| return n |
|
|
|
|
| def _get_shard_memmap(file): |
| key = str(file) |
| mm = _MMAP_CACHE.get(key) |
| if mm is not None: |
| return mm |
| n = _read_num_tokens(file) |
| mm = np.memmap(file, mode="r", dtype="<u2", offset=_SHARD_HEADER_BYTES, shape=(n,)) |
| _MMAP_CACHE[key] = mm |
| return mm |
|
|
|
|
| BOS_ID = None |
|
|
|
|
| def get_next_multiple_of_n(v, n): |
| return ((v + n - 1) // n) * n |
|
|
|
|
| def _build_cu_seqlens(bos_pos, total_len, device, max_doc_len=0, bucket_size=64): |
| if not bos_pos or bos_pos[0] != 0: |
| bos_pos = [0] + bos_pos |
| seg_starts = [] |
| starts_with_end = bos_pos + [total_len] |
| for i in range(len(starts_with_end) - 1): |
| start = starts_with_end[i] |
| end = starts_with_end[i + 1] |
| if max_doc_len > 0: |
| pos = start |
| while pos < end: |
| seg_starts.append(pos) |
| pos += max_doc_len |
| else: |
| seg_starts.append(start) |
| boundaries = seg_starts + [total_len] |
| padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) |
| cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) |
| cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) |
| seg_ends = seg_starts[1:] + [total_len] |
| max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) |
| return cu, max_seqlen |
|
|
| class DocumentPackingLoader: |
| _shard_pool = ThreadPoolExecutor(1) |
|
|
| def __init__(self, h, device, cu_bucket_size=64): |
| self.rank = h.rank |
| self.world_size = h.world_size |
| self.device = device |
| self.cu_bucket_size = cu_bucket_size |
| self.max_seq_len = h.train_seq_len |
| all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] |
| if not all_files: |
| raise FileNotFoundError(f"No files found for pattern: {h.train_files}") |
| self.files = all_files |
| self.file_iter = iter(self.files) |
| self._init_shard(load_data_shard(next(self.file_iter))) |
| self._next_shard = self._submit_next_shard() |
| self._batch_pool = ThreadPoolExecutor(1) |
| self._prefetch_queue = [] |
|
|
| def _init_shard(self, tokens): |
| global BOS_ID |
| self.tokens = tokens |
| self.shard_size = tokens.numel() |
| if BOS_ID is None: |
| BOS_ID = 1 |
| self.bos_idx = ( |
| (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() |
| ) |
| self.cursor = int(self.bos_idx[0]) |
|
|
| def _submit_next_shard(self): |
| try: |
| path = next(self.file_iter) |
| return self._shard_pool.submit(load_data_shard, path) |
| except StopIteration: |
| return None |
|
|
| def _advance_shard(self): |
| if self._next_shard is None: |
| self.file_iter = iter(self.files) |
| self._next_shard = self._shard_pool.submit( |
| load_data_shard, next(self.file_iter) |
| ) |
| self._init_shard(self._next_shard.result()) |
| self._next_shard = self._submit_next_shard() |
|
|
| def _local_doc_starts(self, local_start, total_len): |
| lo = np.searchsorted(self.bos_idx, local_start, side="left") |
| hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") |
| return (self.bos_idx[lo:hi] - local_start).tolist() |
|
|
| def _prepare_batch(self, num_tokens_local, max_seq_len): |
| per_rank_span = num_tokens_local + 1 |
| global_span = per_rank_span * self.world_size |
| while self.cursor + global_span > self.shard_size: |
| self._advance_shard() |
| local_start = self.cursor + self.rank * per_rank_span |
| buf = self.tokens[local_start : local_start + per_rank_span] |
| inputs = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) |
| targets = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) |
| inputs.copy_(buf[:-1]) |
| targets.copy_(buf[1:]) |
| starts = self._local_doc_starts(local_start, inputs.numel()) |
| cu_seqlens, max_seqlen = _build_cu_seqlens( |
| starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size |
| ) |
| cu_seqlens = cu_seqlens.pin_memory() |
| self.cursor += global_span |
| return inputs, targets, cu_seqlens, max_seqlen |
|
|
| def next_batch(self, global_tokens, grad_accum_steps): |
| num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) |
| while len(self._prefetch_queue) < 2: |
| self._prefetch_queue.append( |
| self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) |
| inputs, targets, cu_seqlens, max_seqlen = self._prefetch_queue.pop(0).result() |
| self._prefetch_queue.append( |
| self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) |
| return ( |
| inputs[None].to(self.device, non_blocking=True), |
| targets[None].to(self.device, non_blocking=True), |
| cu_seqlens.to(self.device, non_blocking=True), |
| max_seqlen, |
| ) |
|
|
|
|
| class ShuffledSequenceLoader: |
| def __init__(self, h, device): |
| self.world_size = h.world_size |
| self.seq_len = h.train_seq_len |
| self.device = device |
| all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] |
| if not all_files: |
| raise FileNotFoundError(f"No files found for pattern: {h.train_files}") |
| self.files = all_files[h.rank :: h.world_size] |
| self.rng = np.random.Generator(np.random.PCG64(h.rank)) |
| self.num_tokens = [_read_num_tokens(f) for f in self.files] |
| self.start_inds = [[] for _ in self.files] |
| for si in range(len(self.files)): |
| self._reset_shard(si) |
|
|
| def _reset_shard(self, si): |
| max_phase = min( |
| self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) |
| ) |
| phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 |
| num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len |
| sequence_order = self.rng.permutation(num_sequences) |
| self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() |
|
|
| def next_batch(self, global_tokens, grad_accum_steps): |
| device_tokens = global_tokens // (self.world_size * grad_accum_steps) |
| device_batch_size = device_tokens // self.seq_len |
| remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) |
| x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) |
| y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) |
| for bi in range(device_batch_size): |
| total = remaining.sum() |
| if total <= 0: |
| for si in range(len(self.files)): |
| self._reset_shard(si) |
| remaining = np.array( |
| [len(s) for s in self.start_inds], dtype=np.float64 |
| ) |
| total = remaining.sum() |
| probs = remaining / total |
| si = int(self.rng.choice(len(self.files), p=probs)) |
| start_ind = self.start_inds[si].pop() |
| remaining[si] -= 1 |
| mm = _get_shard_memmap(self.files[si]) |
| window = torch.as_tensor( |
| np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) |
| ) |
| x[bi] = window[:-1] |
| y[bi] = window[1:] |
| return x.to(self.device, non_blocking=True), y.to( |
| self.device, non_blocking=True |
| ) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, eps=None): |
| super().__init__() |
| self.eps = eps |
|
|
| def forward(self, x): |
| return F.rms_norm(x, (x.size(-1),), eps=self.eps) |
|
|
|
|
| class CastedLinear(nn.Linear): |
| def forward(self, x): |
| w = self.weight.to(x.dtype) |
| bias = self.bias.to(x.dtype) if self.bias is not None else None |
| return F.linear(x, w, bias) |
|
|
|
|
| @triton.jit |
| def linear_leaky_relu_square_kernel( |
| a_desc, |
| b_desc, |
| c_desc, |
| aux_desc, |
| M, |
| N, |
| K, |
| BLOCK_SIZE_M: tl.constexpr, |
| BLOCK_SIZE_N: tl.constexpr, |
| BLOCK_SIZE_K: tl.constexpr, |
| NUM_SMS: tl.constexpr, |
| FORWARD: tl.constexpr, |
| ): |
| dtype = tl.bfloat16 |
| start_pid = tl.program_id(axis=0) |
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| k_tiles = tl.cdiv(K, BLOCK_SIZE_K) |
| num_tiles = num_pid_m * num_pid_n |
| tile_id_c = start_pid - NUM_SMS |
| for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): |
| pid_m = tile_id // num_pid_n |
| pid_n = tile_id % num_pid_n |
| offs_am = pid_m * BLOCK_SIZE_M |
| offs_bn = pid_n * BLOCK_SIZE_N |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| for ki in range(k_tiles): |
| offs_k = ki * BLOCK_SIZE_K |
| a = a_desc.load([offs_am, offs_k]) |
| b = b_desc.load([offs_bn, offs_k]) |
| accumulator = tl.dot(a, b.T, accumulator) |
| tile_id_c += NUM_SMS |
| offs_am_c = offs_am |
| offs_bn_c = offs_bn |
| acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) |
| acc = tl.permute(acc, (0, 2, 1)) |
| acc0, acc1 = tl.split(acc) |
| c0 = acc0.to(dtype) |
| c1 = acc1.to(dtype) |
| if not FORWARD: |
| pre0 = aux_desc.load([offs_am_c, offs_bn_c]) |
| pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) |
| c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.5 * pre0) |
| c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.5 * pre1) |
| c_desc.store([offs_am_c, offs_bn_c], c0) |
| c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) |
| if FORWARD: |
| aux0 = tl.where(c0 > 0, c0, 0.5 * c0) |
| aux1 = tl.where(c1 > 0, c1, 0.5 * c1) |
| aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) |
| aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) |
|
|
|
|
| def linear_leaky_relu_square(a, b, aux=None): |
| M, K = a.shape |
| N, K2 = b.shape |
| assert K == K2 |
| c = torch.empty((M, N), device=a.device, dtype=a.dtype) |
| forward = aux is None |
| if aux is None: |
| aux = torch.empty((M, N), device=a.device, dtype=a.dtype) |
| num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count |
| BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 256, 128, 64 |
| num_stages = 4 if forward else 3 |
| a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) |
| b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) |
| c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) |
| aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) |
| grid = lambda _meta: ( |
| min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), |
| ) |
| linear_leaky_relu_square_kernel[grid]( |
| a_desc, |
| b_desc, |
| c_desc, |
| aux_desc, |
| M, |
| N, |
| K, |
| BLOCK_SIZE_M=BLOCK_SIZE_M, |
| BLOCK_SIZE_N=BLOCK_SIZE_N, |
| BLOCK_SIZE_K=BLOCK_SIZE_K, |
| NUM_SMS=num_sms, |
| FORWARD=forward, |
| num_stages=num_stages, |
| num_warps=8, |
| ) |
| if forward: |
| return c, aux |
| return c |
|
|
|
|
| class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, w1, w2): |
| x_flat = x.reshape(-1, x.shape[-1]) |
| pre, post = linear_leaky_relu_square(x_flat, w1) |
| out = F.linear(post, w2) |
| ctx.save_for_backward(x, w1, w2, pre, post) |
| return out.view(*x.shape[:-1], out.shape[-1]) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| x, w1, w2, pre, post = ctx.saved_tensors |
| x_flat = x.reshape(-1, x.shape[-1]) |
| grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) |
| dw2 = grad_output_flat.T @ post |
| dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) |
| dw1 = dpre.T @ x_flat |
| dx = dpre @ w1 |
| return dx.view_as(x), dw1, dw2 |
|
|
|
|
| FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply |
|
|
|
|
| class Rotary(nn.Module): |
| def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): |
| super().__init__() |
| self.dim = dim |
| self.base = base |
| self.train_seq_len = train_seq_len |
| self.yarn = yarn |
| self.rope_dims = rope_dims if rope_dims > 0 else dim |
| inv_freq = 1.0 / base ** ( |
| torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._seq_len_cached = 0 |
| self._cos_cached = None |
| self._sin_cached = None |
|
|
| def forward(self, seq_len, device, dtype): |
| if ( |
| self._cos_cached is None |
| or self._sin_cached is None |
| or self._seq_len_cached < seq_len |
| or self._cos_cached.device != device |
| ): |
| rd = self.rope_dims |
| if self.yarn and seq_len > self.train_seq_len: |
| scale = seq_len / self.train_seq_len |
| new_base = self.base * scale ** (rd / (rd - 2)) |
| inv_freq = 1.0 / new_base ** ( |
| torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd |
| ) |
| else: |
| inv_freq = self.inv_freq.float().to(device) |
| t = torch.arange(seq_len, device=device, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| self._cos_cached = freqs.cos()[None, :, None, :] |
| self._sin_cached = freqs.sin()[None, :, None, :] |
| self._seq_len_cached = seq_len |
| return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) |
|
|
|
|
| def apply_rotary_emb(x, cos, sin, rope_dims=0): |
| if rope_dims > 0 and rope_dims < x.size(-1): |
| x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] |
| half = rope_dims // 2 |
| x1, x2 = x_rope[..., :half], x_rope[..., half:] |
| x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) |
| return torch.cat((x_rope, x_pass), dim=-1) |
| half = x.size(-1) // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__( |
| self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, |
| attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, |
| gated_attn=False, gated_attn_init_std=0.01, |
| sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, |
| ): |
| super().__init__() |
| if dim % num_heads != 0: |
| raise ValueError("model_dim must be divisible by num_heads") |
| if num_heads % num_kv_heads != 0: |
| raise ValueError("num_heads must be divisible by num_kv_heads") |
| if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: |
| raise ValueError( |
| "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" |
| ) |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
| self.head_dim = dim // num_heads |
| if self.head_dim % 2 != 0: |
| raise ValueError("head_dim must be even for RoPE") |
| self.q_gain = nn.Parameter( |
| torch.full((num_heads,), qk_gain_init, dtype=torch.float32) |
| ) |
| self.rope_dims = 0 |
| self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) |
| self.use_xsa = False |
| |
| |
| |
| self.attn_out_gate = attn_out_gate |
| self.attn_out_gate_src = attn_out_gate_src |
| self.gate_window = gate_window |
| if attn_out_gate: |
| self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) |
| self.attn_gate_proj._zero_init = True |
| |
| |
| |
| |
| |
| |
| self.gated_attn = gated_attn |
| if gated_attn: |
| W = torch.empty(num_heads, dim, dtype=torch.float32) |
| nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) |
| self.attn_gate_w = nn.Parameter(W) |
| |
| |
| |
| |
| |
| self.sparse_attn_gate = sparse_attn_gate |
| self.sparse_attn_gate_scale = sparse_attn_gate_scale |
| if sparse_attn_gate: |
| W = torch.empty(num_heads, gate_window, dtype=torch.float32) |
| if sparse_attn_gate_init_std > 0: |
| nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) |
| else: |
| nn.init.zeros_(W) |
| self.attn_gate_w = nn.Parameter(W) |
|
|
| def _xsa_efficient(self, y, v): |
| B, T, H, D = y.shape |
| Hkv = v.size(-2) |
| group = H // Hkv |
| y_g = y.reshape(B, T, Hkv, group, D) |
| vn = F.normalize(v, dim=-1).unsqueeze(-2) |
| proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn |
| return (y_g - proj).reshape(B, T, H, D) |
|
|
| def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): |
| bsz, seqlen, dim = x.shape |
| |
| |
| q_raw = F.linear(x, q_w.to(x.dtype)) |
| q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) |
| k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) |
| v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) |
| q = F.rms_norm(q, (q.size(-1),)) |
| k = F.rms_norm(k, (k.size(-1),)) |
| cos, sin = self.rotary(seqlen, x.device, q.dtype) |
| q = apply_rotary_emb(q, cos, sin, self.rope_dims) |
| k = apply_rotary_emb(k, cos, sin, self.rope_dims) |
| q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] |
| if cu_seqlens is not None: |
| y = flash_attn_varlen_func( |
| q[0], |
| k[0], |
| v[0], |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=max_seqlen, |
| max_seqlen_k=max_seqlen, |
| causal=True, |
| window_size=(-1, -1), |
| )[None] |
| else: |
| y = flash_attn_3_func(q, k, v, causal=True) |
| if self.use_xsa: |
| y = self._xsa_efficient(y, v) |
| |
| |
| |
| |
| if self.attn_out_gate: |
| gate_src = q_raw if self.attn_out_gate_src == "q" else x |
| gate_in = gate_src[..., : self.gate_window].contiguous() |
| g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) |
| y = y * g[..., None] |
| |
| |
| |
| |
| if self.gated_attn: |
| x_c = x.contiguous() |
| g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) |
| y = y * g[..., None] |
| |
| if self.sparse_attn_gate: |
| gate_in = x[..., : self.gate_window].contiguous() |
| g = torch.sigmoid( |
| self.sparse_attn_gate_scale |
| * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) |
| ) |
| y = y * g[..., None] |
| y = y.reshape(bsz, seqlen, dim) |
| self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None |
| return F.linear(y, out_w.to(x.dtype)) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, dim, mlp_mult): |
| super().__init__() |
| self.use_fused = True |
|
|
| def forward(self, x, up_w, down_w): |
| if self.training and self.use_fused: |
| return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) |
| hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() |
| self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None |
| return F.linear(hidden, down_w.to(x.dtype)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| dim, |
| num_heads, |
| num_kv_heads, |
| mlp_mult, |
| rope_base, |
| qk_gain_init, |
| train_seq_len, |
| layer_idx=0, |
| ln_scale=False, |
| yarn=True, |
| attn_out_gate=False, |
| attn_out_gate_src="proj", |
| gate_window=12, |
| gated_attn=False, |
| gated_attn_init_std=0.01, |
| sparse_attn_gate=False, |
| sparse_attn_gate_init_std=0.0, |
| sparse_attn_gate_scale=1.0, |
| ): |
| super().__init__() |
| self.attn_norm = RMSNorm() |
| self.mlp_norm = RMSNorm() |
| self.attn = CausalSelfAttention( |
| dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, |
| attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, |
| gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, |
| sparse_attn_gate=sparse_attn_gate, |
| sparse_attn_gate_init_std=sparse_attn_gate_init_std, |
| sparse_attn_gate_scale=sparse_attn_gate_scale, |
| ) |
| self.mlp = MLP(dim, mlp_mult) |
| self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
| self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
| self.resid_mix = nn.Parameter( |
| torch.stack((torch.ones(dim), torch.zeros(dim))).float() |
| ) |
| self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 |
|
|
| def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): |
| mix = self.resid_mix.to(dtype=x.dtype) |
| x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 |
| attn_out = self.attn( |
| self.attn_norm(x_in) * self.ln_scale_factor, |
| q_w, k_w, v_w, out_w, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
| x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out |
| x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ |
| None, None, : |
| ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) |
| return x_out |
|
|
| class GPT(nn.Module): |
| def __init__(self, h): |
| super().__init__() |
| if h.logit_softcap <= 0.0: |
| raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") |
| self.tie_embeddings = h.tie_embeddings |
| self.tied_embed_init_std = h.tied_embed_init_std |
| self.logit_softcap = h.logit_softcap |
| self.fused_ce_enabled = bool(h.fused_ce_enabled) |
| self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) |
| self.num_layers = h.num_layers |
| head_dim = h.model_dim // h.num_heads |
| kv_dim = h.num_kv_heads * head_dim |
| hidden_dim = int(h.mlp_mult * h.model_dim) |
| self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) |
| self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) |
| self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) |
| self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) |
| self.num_encoder_layers = h.num_layers // 2 |
| self.num_decoder_layers = h.num_layers - self.num_encoder_layers |
| self.blocks = nn.ModuleList( |
| [ |
| Block( |
| h.model_dim, |
| h.num_heads, |
| h.num_kv_heads, |
| h.mlp_mult, |
| h.rope_base, |
| h.qk_gain_init, |
| h.train_seq_len, |
| layer_idx=i, |
| ln_scale=h.ln_scale, |
| yarn=h.rope_yarn, |
| attn_out_gate=h.attn_out_gate_enabled, |
| attn_out_gate_src=h.attn_out_gate_src, |
| gate_window=h.gate_window, |
| gated_attn=h.gated_attn_enabled, |
| gated_attn_init_std=h.gated_attn_init_std, |
| sparse_attn_gate=h.sparse_attn_gate_enabled, |
| sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, |
| sparse_attn_gate_scale=h.sparse_attn_gate_scale, |
| ) |
| for i in range(h.num_layers) |
| ] |
| ) |
| if h.rope_dims > 0: |
| head_dim = h.model_dim // h.num_heads |
| for block in self.blocks: |
| block.attn.rope_dims = h.rope_dims |
| block.attn.rotary = Rotary( |
| head_dim, |
| base=h.rope_base, |
| train_seq_len=h.train_seq_len, |
| rope_dims=h.rope_dims, |
| yarn=h.rope_yarn, |
| ) |
| self.final_norm = RMSNorm() |
| self.lm_head = ( |
| None |
| if h.tie_embeddings |
| else CastedLinear(h.model_dim, h.vocab_size, bias=False) |
| ) |
| if self.lm_head is not None: |
| self.lm_head._zero_init = True |
| if h.xsa_last_n > 0: |
| for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): |
| self.blocks[i].attn.use_xsa = True |
| self.looping_active = False |
| if h.num_loops > 0: |
| loop_seg = list(range(h.loop_start, h.loop_end + 1)) |
| all_indices = list(range(h.loop_start)) |
| for _ in range(h.num_loops + 1): |
| all_indices.extend(loop_seg) |
| all_indices.extend(range(h.loop_end + 1, h.num_layers)) |
| num_enc = len(all_indices) // 2 |
| self.encoder_indices = all_indices[:num_enc] |
| self.decoder_indices = all_indices[num_enc:] |
| else: |
| self.encoder_indices = list(range(self.num_encoder_layers)) |
| self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) |
| self.num_skip_weights = min( |
| len(self.encoder_indices), len(self.decoder_indices) |
| ) |
| self.skip_weights = nn.Parameter( |
| torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) |
| ) |
| self.skip_gates = ( |
| nn.Parameter( |
| torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) |
| ) |
| if h.skip_gates_enabled |
| else None |
| ) |
| self.parallel_start_layer = h.parallel_start_layer |
| self.parallel_final_lane = h.parallel_final_lane.lower() |
| self.parallel_post_lambdas = nn.Parameter( |
| torch.ones(h.num_layers, 2, 2, dtype=torch.float32) |
| ) |
| self.parallel_resid_lambdas = nn.Parameter( |
| torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) |
| ) |
| |
| |
| |
| |
| self.smear_gate_enabled = h.smear_gate_enabled |
| if self.smear_gate_enabled: |
| self.smear_window = h.gate_window |
| self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) |
| self.smear_gate._zero_init = True |
| self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| if self.tie_embeddings: |
| nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) |
| n = self.num_layers |
| proj_scale = 1.0 / math.sqrt(2 * n) |
| for i in range(n): |
| nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) |
| nn.init.zeros_(self.qo_bank.data[n + i]) |
| self.qo_bank.data[n + i].mul_(proj_scale) |
| nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) |
| nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) |
| for i in range(n): |
| nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) |
| nn.init.zeros_(self.mlp_down_bank.data[i]) |
| self.mlp_down_bank.data[i].mul_(proj_scale) |
| for name, module in self.named_modules(): |
| if isinstance(module, nn.Linear): |
| if getattr(module, "_zero_init", False): |
| nn.init.zeros_(module.weight) |
| elif ( |
| module.weight.ndim == 2 |
| and module.weight.shape[0] >= 64 |
| and module.weight.shape[1] >= 64 |
| ): |
| nn.init.orthogonal_(module.weight, gain=1.0) |
|
|
| def _bank_weights(self, i): |
| n = self.num_layers |
| return ( |
| self.qo_bank[i], |
| self.kv_bank[i], |
| self.kv_bank[n + i], |
| self.qo_bank[n + i], |
| self.mlp_up_bank[i], |
| self.mlp_down_bank[i], |
| ) |
|
|
| def _parallel_block( |
| self, block_idx, lane0, lane1, x0, |
| q_w, k_w, v_w, out_w, up_w, down_w, |
| cu_seqlens=None, max_seqlen=0, |
| ): |
| block = self.blocks[block_idx] |
| mix = block.resid_mix.to(dtype=lane0.dtype) |
| attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 |
| attn_out = block.attn( |
| block.attn_norm(attn_read) * block.ln_scale_factor, |
| q_w, k_w, v_w, out_w, |
| cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, |
| ) |
| attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out |
| mlp_read = lane1 |
| mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( |
| block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w |
| ) |
| attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) |
| attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) |
| mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) |
| mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) |
| lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out |
| lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out |
| return lane0, lane1 |
|
|
| def _final_parallel_hidden(self, lane0, lane1): |
| if self.parallel_final_lane == "mlp": |
| return lane1 |
| if self.parallel_final_lane == "attn": |
| return lane0 |
| return 0.5 * (lane0 + lane1) |
|
|
| def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): |
| """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. |
| Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" |
| x = self.tok_emb(input_ids) |
| |
| |
| |
| |
| if self.smear_gate_enabled: |
| sl = self.smear_lambda.to(dtype=x.dtype) |
| gate_in = x[:, 1:, : self.smear_window].contiguous() |
| g = sl * torch.sigmoid(self.smear_gate(gate_in)) |
| not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) |
| x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) |
| x = F.rms_norm(x, (x.size(-1),)) |
| x0 = x |
| skips = [] |
| enc_iter = ( |
| self.encoder_indices |
| if self.looping_active |
| else range(self.num_encoder_layers) |
| ) |
| dec_iter = ( |
| self.decoder_indices |
| if self.looping_active |
| else range( |
| self.num_encoder_layers, |
| self.num_encoder_layers + self.num_decoder_layers, |
| ) |
| ) |
| for i in enc_iter: |
| q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) |
| x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
| skips.append(x) |
| psl = self.parallel_start_layer |
| lane0 = None |
| lane1 = None |
| for skip_idx, i in enumerate(dec_iter): |
| q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) |
| if i >= psl and psl > 0: |
| if lane0 is None: |
| lane0 = x |
| lane1 = x |
| if skip_idx < self.num_skip_weights and skips: |
| skip = skips.pop() |
| w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] |
| if self.skip_gates is not None: |
| g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] |
| lane0 = torch.lerp(w * skip, lane0, g) |
| else: |
| lane0 = lane0 + w * skip |
| lane0, lane1 = self._parallel_block( |
| i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, |
| cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, |
| ) |
| else: |
| if skip_idx < self.num_skip_weights and skips: |
| scaled_skip = ( |
| self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] |
| * skips.pop() |
| ) |
| if self.skip_gates is not None: |
| g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] |
| x = torch.lerp(scaled_skip, x, g) |
| else: |
| x = x + scaled_skip |
| x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
| if lane0 is not None: |
| x = self._final_parallel_hidden(lane0, lane1) |
| x = self.final_norm(x) |
| return x |
|
|
| def _project_logits(self, hidden): |
| if self.tie_embeddings: |
| return F.linear(hidden, self.tok_emb.weight) |
| return self.lm_head(hidden) |
|
|
| def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): |
| hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
| logits_proj = self._project_logits(hidden) |
| return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) |
|
|
| def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): |
| hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
| logits_proj = self._project_logits(hidden) |
| flat_targets = target_ids.reshape(-1) |
| |
| |
| |
| if self.fused_ce_enabled: |
| return softcapped_cross_entropy( |
| logits_proj.reshape(-1, logits_proj.size(-1)), |
| flat_targets, |
| self.logit_softcap, |
| reduction="mean", |
| ) |
| logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) |
| return F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)).float(), |
| flat_targets, |
| reduction="mean", |
| ) |
|
|
| def forward_ttt(self, input_ids, target_ids, lora): |
| x = self.tok_emb(input_ids) |
| |
| |
| if self.smear_gate_enabled: |
| sl = self.smear_lambda.to(dtype=x.dtype) |
| gate_in = x[:, 1:, : self.smear_window].contiguous() |
| g = sl * torch.sigmoid(self.smear_gate(gate_in)) |
| not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) |
| x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) |
| x = F.rms_norm(x, (x.size(-1),)) |
| x0 = x |
| skips = [] |
| enc_iter = ( |
| self.encoder_indices |
| if self.looping_active |
| else list(range(self.num_encoder_layers)) |
| ) |
| dec_iter = ( |
| self.decoder_indices |
| if self.looping_active |
| else list( |
| range( |
| self.num_encoder_layers, |
| self.num_encoder_layers + self.num_decoder_layers, |
| ) |
| ) |
| ) |
| slot = 0 |
| for i in enc_iter: |
| q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) |
| x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) |
| slot += 1 |
| skips.append(x) |
| psl = self.parallel_start_layer |
| lane0 = None |
| lane1 = None |
| for skip_idx, i in enumerate(dec_iter): |
| q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) |
| if i >= psl and psl > 0: |
| if lane0 is None: |
| lane0 = x |
| lane1 = x |
| if skip_idx < self.num_skip_weights and skips: |
| skip = skips.pop() |
| w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] |
| if self.skip_gates is not None: |
| g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] |
| lane0 = torch.lerp(w * skip, lane0, g) |
| else: |
| lane0 = lane0 + w * skip |
| lane0, lane1 = self._parallel_block_with_lora( |
| i, lane0, lane1, x0, lora, slot, |
| q_w, k_w, v_w, out_w, up_w, down_w, |
| ) |
| else: |
| if skip_idx < self.num_skip_weights and skips: |
| scaled_skip = ( |
| self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] |
| * skips.pop() |
| ) |
| if self.skip_gates is not None: |
| g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] |
| x = torch.lerp(scaled_skip, x, g) |
| else: |
| x = x + scaled_skip |
| x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) |
| slot += 1 |
| if lane0 is not None: |
| x = self._final_parallel_hidden(lane0, lane1) |
| x = self.final_norm(x) |
| if self.tie_embeddings: |
| logits = F.linear(x, self.tok_emb.weight) |
| else: |
| logits = self.lm_head(x) |
| logits = logits + lora.lm_head_lora(x) |
| logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) |
| bsz, sl, V = logits.shape |
| return F.cross_entropy( |
| logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" |
| ).reshape(bsz, sl) |
|
|
| def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): |
| mix = block.resid_mix.to(dtype=x.dtype) |
| x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 |
| n = block.attn_norm(x_in) * block.ln_scale_factor |
| attn = block.attn |
| bsz, seqlen, dim = n.shape |
| |
| q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) |
| q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) |
| k = F.linear(n, k_w.to(n.dtype)) |
| if lora.k_loras is not None: |
| k = k + lora.k_loras[slot](n) |
| k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) |
| v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( |
| bsz, seqlen, attn.num_kv_heads, attn.head_dim |
| ) |
| q = F.rms_norm(q, (q.size(-1),)) |
| k = F.rms_norm(k, (k.size(-1),)) |
| cos, sin = attn.rotary(seqlen, n.device, q.dtype) |
| q = apply_rotary_emb(q, cos, sin, attn.rope_dims) |
| k = apply_rotary_emb(k, cos, sin, attn.rope_dims) |
| q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] |
| y = flash_attn_3_func(q, k, v, causal=True) |
| if attn.use_xsa: |
| y = attn._xsa_efficient(y, v) |
| |
| if attn.attn_out_gate: |
| gate_src = q_raw if attn.attn_out_gate_src == "q" else n |
| gate_in = gate_src[..., : attn.gate_window].contiguous() |
| g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) |
| y = y * g[..., None] |
| |
| |
| if attn.gated_attn: |
| n_c = n.contiguous() |
| g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) |
| y = y * g[..., None] |
| |
| |
| |
| if attn.sparse_attn_gate: |
| gate_in = n[..., : attn.gate_window].contiguous() |
| g = torch.sigmoid( |
| attn.sparse_attn_gate_scale |
| * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) |
| ) |
| y = y * g[..., None] |
| y = y.reshape(bsz, seqlen, dim) |
| attn_out = F.linear(y, out_w.to(n.dtype)) |
| if lora.o_loras is not None: |
| attn_out = attn_out + lora.o_loras[slot](n) |
| x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out |
| mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor |
| mlp_out = block.mlp(mlp_n, up_w, down_w) |
| if lora.mlp_loras is not None: |
| mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) |
| x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out |
| return x_out |
|
|
| def _parallel_block_with_lora( |
| self, block_idx, lane0, lane1, x0, lora, slot, |
| q_w, k_w, v_w, out_w, up_w, down_w, |
| ): |
| block = self.blocks[block_idx] |
| mix = block.resid_mix.to(dtype=lane0.dtype) |
| attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 |
| n = block.attn_norm(attn_read) * block.ln_scale_factor |
| attn = block.attn |
| bsz, seqlen, dim = n.shape |
| q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) |
| q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) |
| k = F.linear(n, k_w.to(n.dtype)) |
| if lora.k_loras is not None: |
| k = k + lora.k_loras[slot](n) |
| k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) |
| v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( |
| bsz, seqlen, attn.num_kv_heads, attn.head_dim |
| ) |
| q = F.rms_norm(q, (q.size(-1),)) |
| k = F.rms_norm(k, (k.size(-1),)) |
| cos, sin = attn.rotary(seqlen, n.device, q.dtype) |
| q = apply_rotary_emb(q, cos, sin, attn.rope_dims) |
| k = apply_rotary_emb(k, cos, sin, attn.rope_dims) |
| q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] |
| y = flash_attn_3_func(q, k, v, causal=True) |
| if attn.use_xsa: |
| y = attn._xsa_efficient(y, v) |
| |
| if attn.attn_out_gate: |
| gate_src = q_raw if attn.attn_out_gate_src == "q" else n |
| gate_in = gate_src[..., : attn.gate_window].contiguous() |
| g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) |
| y = y * g[..., None] |
| |
| if attn.gated_attn: |
| n_c = n.contiguous() |
| g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) |
| y = y * g[..., None] |
| |
| |
| if attn.sparse_attn_gate: |
| gate_in = n[..., : attn.gate_window].contiguous() |
| g = torch.sigmoid( |
| attn.sparse_attn_gate_scale |
| * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) |
| ) |
| y = y * g[..., None] |
| y = y.reshape(bsz, seqlen, dim) |
| attn_out = F.linear(y, out_w.to(n.dtype)) |
| if lora.o_loras is not None: |
| attn_out = attn_out + lora.o_loras[slot](n) |
| attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out |
| mlp_read = lane1 |
| mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor |
| mlp_out = block.mlp(mlp_n, up_w, down_w) |
| if lora.mlp_loras is not None: |
| mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) |
| mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out |
| attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) |
| attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) |
| mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) |
| mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) |
| lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out |
| lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out |
| return lane0, lane1 |
|
|
|
|
| class BatchedLinearLoRA(nn.Module): |
| |
| |
| _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) |
| |
| |
| _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) |
|
|
| def __init__(self, bsz, in_features, out_features, rank): |
| super().__init__() |
| self._bound = 1.0 / math.sqrt(in_features) |
| self._scale = self._ALPHA / rank |
| self.A = nn.Parameter( |
| torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) |
| ) |
| self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) |
|
|
| def reset(self): |
| with torch.no_grad(): |
| if not self._WARM_START_A: |
| self.A.uniform_(-self._bound, self._bound) |
| self.B.zero_() |
|
|
| def forward(self, x): |
| return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale |
|
|
|
|
| class BatchedTTTLoRA(nn.Module): |
| def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): |
| super().__init__() |
| self.bsz = bsz |
| dim = model.qo_bank.shape[-1] |
| vocab = model.tok_emb.num_embeddings |
| if getattr(model, "looping_active", False): |
| num_slots = len(model.encoder_indices) + len(model.decoder_indices) |
| else: |
| num_slots = len(model.blocks) |
| kv_dim = model.blocks[0].attn.num_kv_heads * ( |
| dim // model.blocks[0].attn.num_heads |
| ) |
| embed_dim = model.tok_emb.embedding_dim |
| self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) |
| self.q_loras = nn.ModuleList( |
| [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] |
| ) |
| self.v_loras = nn.ModuleList( |
| [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] |
| ) |
| self.k_loras = ( |
| nn.ModuleList( |
| [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] |
| ) |
| if k_lora |
| else None |
| ) |
| self.mlp_loras = ( |
| nn.ModuleList( |
| [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] |
| ) |
| if mlp_lora |
| else None |
| ) |
| self.o_loras = ( |
| nn.ModuleList( |
| [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] |
| ) |
| if o_lora |
| else None |
| ) |
|
|
| def reset(self): |
| with torch.no_grad(): |
| self.lm_head_lora.reset() |
| for loras in [self.q_loras, self.v_loras, self.k_loras, |
| self.mlp_loras, self.o_loras]: |
| if loras is not None: |
| for lora in loras: |
| lora.reset() |
|
|
|
|
| |
| |
| |
| |
| _PE_COEFFS = ( |
| (8.156554524902461, -22.48329292557795, 15.878769915207462), |
| (4.042929935166739, -2.808917465908714, 0.5000178451051316), |
| (3.8916678022926607, -2.772484153217685, 0.5060648178503393), |
| (3.285753657755655, -2.3681294933425376, 0.46449024233003106), |
| (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), |
| ) |
|
|
|
|
| @torch.compile |
| def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): |
| was_2d = G.ndim == 2 |
| if was_2d: |
| G = G.unsqueeze(0) |
| X = G.bfloat16() |
| transposed = X.size(-2) > X.size(-1) |
| if transposed: |
| X = X.mT |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) |
| coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS |
| for a, b, c in coeffs: |
| A = X @ X.mT |
| B = b * A + c * (A @ A) |
| X = a * X + B @ X |
| if transposed: |
| X = X.mT |
| if was_2d: |
| X = X.squeeze(0) |
| return X |
|
|
|
|
| class Muon(torch.optim.Optimizer): |
| def __init__( |
| self, |
| params, |
| lr, |
| momentum, |
| backend_steps, |
| nesterov=True, |
| weight_decay=0.0, |
| row_normalize=False, |
| ): |
| super().__init__( |
| params, |
| dict( |
| lr=lr, |
| momentum=momentum, |
| backend_steps=backend_steps, |
| nesterov=nesterov, |
| weight_decay=weight_decay, |
| row_normalize=row_normalize, |
| ), |
| ) |
| self._built = False |
|
|
| def _build(self): |
| self._distributed = dist.is_available() and dist.is_initialized() |
| self._world_size = dist.get_world_size() if self._distributed else 1 |
| self._rank = dist.get_rank() if self._distributed else 0 |
| ws = self._world_size |
| self._bank_meta = [] |
| for group in self.param_groups: |
| for p in group["params"]: |
| B = p.shape[0] |
| padded_B = ((B + ws - 1) // ws) * ws |
| shard_B = padded_B // ws |
| tail = p.shape[1:] |
| dev = p.device |
| self._bank_meta.append({ |
| "p": p, |
| "B": B, |
| "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), |
| "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), |
| "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), |
| "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), |
| "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, |
| }) |
| self._bank_meta.sort(key=lambda m: -m["p"].numel()) |
| self._built = True |
|
|
| def launch_reduce_scatters(self): |
| if not self._built: |
| self._build() |
| if not self._distributed: |
| return |
| self._rs_futures = [] |
| for m in self._bank_meta: |
| p = m["p"] |
| if p.grad is None: |
| self._rs_futures.append(None) |
| continue |
| pg = m["padded_grad"] |
| pg[: m["B"]].copy_(p.grad) |
| fut = dist.reduce_scatter_tensor( |
| m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True |
| ) |
| self._rs_futures.append(fut) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| if not self._built: |
| self._build() |
| for group in self.param_groups: |
| lr = group["lr"] |
| momentum = group["momentum"] |
| backend_steps = group["backend_steps"] |
| nesterov = group["nesterov"] |
| wd = group.get("weight_decay", 0.0) |
| row_normalize = group.get("row_normalize", False) |
| prev_ag_handle = None |
| prev_m = None |
| sharded = self._distributed and hasattr(self, "_rs_futures") |
| for idx, m in enumerate(self._bank_meta): |
| p = m["p"] |
| if p.grad is None: |
| continue |
| if prev_ag_handle is not None: |
| prev_ag_handle.wait() |
| pp = prev_m["p"] |
| upd = prev_m["full_update"][: prev_m["B"]] |
| if wd > 0.0: |
| pp.data.mul_(1.0 - lr * wd) |
| pp.add_(upd, alpha=-lr * prev_m["scale"]) |
| if sharded and self._rs_futures[idx] is not None: |
| self._rs_futures[idx].wait() |
| g = m["shard"] |
| buf = m["shard_mom"] |
| else: |
| g = p.grad.bfloat16() |
| state = self.state[p] |
| if "momentum_buffer" not in state: |
| state["momentum_buffer"] = torch.zeros_like(g) |
| buf = state["momentum_buffer"] |
| buf.mul_(momentum).add_(g) |
| if nesterov: |
| update = g.add(buf, alpha=momentum) |
| else: |
| update = buf |
| if row_normalize: |
| rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) |
| update = update / rn.to(update.dtype) |
| update = zeropower_via_newtonschulz5(update, steps=backend_steps) |
| if sharded: |
| prev_ag_handle = dist.all_gather_into_tensor( |
| m["full_update"], update, async_op=True |
| ) |
| prev_m = m |
| else: |
| if wd > 0.0: |
| p.data.mul_(1.0 - lr * wd) |
| p.add_(update, alpha=-lr * m["scale"]) |
| if prev_ag_handle is not None: |
| prev_ag_handle.wait() |
| pp = prev_m["p"] |
| upd = prev_m["full_update"][: prev_m["B"]] |
| if wd > 0.0: |
| pp.data.mul_(1.0 - lr * wd) |
| pp.add_(upd, alpha=-lr * prev_m["scale"]) |
| if hasattr(self, "_rs_futures"): |
| del self._rs_futures |
| return loss |
|
|
|
|
| CONTROL_TENSOR_NAME_PATTERNS = tuple( |
| pattern |
| for pattern in os.environ.get( |
| "CONTROL_TENSOR_NAME_PATTERNS", |
| "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", |
| ).split(",") |
| if pattern |
| ) |
|
|
|
|
| PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 |
|
|
|
|
| class Optimizers: |
| def __init__(self, h, base_model): |
| matrix_params = [ |
| base_model.qo_bank, |
| base_model.kv_bank, |
| base_model.mlp_up_bank, |
| base_model.mlp_down_bank, |
| ] |
| block_named_params = list(base_model.blocks.named_parameters()) |
| scalar_params = [ |
| p |
| for (name, p) in block_named_params |
| if p.ndim < 2 |
| or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) |
| ] |
| if base_model.skip_weights.numel() > 0: |
| scalar_params.append(base_model.skip_weights) |
| if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: |
| scalar_params.append(base_model.skip_gates) |
| if base_model.parallel_post_lambdas is not None: |
| scalar_params.append(base_model.parallel_post_lambdas) |
| if base_model.parallel_resid_lambdas is not None: |
| scalar_params.append(base_model.parallel_resid_lambdas) |
| |
| |
| if getattr(base_model, "smear_gate_enabled", False): |
| scalar_params.append(base_model.smear_gate.weight) |
| scalar_params.append(base_model.smear_lambda) |
| token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr |
| tok_params = [ |
| {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} |
| ] |
| self.optimizer_tok = torch.optim.AdamW( |
| tok_params, |
| betas=(h.beta1, h.beta2), |
| eps=h.adam_eps, |
| weight_decay=h.embed_wd, |
| fused=True, |
| ) |
| self.optimizer_muon = Muon( |
| matrix_params, |
| lr=h.matrix_lr, |
| momentum=h.muon_momentum, |
| backend_steps=h.muon_backend_steps, |
| weight_decay=h.muon_wd, |
| row_normalize=h.muon_row_normalize, |
| ) |
| for group in self.optimizer_muon.param_groups: |
| group["base_lr"] = h.matrix_lr |
| self.optimizer_scalar = torch.optim.AdamW( |
| [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], |
| betas=(h.beta1, h.beta2), |
| eps=h.adam_eps, |
| weight_decay=h.adam_wd, |
| fused=True, |
| ) |
| self.optimizers = [ |
| self.optimizer_tok, |
| self.optimizer_muon, |
| self.optimizer_scalar, |
| ] |
| self.replicated_params = list(tok_params[0]["params"]) |
| self.replicated_params.extend(scalar_params) |
| self.replicated_large_params = [] |
| self.replicated_packed_params = [] |
| for p in self.replicated_params: |
| if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: |
| self.replicated_packed_params.append(p) |
| else: |
| self.replicated_large_params.append(p) |
| self._aux_stream = torch.cuda.Stream() |
|
|
| def __iter__(self): |
| return iter(self.optimizers) |
|
|
| def zero_grad_all(self): |
| for opt in self.optimizers: |
| opt.zero_grad(set_to_none=True) |
|
|
| def _all_reduce_packed_grads(self): |
| grads_by_key = collections.defaultdict(list) |
| for p in self.replicated_packed_params: |
| if p.grad is not None: |
| grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) |
| for grads in grads_by_key.values(): |
| flat = torch.empty( |
| sum(g.numel() for g in grads), |
| device=grads[0].device, |
| dtype=grads[0].dtype, |
| ) |
| offset = 0 |
| for g in grads: |
| n = g.numel() |
| flat[offset : offset + n].copy_(g.contiguous().view(-1)) |
| offset += n |
| dist.all_reduce(flat, op=dist.ReduceOp.AVG) |
| offset = 0 |
| for g in grads: |
| n = g.numel() |
| g.copy_(flat[offset : offset + n].view_as(g)) |
| offset += n |
|
|
| def step(self, distributed=False): |
| self.optimizer_muon.launch_reduce_scatters() |
| if distributed: |
| reduce_handles = [ |
| dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) |
| for p in self.replicated_large_params |
| if p.grad is not None |
| ] |
| self._all_reduce_packed_grads() |
| for handle in reduce_handles: |
| handle.wait() |
| self._aux_stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(self._aux_stream): |
| self.optimizer_tok.step() |
| self.optimizer_scalar.step() |
| self.optimizer_muon.step() |
| torch.cuda.current_stream().wait_stream(self._aux_stream) |
| self.zero_grad_all() |
|
|
|
|
| def restore_fp32_params(model): |
| for module in model.modules(): |
| if isinstance(module, CastedLinear): |
| module.float() |
| for name, param in model.named_parameters(): |
| if ( |
| param.ndim < 2 |
| or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) |
| ) and param.dtype != torch.float32: |
| param.data = param.data.float() |
| if hasattr(model, "qo_bank") and model.qo_bank is not None: |
| model.qo_bank.data = model.qo_bank.data.float() |
| model.kv_bank.data = model.kv_bank.data.float() |
| model.mlp_up_bank.data = model.mlp_up_bank.data.float() |
| model.mlp_down_bank.data = model.mlp_down_bank.data.float() |
|
|
|
|
| def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): |
| hessians = {} |
| hooks = [] |
| for i, block in enumerate(model.blocks): |
| block.attn._calib = True |
| block.mlp._calib = True |
| block.mlp.use_fused = False |
|
|
| def make_attn_hook(layer_idx): |
| def hook_fn(module, inp, out): |
| x = inp[0].detach().float() |
| if x.ndim == 3: |
| x = x.reshape(-1, x.shape[-1]) |
| for suffix in ["c_q", "c_k", "c_v"]: |
| name = f"blocks.{layer_idx}.attn.{suffix}.weight" |
| if name not in hessians: |
| hessians[name] = torch.zeros( |
| x.shape[1], x.shape[1], dtype=torch.float32, device=device |
| ) |
| hessians[name].addmm_(x.T, x) |
| y = module._last_proj_input |
| if y is not None: |
| y = y.float() |
| if y.ndim == 3: |
| y = y.reshape(-1, y.shape[-1]) |
| name = f"blocks.{layer_idx}.attn.proj.weight" |
| if name not in hessians: |
| hessians[name] = torch.zeros( |
| y.shape[1], y.shape[1], dtype=torch.float32, device=device |
| ) |
| hessians[name].addmm_(y.T, y) |
| return hook_fn |
|
|
| def make_mlp_hook(layer_idx): |
| def hook_fn(module, inp, out): |
| x = inp[0].detach().float() |
| if x.ndim == 3: |
| x = x.reshape(-1, x.shape[-1]) |
| name = f"blocks.{layer_idx}.mlp.fc.weight" |
| if name not in hessians: |
| hessians[name] = torch.zeros( |
| x.shape[1], x.shape[1], dtype=torch.float32, device=device |
| ) |
| hessians[name].addmm_(x.T, x) |
| h_act = module._last_down_input |
| if h_act is not None: |
| h_act = h_act.float() |
| if h_act.ndim == 3: |
| h_act = h_act.reshape(-1, h_act.shape[-1]) |
| name = f"blocks.{layer_idx}.mlp.proj.weight" |
| if name not in hessians: |
| hessians[name] = torch.zeros( |
| h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device |
| ) |
| hessians[name].addmm_(h_act.T, h_act) |
| return hook_fn |
|
|
| for i, block in enumerate(model.blocks): |
| hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) |
| hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) |
|
|
| |
| def make_linear_input_hook(weight_name): |
| def hook_fn(module, inp, out): |
| x = inp[0].detach().float() |
| if x.ndim == 3: |
| x = x.reshape(-1, x.shape[-1]) |
| if weight_name not in hessians: |
| hessians[weight_name] = torch.zeros( |
| x.shape[1], x.shape[1], dtype=torch.float32, device=device |
| ) |
| hessians[weight_name].addmm_(x.T, x) |
| return hook_fn |
|
|
| if model.tie_embeddings: |
| hook_module = model.final_norm |
|
|
| def make_output_hook(name): |
| def hook_fn(module, inp, out): |
| x = out.detach().float() |
| if x.ndim == 3: |
| x = x.reshape(-1, x.shape[-1]) |
| if name not in hessians: |
| hessians[name] = torch.zeros( |
| x.shape[1], x.shape[1], dtype=torch.float32, device=device |
| ) |
| hessians[name].addmm_(x.T, x) |
| return hook_fn |
|
|
| hooks.append( |
| hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) |
| ) |
| model.eval() |
| with torch.no_grad(): |
| for _ in range(n_calibration_batches): |
| x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) |
| model.forward_logits(x) |
| for hook in hooks: |
| hook.remove() |
| for i, block in enumerate(model.blocks): |
| block.attn._calib = False |
| block.mlp._calib = False |
| block.mlp.use_fused = True |
| for name in hessians: |
| hessians[name] = hessians[name].cpu() / n_calibration_batches |
| return hessians |
|
|
|
|
| def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): |
| W_orig = w.float().clone() |
| rows, cols = W_orig.shape |
| H = H.float().clone() |
| dead = torch.diag(H) == 0 |
| H[dead, dead] = 1 |
| damp = 0.01 * H.diag().mean() |
| H.diagonal().add_(damp) |
| perm = torch.argsort(H.diag(), descending=True) |
| invperm = torch.argsort(perm) |
| W_perm = W_orig[:, perm].clone() |
| W_perm[:, dead[perm]] = 0 |
| H = H[perm][:, perm] |
| Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) |
| Hinv = torch.linalg.cholesky(Hinv, upper=True) |
| row_std = W_orig.std(dim=1) |
| s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) |
| sf = s.float() |
| Q = torch.zeros(rows, cols, dtype=torch.int8) |
| W_work = W_perm.clone() |
| for i1 in range(0, cols, block_size): |
| i2 = min(i1 + block_size, cols) |
| W_block = W_work[:, i1:i2].clone() |
| Hinv_block = Hinv[i1:i2, i1:i2] |
| Err = torch.zeros(rows, i2 - i1) |
| for j in range(i2 - i1): |
| w_col = W_block[:, j] |
| d = Hinv_block[j, j] |
| q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) |
| Q[:, i1 + j] = q_col.to(torch.int8) |
| err = (w_col - q_col.float() * sf) / d |
| Err[:, j] = err |
| W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) |
| if i2 < cols: |
| W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] |
| return Q[:, invperm], s |
|
|
|
|
| def _quantize_gate_int8_row(w): |
| |
| |
| |
| W = w.float().contiguous() |
| row_max = W.abs().amax(dim=1).clamp_min(1e-10) |
| s = (row_max / 127.0).to(torch.float16) |
| sf = s.float().view(-1, 1) |
| q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) |
| return q, s |
|
|
|
|
| def _lqer_pack(A, B, bits): |
| rng = 2 ** (bits - 1) - 1 |
| sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) |
| sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) |
| qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) |
| qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) |
| return qA, sA, qB, sB |
|
|
|
|
| def _lqer_pack_asym(A, B, g=64): |
| |
| sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) |
| qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) |
| |
| Bf = B.reshape(-1, g) |
| Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) |
| sB = (Bmax / 7.5).to(torch.float16).reshape(-1) |
| qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( |
| torch.int8 |
| ).reshape(B.shape) |
| return qA, sA, qB, sB |
|
|
|
|
| def gptq_mixed_quantize(state_dict, hessians, h): |
| result = {} |
| meta = {} |
| quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) |
| lqer_on = bool(getattr(h, "lqer_enabled", False)) |
| lqer_cands = {} |
| for (name, tensor) in state_dict.items(): |
| t = tensor.detach().cpu().contiguous() |
| |
| |
| |
| if ( |
| quant_gate |
| and t.is_floating_point() |
| and t.ndim == 2 |
| and name.endswith(".attn_gate_w") |
| |
| |
| |
| |
| and 32 <= t.numel() <= 8192 |
| ): |
| gq, gs = _quantize_gate_int8_row(t) |
| result[name + ".gq"] = gq |
| result[name + ".gs"] = gs |
| meta[name] = "gate_int8_row" |
| continue |
| if not t.is_floating_point() or t.numel() <= 65536: |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t |
| meta[name] = "passthrough (float16)" |
| continue |
| if "tok_emb" in name: |
| cs = h.embed_clip_sigmas |
| elif ".mlp." in name: |
| cs = h.mlp_clip_sigmas |
| elif ".attn." in name: |
| cs = h.attn_clip_sigmas |
| else: |
| cs = h.matrix_clip_sigmas |
| bits = h.embed_bits if "tok_emb" in name else h.matrix_bits |
| clip_range = 2 ** (bits - 1) - 1 |
| ret = gptq_quantize_weight( |
| t, hessians[name], clip_sigmas=cs, clip_range=clip_range |
| ) |
| q, s = ret |
| result[name + ".q"] = q |
| result[name + ".scale"] = s |
| meta[name] = f"gptq (int{bits})" |
| if lqer_on: |
| W_q = q.float() * s.float().view(-1, 1) |
| E = t.float() - W_q |
| lqer_cands[name] = (E, float(E.norm())) |
| if lqer_on and lqer_cands: |
| top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] |
| asym_on = bool(getattr(h, "lqer_asym_enabled", False)) |
| asym_g = int(getattr(h, "lqer_asym_group", 64)) |
| for (name, (E, _)) in top: |
| U, S, Vh = torch.linalg.svd(E, full_matrices=False) |
| r = min(h.lqer_rank, S.numel()) |
| A = (U[:, :r] * S[:r]).contiguous() |
| B = Vh[:r, :].contiguous() |
| if asym_on and B.numel() % asym_g == 0: |
| qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) |
| result[name + ".lqA_a"] = qA |
| result[name + ".lqAs_a"] = sA |
| result[name + ".lqB_a"] = qB |
| result[name + ".lqBs_a"] = sB |
| meta[name] = meta[name] + "+lqer_asym" |
| else: |
| qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) |
| result[name + ".lqA"] = qA |
| result[name + ".lqAs"] = sA |
| result[name + ".lqB"] = qB |
| result[name + ".lqBs"] = sB |
| meta[name] = meta[name] + "+lqer" |
| categories = collections.defaultdict(set) |
| for (name, cat) in meta.items(): |
| short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) |
| categories[cat].add(short) |
| log("Quantized weights:") |
| for cat in sorted(categories): |
| log(f" {cat}: {', '.join(sorted(categories[cat]))}") |
| return result, meta |
|
|
| def dequantize_mixed(result, meta, template_sd): |
| out = {} |
| for (name, orig) in template_sd.items(): |
| info = meta.get(name) |
| if info is None: |
| continue |
| orig_dtype = orig.dtype |
| if "passthrough" in info: |
| t = result[name] |
| if t.dtype == torch.float16 and orig_dtype in ( |
| torch.float32, |
| torch.bfloat16, |
| ): |
| t = t.to(orig_dtype) |
| out[name] = t |
| continue |
| if info == "gate_int8_row": |
| gq = result[name + ".gq"] |
| gs = result[name + ".gs"] |
| out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) |
| continue |
| q, s = result[name + ".q"], result[name + ".scale"] |
| if s.ndim > 0: |
| W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) |
| else: |
| W = q.float() * float(s.item()) |
| if "lqer_asym" in info: |
| qA_t = result[name + ".lqA_a"] |
| sA_t = result[name + ".lqAs_a"] |
| qB_t = result[name + ".lqB_a"] |
| sB_t = result[name + ".lqBs_a"] |
| qA = qA_t.float() * float(sA_t) |
| g_sz = qB_t.numel() // sB_t.numel() |
| qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( |
| qB_t.shape |
| ) |
| W = W + qA @ qB |
| elif "lqer" in info: |
| qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) |
| qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) |
| W = W + qA @ qB |
| out[name] = W.to(orig_dtype) |
| return out |
|
|
|
|
| _BSHF_MAGIC = b"BSHF" |
|
|
|
|
| |
|
|
| _GROUP_ORDER = [ |
| "_tok_emb.weight.q", |
| "attn.c_k.weight.q", "attn.c_q.weight.q", |
| "attn.c_v.weight.q", "attn.proj.weight.q", |
| "mlp.fc.weight.q", "mlp.proj.weight.q", |
| ] |
| _SIMSORT_KEYS = {"_tok_emb.weight.q", "attn.c_q.weight.q", "mlp.fc.weight.q"} |
| _PACK_MAGIC = b"PGRP" |
|
|
|
|
| def _similarity_sort_l1(matrix): |
| import numpy as _np |
| n = matrix.shape[0] |
| used = _np.zeros(n, dtype=bool) |
| order = [0] |
| used[0] = True |
| cur = matrix[0].astype(_np.float32) |
| for _ in range(n - 1): |
| dists = _np.sum(_np.abs(matrix[~used].astype(_np.float32) - cur), axis=1) |
| unused = _np.where(~used)[0] |
| best = unused[_np.argmin(dists)] |
| order.append(best) |
| used[best] = True |
| cur = matrix[best].astype(_np.float32) |
| return _np.array(order, dtype=_np.uint16) |
|
|
|
|
| def _lrzip_compress(data, tmpdir, label): |
| inp = os.path.join(tmpdir, f"{label}.bin") |
| out = f"{inp}.lrz" |
| with open(inp, "wb") as f: |
| f.write(data) |
| subprocess.run(["lrzip", "-z", "-L", "9", "-o", out, inp], capture_output=True, check=True) |
| with open(out, "rb") as f: |
| result = f.read() |
| os.remove(inp); os.remove(out) |
| return result |
|
|
|
|
| def _lrzip_decompress(data, tmpdir, label): |
| inp = os.path.join(tmpdir, f"{label}.lrz") |
| out = os.path.join(tmpdir, f"{label}.bin") |
| with open(inp, "wb") as f: |
| f.write(data) |
| subprocess.run(["lrzip", "-d", "-f", "-o", out, inp], capture_output=True, check=True) |
| with open(out, "rb") as f: |
| result = f.read() |
| os.remove(inp); os.remove(out) |
| return result |
|
|
|
|
| def _pack_streams(streams): |
| import struct |
| n = len(streams) |
| hdr = _PACK_MAGIC + struct.pack("<I", n) |
| for s in streams: |
| hdr += struct.pack("<I", len(s)) |
| return hdr + b"".join(streams) |
|
|
|
|
| def _unpack_streams(blob): |
| import struct |
| assert blob[:4] == _PACK_MAGIC |
| n = struct.unpack("<I", blob[4:8])[0] |
| off = 8 |
| lengths = [struct.unpack("<I", blob[off + i*4:off + i*4 + 4])[0] for i in range(n)] |
| off += n * 4 |
| streams = [] |
| for length in lengths: |
| streams.append(blob[off:off + length]) |
| off += length |
| return streams |
|
|
|
|
| def _compress(raw, compressor): |
| if compressor == "brotli": |
| import brotli |
| return brotli.compress(raw, quality=11) |
| if compressor == "lzma": |
| import lzma |
| return lzma.compress(raw, preset=9) |
| raise ValueError(f"unknown compressor {compressor!r}") |
|
|
|
|
| def _decompress(blob, compressor): |
| if compressor == "brotli": |
| import brotli |
| return brotli.decompress(blob) |
| if compressor == "lzma": |
| import lzma |
| return lzma.decompress(blob) |
| raise ValueError(f"unknown compressor {compressor!r}") |
|
|
|
|
| def _serialize_pergroup(quant_result, quant_meta, num_layers, tmpdir): |
| import brotli |
| import numpy as _np |
| groups = collections.defaultdict(list) |
| remainder = {} |
| for name, t in sorted(quant_result.items()): |
| if t.dtype != torch.int8: |
| remainder[name] = t |
| continue |
| parts = name.split(".") |
| routed = False |
| if parts[0] == "blocks" and parts[1].isdigit(): |
| key = ".".join(parts[2:]) |
| if key in _GROUP_ORDER: |
| groups[key].append((int(parts[1]), t)) |
| routed = True |
| else: |
| group_key = "_" + name |
| if group_key in _GROUP_ORDER: |
| groups[group_key] = [(0, t)] |
| routed = True |
| if not routed: |
| |
| |
| |
| remainder[name] = t |
|
|
| streams = [] |
| all_perms = b"" |
| shape_manifest = {} |
|
|
| for group_key in _GROUP_ORDER: |
| if group_key not in groups: |
| streams.append(b"") |
| continue |
| tensors = sorted(groups[group_key], key=lambda x: x[0]) |
| blob = b"" |
| grp_shapes = [] |
| for idx, t in tensors: |
| arr = t.numpy() |
| orig_shape = arr.shape |
| if arr.ndim == 2: |
| if group_key in _SIMSORT_KEYS: |
| order = _similarity_sort_l1(arr) |
| all_perms += order.tobytes() |
| arr = arr[order] |
| arr = _np.ascontiguousarray(arr.T) |
| blob += arr.tobytes() |
| grp_shapes.append(orig_shape) |
| shape_manifest[group_key] = grp_shapes |
| compressed = _lrzip_compress(blob, tmpdir, group_key.replace(".", "_")) |
| streams.append(compressed) |
|
|
| remainder_buf = io.BytesIO() |
| torch.save({"r": remainder, "m": quant_meta, "s": shape_manifest}, remainder_buf) |
| streams.append(brotli.compress(remainder_buf.getvalue(), quality=11, lgwin=24)) |
| streams.append(brotli.compress(all_perms, quality=11) if all_perms else b"") |
|
|
| return _pack_streams(streams) |
|
|
|
|
| def _deserialize_pergroup(blob, num_layers, tmpdir): |
| import brotli |
| import numpy as _np |
| streams = _unpack_streams(blob) |
| n_groups = len(_GROUP_ORDER) |
|
|
| remainder_state = torch.load( |
| io.BytesIO(brotli.decompress(streams[n_groups])), map_location="cpu" |
| ) |
| quant_meta = remainder_state["m"] |
| quant_result = dict(remainder_state["r"]) |
| shape_manifest = remainder_state["s"] |
| all_perms = brotli.decompress(streams[n_groups + 1]) if streams[n_groups + 1] else b"" |
|
|
| def _decompress_one(args): |
| i, gk, data = args |
| if not data: |
| return gk, b"" |
| return gk, _lrzip_decompress(data, tmpdir, f"d_{gk.replace('.', '_')}") |
|
|
| from concurrent.futures import ThreadPoolExecutor as _TPool |
| with _TPool(max_workers=n_groups) as pool: |
| futs = [pool.submit(_decompress_one, (i, gk, streams[i])) for i, gk in enumerate(_GROUP_ORDER)] |
| raw_groups = {f.result()[0]: f.result()[1] for f in futs} |
|
|
| perm_off = 0 |
| for group_key in _GROUP_ORDER: |
| raw = raw_groups.get(group_key, b"") |
| if not raw: |
| continue |
| grp_shapes = shape_manifest[group_key] |
| data_arr = _np.frombuffer(raw, dtype=_np.int8) |
|
|
| if group_key.startswith("_"): |
| tensor_names = [group_key[1:]] |
| else: |
| tensor_names = [f"blocks.{i}.{group_key}" for i in range(num_layers)] |
|
|
| offset = 0 |
| for tname, orig_shape in zip(tensor_names, grp_shapes): |
| n_elem = 1 |
| for d in orig_shape: |
| n_elem *= d |
| chunk = data_arr[offset:offset + n_elem].copy() |
| offset += n_elem |
|
|
| if len(orig_shape) == 2: |
| rows, cols = orig_shape |
| chunk = chunk.reshape(cols, rows).T |
|
|
| if group_key in _SIMSORT_KEYS: |
| perm = _np.frombuffer(all_perms[perm_off:perm_off + rows * 2], dtype=_np.uint16) |
| perm_off += rows * 2 |
| inv_perm = _np.empty_like(perm) |
| inv_perm[perm] = _np.arange(rows, dtype=_np.uint16) |
| chunk = chunk[inv_perm] |
|
|
| chunk = chunk.reshape(orig_shape) |
|
|
| quant_result[tname] = torch.from_numpy(_np.ascontiguousarray(chunk)) |
|
|
| return quant_result, quant_meta |
|
|
|
|
| def _unbank_state_dict(state_dict, num_layers): |
| sd = {} |
| n = num_layers |
| for k, v in state_dict.items(): |
| t = v.detach().cpu() if v is not None else None |
| if k == "qo_bank": |
| for i in range(n): |
| sd[f"blocks.{i}.attn.c_q.weight"] = t[i] |
| sd[f"blocks.{i}.attn.proj.weight"] = t[n + i] |
| elif k == "kv_bank": |
| for i in range(n): |
| sd[f"blocks.{i}.attn.c_k.weight"] = t[i] |
| sd[f"blocks.{i}.attn.c_v.weight"] = t[n + i] |
| elif k == "mlp_up_bank": |
| for i in range(n): |
| sd[f"blocks.{i}.mlp.fc.weight"] = t[i] |
| elif k == "mlp_down_bank": |
| for i in range(n): |
| sd[f"blocks.{i}.mlp.proj.weight"] = t[i] |
| else: |
| if t is not None: |
| sd[k] = t |
| return sd |
|
|
|
|
| def _rebank_state_dict(flat_sd, num_layers, model_dim, kv_dim, hidden_dim): |
| sd = {} |
| n = num_layers |
| sd["qo_bank"] = torch.zeros(2 * n, model_dim, model_dim) |
| sd["kv_bank"] = torch.zeros(2 * n, kv_dim, model_dim) |
| for i in range(n): |
| sd["qo_bank"][i] = flat_sd[f"blocks.{i}.attn.c_q.weight"] |
| sd["qo_bank"][n + i] = flat_sd[f"blocks.{i}.attn.proj.weight"] |
| sd["kv_bank"][i] = flat_sd[f"blocks.{i}.attn.c_k.weight"] |
| sd["kv_bank"][n + i] = flat_sd[f"blocks.{i}.attn.c_v.weight"] |
| sd["mlp_up_bank"] = torch.zeros(n, hidden_dim, model_dim) |
| sd["mlp_down_bank"] = torch.zeros(n, model_dim, hidden_dim) |
| for i in range(n): |
| sd["mlp_up_bank"][i] = flat_sd[f"blocks.{i}.mlp.fc.weight"] |
| sd["mlp_down_bank"][i] = flat_sd[f"blocks.{i}.mlp.proj.weight"] |
| for k, v in flat_sd.items(): |
| if not ( |
| k.startswith("blocks.") |
| and any( |
| p in k |
| for p in [ |
| ".attn.c_q.", ".attn.c_k.", ".attn.c_v.", |
| ".attn.proj.", ".mlp.fc.", ".mlp.proj.", |
| ] |
| ) |
| ): |
| sd[k] = v |
| return sd |
|
|
|
|
|
|
| def _compressed_code_size(code): |
| import brotli |
| code_raw = code.encode("utf-8") |
| try: |
| minified = subprocess.run( |
| ["pyminify", "--no-rename-locals", "--no-hoist-literals", "--remove-literal-statements", "--remove-asserts", "--prefer-single-line", "-"], |
| input=code_raw, capture_output=True, check=True, |
| ).stdout |
| except (FileNotFoundError, subprocess.CalledProcessError): |
| minified = code_raw |
| compressed = brotli.compress(minified, quality=11) |
| encoded = base64.b85encode(compressed) |
| wrapper = b"import brotli as B,base64 as b\nexec(B.decompress(b.b85decode(\"" + encoded + b"\")))\n" |
| return len(code_raw), len(wrapper) |
|
|
|
|
| def serialize(h, base_model, code): |
| code_bytes_uncompressed, code_bytes = _compressed_code_size(code) |
| if h.is_main_process: |
| torch.save(base_model.state_dict(), h.model_path) |
| model_bytes = os.path.getsize(h.model_path) |
| log(f"Serialized model: {model_bytes} bytes") |
| log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") |
| log(f"Code size (compressed): {code_bytes} bytes") |
| sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) |
| device = torch.device("cuda", h.local_rank) |
| t0 = time.perf_counter() |
| calib_loader = ShuffledSequenceLoader(h, device) |
| log("GPTQ:collecting Hessians from calibration data...") |
| hessians = collect_hessians( |
| base_model, |
| calib_loader, |
| h, |
| device, |
| n_calibration_batches=h.gptq_calibration_batches, |
| ) |
| log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") |
| quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) |
| if h.compressor == "pergroup": |
| import tempfile |
| tmpdir = tempfile.mkdtemp(prefix="pgrp_") |
| log("Serialize: per-group lrzip compression...") |
| t1 = time.perf_counter() |
| quant_blob = _serialize_pergroup(quant_result, quant_meta, h.num_layers, tmpdir) |
| log(f"Serialize: per-group compression done in {time.perf_counter()-t1:.1f}s") |
| try: |
| os.rmdir(tmpdir) |
| except OSError: |
| pass |
| else: |
| quant_buf = io.BytesIO() |
| torch.save({"w": quant_result, "m": quant_meta}, quant_buf) |
| quant_raw = quant_buf.getvalue() |
| quant_blob = _compress(quant_raw, h.compressor) |
| quant_file_bytes = len(quant_blob) |
| bytes_total = quant_file_bytes + code_bytes |
| if h.is_main_process: |
| with open(h.quantized_model_path, "wb") as f: |
| f.write(quant_blob) |
| log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") |
| log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") |
| return bytes_total, quant_file_bytes |
|
|
|
|
| def deserialize(h, device): |
| eval_model = GPT(h).to(device).bfloat16() |
| restore_fp32_params(eval_model) |
| flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) |
| with open(h.quantized_model_path, "rb") as f: |
| quant_blob_disk = f.read() |
| if quant_blob_disk[:4] == _PACK_MAGIC: |
| import tempfile |
| tmpdir = tempfile.mkdtemp(prefix="pgrp_dec_") |
| log("Deserialize: per-group lrzip decompression...") |
| t0 = time.perf_counter() |
| quant_result, quant_meta = _deserialize_pergroup( |
| quant_blob_disk, h.num_layers, tmpdir |
| ) |
| log(f"Deserialize: decompression done in {time.perf_counter()-t0:.1f}s") |
| try: |
| os.rmdir(tmpdir) |
| except OSError: |
| pass |
| else: |
| quant_state = torch.load( |
| io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" |
| ) |
| quant_result, quant_meta = quant_state["w"], quant_state["m"] |
| deq_flat = dequantize_mixed(quant_result, quant_meta, flat_template) |
| head_dim = h.model_dim // h.num_heads |
| kv_dim = h.num_kv_heads * head_dim |
| hidden_dim = int(h.mlp_mult * h.model_dim) |
| deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) |
| eval_model.load_state_dict(deq_state, strict=True) |
| return eval_model |
|
|
|
|
| def _loss_bpb(loss_sum, token_count, byte_count): |
| val_loss = (loss_sum / token_count).item() |
| val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) |
| return val_loss, val_bpb |
|
|
|
|
| def eval_val(h, device, val_data, model, forward_logits_fn=None): |
| seq_len = h.eval_seq_len |
| local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) |
| if local_batch_tokens < seq_len: |
| raise ValueError( |
| f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" |
| ) |
| local_batch_seqs = local_batch_tokens // seq_len |
| total_seqs = (val_data.val_tokens.numel() - 1) // seq_len |
| seq_start = total_seqs * h.rank // h.world_size |
| seq_end = total_seqs * (h.rank + 1) // h.world_size |
|
|
| |
| seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs |
|
|
| val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) |
| val_token_count = torch.zeros((), device=device, dtype=torch.float64) |
| val_byte_count = torch.zeros((), device=device, dtype=torch.float64) |
| run_forward_logits = ( |
| (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) |
| if forward_logits_fn is None |
| else forward_logits_fn |
| ) |
| model.eval() |
| global BOS_ID |
| if BOS_ID is None: |
| BOS_ID = 1 |
| with torch.no_grad(): |
| for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): |
| batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) |
| raw_start = batch_seq_start * seq_len |
| raw_end = batch_seq_end * seq_len + 1 |
| local = val_data.val_tokens[raw_start:raw_end].to( |
| device=device, dtype=torch.int64, non_blocking=True |
| ) |
| x = local[:-1] |
| y = local[1:] |
| bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() |
| cu_seqlens, max_seqlen = _build_cu_seqlens( |
| bos_pos, x.numel(), x.device, h.eval_seq_len, 64 |
| ) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| logits = run_forward_logits( |
| x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen |
| ).detach() |
| per_token_loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)).float(), |
| y.reshape(-1), |
| reduction="none", |
| ) |
| val_loss_sum += per_token_loss.to(torch.float64).sum() |
| val_token_count += float(y.numel()) |
| prev_ids = x |
| tgt_ids = y |
| sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( |
| device=device, dtype=torch.int32, non_blocking=True |
| ) |
| val_byte_count += sidecar_slice.to(torch.float64).sum() |
| if dist.is_available() and dist.is_initialized(): |
| dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) |
| dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) |
| dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) |
| model.train() |
| return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) |
|
|
|
|
| def _find_docs(all_tokens): |
| bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() |
| docs = [] |
| for i in range(len(bos_positions)): |
| start = int(bos_positions[i]) |
| end = ( |
| int(bos_positions[i + 1]) |
| if i + 1 < len(bos_positions) |
| else all_tokens.numel() |
| ) |
| if i + 1 < len(bos_positions): |
| end += 1 |
| assert end - start >= 2 |
| docs.append((start, end - start)) |
| return docs |
|
|
|
|
| def _build_ttt_global_batches(doc_entries, h, ascending=False): |
| batch_size = h.ttt_batch_size |
| global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) |
| global_batches = [ |
| global_doc_entries[i : i + batch_size] |
| for i in range(0, len(global_doc_entries), batch_size) |
| ] |
| indexed = list(enumerate(global_batches)) |
| if not ascending: |
| indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) |
| return indexed |
|
|
|
|
| def _init_batch_counter(path): |
| with open(path, "wb") as f: |
| f.write((0).to_bytes(4, "little")) |
|
|
|
|
| def _claim_next_batch(counter_path, queue_len): |
| try: |
| with open(counter_path, "r+b") as f: |
| fcntl.flock(f, fcntl.LOCK_EX) |
| idx = int.from_bytes(f.read(4), "little") |
| f.seek(0) |
| f.write((idx + 1).to_bytes(4, "little")) |
| f.flush() |
| except FileNotFoundError: |
| return queue_len |
| return idx |
|
|
|
|
| def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): |
| chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size |
| win_start = max(0, chunk_end - eval_seq_len) |
| win_len = chunk_end - win_start |
| chunk_start = ci * chunk_size |
| chunk_offset = chunk_start - win_start |
| chunk_len = chunk_end - chunk_start |
| return win_start, win_len, chunk_offset, chunk_len |
|
|
|
|
| def _accumulate_bpb( |
| ptl, |
| x, |
| y, |
| chunk_offsets, |
| chunk_lens, |
| pos_idx, |
| base_bytes_lut, |
| has_leading_space_lut, |
| is_boundary_token_lut, |
| loss_sum, |
| byte_sum, |
| token_count, |
| y_bytes=None, |
| ): |
| pos = pos_idx[: x.size(1)].unsqueeze(0) |
| mask = ( |
| (chunk_lens.unsqueeze(1) > 0) |
| & (pos >= chunk_offsets.unsqueeze(1)) |
| & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) |
| ) |
| mask_f64 = mask.to(torch.float64) |
| if y_bytes is not None: |
| tok_bytes = y_bytes.to(torch.float64) |
| else: |
| tok_bytes = base_bytes_lut[y].to(torch.float64) |
| tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( |
| torch.float64 |
| ) |
| loss_sum += (ptl.to(torch.float64) * mask_f64).sum() |
| byte_sum += (tok_bytes * mask_f64).sum() |
| token_count += chunk_lens.to(torch.float64).sum() |
|
|
|
|
| def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): |
| val_loss = (loss_sum / token_count).item() |
| val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) |
| return val_loss, val_bpb |
|
|
|
|
| def _add_to_counter(path, delta): |
| try: |
| with open(path, "r+b") as f: |
| fcntl.flock(f, fcntl.LOCK_EX) |
| cur = int.from_bytes(f.read(8), "little", signed=True) |
| cur += int(delta) |
| f.seek(0) |
| f.write(int(cur).to_bytes(8, "little", signed=True)) |
| f.flush() |
| return cur |
| except FileNotFoundError: |
| return int(delta) |
|
|
|
|
| def _init_int64_counter(path): |
| with open(path, "wb") as f: |
| f.write((0).to_bytes(8, "little", signed=True)) |
|
|
|
|
| def _select_ttt_doc_entries(docs, h): |
| doc_entries = list(enumerate(docs)) |
| if h.val_doc_fraction < 1.0: |
| sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) |
| sampled_indices = sorted( |
| random.Random(h.seed).sample(range(len(docs)), sample_n) |
| ) |
| return [(i, docs[i]) for i in sampled_indices] |
| return doc_entries |
|
|
|
|
| def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): |
| global BOS_ID |
| if BOS_ID is None: |
| BOS_ID = 1 |
| base_model.eval() |
| seq_len = h.eval_seq_len |
| total_tokens = val_tokens.numel() - 1 |
| ttt_chunk = h.global_ttt_chunk_tokens |
| batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs |
| num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk |
| ttt_params = [p for p in base_model.parameters()] |
| for p in ttt_params: |
| p.requires_grad_(True) |
| optimizer = torch.optim.SGD( |
| ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum |
| ) |
| t_start = time.perf_counter() |
| for ci in range(num_chunks): |
| chunk_start = ci * ttt_chunk |
| chunk_end = min((ci + 1) * ttt_chunk, total_tokens) |
| is_last_chunk = ci == num_chunks - 1 |
| if is_last_chunk or h.global_ttt_epochs <= 0: |
| continue |
| base_model.train() |
| chunk_seqs = (chunk_end - chunk_start) // seq_len |
| if chunk_seqs <= 0: |
| continue |
| warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) |
| if warmup_chunks > 0 and ci < warmup_chunks: |
| warmup_denom = max(warmup_chunks - 1, 1) |
| warmup_t = ci / warmup_denom |
| lr_now = ( |
| h.global_ttt_warmup_start_lr |
| + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t |
| ) |
| else: |
| decay_steps = max(num_chunks - 1 - warmup_chunks, 1) |
| decay_ci = max(ci - warmup_chunks, 0) |
| lr_now = h.global_ttt_lr * 0.5 * ( |
| 1.0 + math.cos(math.pi * decay_ci / decay_steps) |
| ) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr_now |
| my_seq_s = chunk_seqs * h.rank // h.world_size |
| my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size |
| my_chunk_seqs = my_seq_e - my_seq_s |
| for _ in range(h.global_ttt_epochs): |
| for bs in range(0, my_chunk_seqs, batch_seqs): |
| be = min(bs + batch_seqs, my_chunk_seqs) |
| actual_bs = my_seq_s + bs |
| start_tok = chunk_start + actual_bs * seq_len |
| end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 |
| if end_tok > val_tokens.numel(): |
| continue |
| local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) |
| x_flat = local[:-1] |
| y_flat = local[1:] |
| optimizer.zero_grad(set_to_none=True) |
| with torch.enable_grad(): |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| if h.global_ttt_respect_doc_boundaries: |
| bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() |
| cu_seqlens, max_seqlen = _build_cu_seqlens( |
| bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 |
| ) |
| loss = base_model( |
| x_flat[None], |
| y_flat[None], |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
| else: |
| x = x_flat.reshape(-1, seq_len) |
| y = y_flat.reshape(-1, seq_len) |
| loss = base_model(x, y) |
| loss.backward() |
| if dist.is_available() and dist.is_initialized(): |
| for p in ttt_params: |
| if p.grad is not None: |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) |
| p.grad.mul_(1.0 / h.world_size) |
| if h.global_ttt_grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) |
| optimizer.step() |
| base_model.eval() |
| if h.rank == 0: |
| elapsed = time.perf_counter() - t_start |
| log( |
| f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" |
| ) |
| for p in base_model.parameters(): |
| p.requires_grad_(True) |
| base_model.eval() |
|
|
|
|
| def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train): |
| global BOS_ID |
| if BOS_ID is None: |
| BOS_ID = 1 |
| base_model.eval() |
| for p in base_model.parameters(): |
| p.requires_grad_(False) |
| all_tokens = val_data.val_tokens |
| all_tokens_idx = all_tokens.to(torch.int32) |
| docs = _find_docs(all_tokens) |
| doc_entries = _select_ttt_doc_entries(docs, h) |
| prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) |
| num_phases = max(1, int(h.phased_ttt_num_phases)) |
| phase_boundaries = [] |
| for pi in range(num_phases): |
| boundary = prefix_doc_limit * (pi + 1) // num_phases |
| phase_boundaries.append(boundary) |
| current_phase = 0 |
| current_phase_boundary = phase_boundaries[0] |
| log( |
| "ttt_phased:" |
| f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " |
| f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" |
| f" num_phases:{num_phases} boundaries:{phase_boundaries}" |
| ) |
| chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len |
| eval_batch_set = None |
| if h.ttt_eval_batches: |
| eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) |
| use_ascending = eval_batch_set is not None |
| global_batches_sorted = _build_ttt_global_batches( |
| doc_entries, h, ascending=use_ascending |
| ) |
| queue_len = len(global_batches_sorted) |
| counter_path = f"/tmp/ttt_counter_{h.run_id}" |
| prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" |
| pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" |
| if h.rank == 0: |
| _init_batch_counter(counter_path) |
| _init_int64_counter(prefix_counter_path) |
| try: |
| os.remove(pause_flag_path) |
| except FileNotFoundError: |
| pass |
| if dist.is_available() and dist.is_initialized(): |
| path_list = [counter_path, prefix_counter_path, pause_flag_path] |
| dist.broadcast_object_list(path_list, src=0) |
| counter_path, prefix_counter_path, pause_flag_path = path_list |
| dist.barrier() |
| loss_sum = torch.zeros((), device=device, dtype=torch.float64) |
| byte_sum = torch.zeros((), device=device, dtype=torch.float64) |
| token_count = torch.zeros((), device=device, dtype=torch.float64) |
| t_start = time.perf_counter() |
| reusable_lora = BatchedTTTLoRA( |
| h.ttt_batch_size, base_model, h.ttt_lora_rank, |
| k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, |
| ).to(device) |
|
|
| def _build_opt(lora): |
| if h.ttt_optimizer == "sgd": |
| return torch.optim.SGD( |
| lora.parameters(), lr=h.ttt_lora_lr, |
| momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, |
| ) |
| return torch.optim.AdamW( |
| lora.parameters(), lr=h.ttt_lora_lr, |
| betas=(h.ttt_beta1, h.ttt_beta2), |
| eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, |
| ) |
|
|
| reusable_opt = _build_opt(reusable_lora) |
| local_scored_docs = [] |
| global_ttt_done = prefix_doc_limit == 0 |
| try: |
| while True: |
| queue_idx = _claim_next_batch(counter_path, queue_len) |
| if queue_idx >= queue_len: |
| break |
| orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] |
| batch = [doc for _, doc in batch_entries] |
| bsz = len(batch) |
| prev_loss = loss_sum.item() |
| prev_bytes = byte_sum.item() |
| prev_tokens = token_count.item() |
| if bsz == reusable_lora.bsz: |
| reusable_lora.reset() |
| for s in reusable_opt.state.values(): |
| for k, v in s.items(): |
| if isinstance(v, torch.Tensor): |
| v.zero_() |
| elif k == "step": |
| s[k] = 0 |
| cur_lora = reusable_lora |
| cur_opt = reusable_opt |
| else: |
| cur_lora = BatchedTTTLoRA( |
| bsz, base_model, h.ttt_lora_rank, |
| k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, |
| ).to(device) |
| cur_opt = _build_opt(cur_lora) |
| pred_lens = [doc_len - 1 for _, doc_len in batch] |
| num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] |
| max_nc = max(num_chunks) |
| num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) |
| for ci in range(max_nc): |
| active = [ci < nc for nc in num_chunks] |
| needs_train = any(ci < nc - 1 for nc in num_chunks) |
| tok_starts = torch.zeros(bsz, dtype=torch.int64) |
| tok_wls = torch.zeros(bsz, dtype=torch.int64) |
| chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) |
| chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) |
| for b in range(bsz): |
| if not active[b]: |
| continue |
| doc_start, doc_len = batch[b] |
| win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( |
| ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len |
| ) |
| tok_starts[b] = doc_start + win_start |
| tok_wls[b] = win_len |
| chunk_offsets_cpu[b] = chunk_offset |
| chunk_lens_cpu[b] = chunk_len |
| _, context_size, chunk_offset, _ = _compute_chunk_window( |
| ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len |
| ) |
| col_idx = torch.arange(context_size + 1) |
| idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) |
| idx.clamp_(max=all_tokens.numel() - 1) |
| gathered_gpu = all_tokens_idx[idx].to( |
| device=device, dtype=torch.int64, non_blocking=True |
| ) |
| valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( |
| device, non_blocking=True |
| ) |
| chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) |
| chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) |
| x = torch.where(valid, gathered_gpu[:, :context_size], 0) |
| y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) |
| ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) |
| |
| |
| |
| y_bytes_arg = None |
| if val_data.caseops_enabled and val_data.val_bytes is not None: |
| y_idx = ( |
| tok_starts.unsqueeze(1) |
| + 1 |
| + col_idx[:context_size].unsqueeze(0) |
| ) |
| y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) |
| y_bytes_arg = val_data.val_bytes[y_idx].to( |
| device=device, dtype=torch.int32, non_blocking=True |
| ) |
| |
| |
| y_bytes_arg = torch.where( |
| valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) |
| ) |
| with torch.no_grad(): |
| _accumulate_bpb( |
| per_tok_loss, |
| x, |
| y, |
| chunk_offsets, |
| chunk_lens, |
| ctx_pos, |
| val_data.base_bytes_lut, |
| val_data.has_leading_space_lut, |
| val_data.is_boundary_token_lut, |
| loss_sum, |
| byte_sum, |
| token_count, |
| y_bytes=y_bytes_arg, |
| ) |
| if needs_train: |
| activate_chunk_mask = (num_chunks_t - 1 > ci).float() |
| for gi in range(h.ttt_grad_steps): |
| if gi > 0: |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) |
| per_doc = per_tok_loss[ |
| :, chunk_offset : chunk_offset + chunk_size |
| ].mean(dim=-1) |
| cur_opt.zero_grad(set_to_none=True) |
| (per_doc * activate_chunk_mask).sum().backward() |
| cur_opt.step() |
| else: |
| del per_tok_loss |
| batch_num = orig_batch_idx + 1 |
| doc_lens = [dl for _, dl in batch] |
| should_report = batch_num in eval_batch_set if eval_batch_set is not None else True |
| if should_report: |
| cur_tokens = token_count.item() |
| cur_loss_val = loss_sum.item() |
| cur_bytes_val = byte_sum.item() |
| dt = cur_tokens - prev_tokens |
| db = cur_bytes_val - prev_bytes |
| if dt > 0 and db > 0: |
| b_loss = (cur_loss_val - prev_loss) / dt |
| b_bpb = b_loss / math.log(2.0) * (dt / db) |
| else: |
| b_loss = b_bpb = 0.0 |
| r_loss = cur_loss_val / max(cur_tokens, 1) |
| r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) |
| elapsed = time.perf_counter() - t_start |
| log( |
| f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " |
| f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " |
| f"gd:{int(global_ttt_done)}" |
| ) |
| if not global_ttt_done: |
| local_scored_docs.extend( |
| (orig_batch_idx, pos, doc_start, doc_len) |
| for pos, (doc_start, doc_len) in enumerate(batch) |
| ) |
| prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) |
| if prefix_done >= current_phase_boundary: |
| try: |
| with open(pause_flag_path, "x"): |
| pass |
| except FileExistsError: |
| pass |
| should_pause = os.path.exists(pause_flag_path) |
| if should_pause: |
| if dist.is_available() and dist.is_initialized(): |
| dist.barrier() |
| gathered_scored_docs = [None] * h.world_size |
| if dist.is_available() and dist.is_initialized(): |
| dist.all_gather_object(gathered_scored_docs, local_scored_docs) |
| else: |
| gathered_scored_docs = [local_scored_docs] |
| scored_docs_for_global = [] |
| for rank_docs in gathered_scored_docs: |
| if rank_docs: |
| scored_docs_for_global.extend(rank_docs) |
| scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) |
| scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] |
| scored_token_chunks = [ |
| val_data.val_tokens[doc_start : doc_start + doc_len] |
| for _, _, doc_start, doc_len in scored_docs_for_global |
| ] |
| if scored_token_chunks: |
| global_ttt_tokens = torch.cat(scored_token_chunks) |
| else: |
| global_ttt_tokens = val_data.val_tokens[:0] |
| if h.rank == 0: |
| prefix_done = 0 |
| try: |
| with open(prefix_counter_path, "rb") as f: |
| prefix_done = int.from_bytes( |
| f.read(8), "little", signed=True |
| ) |
| except FileNotFoundError: |
| pass |
| log( |
| f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " |
| f"gd:{len(scored_docs_for_global)} " |
| f"t:{time.perf_counter() - t_start:.1f}s" |
| ) |
| train_val_ttt_global_sgd_distributed( |
| h, device, val_data, base_model, global_ttt_tokens |
| ) |
| for p in base_model.parameters(): |
| p.requires_grad_(False) |
| reusable_lora = BatchedTTTLoRA( |
| h.ttt_batch_size, base_model, h.ttt_lora_rank, |
| k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, |
| ).to(device) |
| reusable_opt = _build_opt(reusable_lora) |
| current_phase += 1 |
| if current_phase >= num_phases: |
| global_ttt_done = True |
| else: |
| current_phase_boundary = phase_boundaries[current_phase] |
| if h.rank == 0: |
| try: |
| os.remove(pause_flag_path) |
| except FileNotFoundError: |
| pass |
| if dist.is_available() and dist.is_initialized(): |
| dist.barrier() |
| if h.rank == 0: |
| log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") |
| del cur_lora, cur_opt |
| finally: |
| pass |
| if dist.is_available() and dist.is_initialized(): |
| dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) |
| dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) |
| dist.all_reduce(token_count, op=dist.ReduceOp.SUM) |
| for p in base_model.parameters(): |
| p.requires_grad_(True) |
| base_model.train() |
| return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) |
|
|
|
|
| def timed_eval(label, fn, *args, **kwargs): |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| val_loss, val_bpb = fn(*args, **kwargs) |
| torch.cuda.synchronize() |
| elapsed_ms = 1e3 * (time.perf_counter() - t0) |
| log( |
| f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" |
| ) |
| return val_loss, val_bpb |
|
|
|
|
| def train_model(h, device, val_data): |
| base_model = GPT(h).to(device).bfloat16() |
| restore_fp32_params(base_model) |
| compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) |
| compiled_forward_logits = torch.compile( |
| base_model.forward_logits, dynamic=False, fullgraph=True |
| ) |
| model = compiled_model |
| log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") |
| optimizers = Optimizers(h, base_model) |
| train_loader = DocumentPackingLoader(h, device) |
| max_wallclock_ms = ( |
| 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None |
| ) |
| if max_wallclock_ms is not None: |
| max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 |
| log( |
| f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" |
| ) |
|
|
| def training_frac(step, elapsed_ms): |
| if max_wallclock_ms is None: |
| return step / max(h.iterations, 1) |
| return elapsed_ms / max(max_wallclock_ms, 1e-09) |
|
|
| def lr_mul(frac): |
| if h.warmdown_frac <= 0: |
| return 1.0 |
| if frac >= 1.0 - h.warmdown_frac: |
| return max((1.0 - frac) / h.warmdown_frac, h.min_lr) |
| return 1.0 |
|
|
| _clip_params = [p for p in base_model.parameters() if p.requires_grad] |
| def step_fn(step, lr_scale): |
| train_loss = torch.zeros((), device=device) |
| for micro_step in range(h.grad_accum_steps): |
| x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( |
| h.train_batch_tokens, h.grad_accum_steps |
| ) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) |
| train_loss += loss.detach() |
| (loss / h.grad_accum_steps).backward() |
| train_loss /= h.grad_accum_steps |
| if step <= h.muon_momentum_warmup_steps: |
|
|
| frac = ( |
|
|
| min(step / h.muon_momentum_warmup_steps, 1.0) |
|
|
| if h.muon_momentum_warmup_steps > 0 |
|
|
| else 1.0 |
|
|
| ) |
|
|
| muon_momentum = ( |
|
|
| 1 - frac |
|
|
| ) * h.muon_momentum_warmup_start + frac * h.muon_momentum |
|
|
| for group in optimizers.optimizer_muon.param_groups: |
|
|
| group["momentum"] = muon_momentum |
| for opt in optimizers: |
| for group in opt.param_groups: |
| group["lr"] = group["base_lr"] * lr_scale |
| if h.grad_clip_norm > 0: |
| torch.nn.utils.clip_grad_norm_(_clip_params, h.grad_clip_norm) |
| optimizers.step(distributed=h.distributed) |
| return train_loss |
|
|
| if h.warmup_steps > 0: |
| initial_model_state = { |
| name: tensor.detach().cpu().clone() |
| for (name, tensor) in base_model.state_dict().items() |
| } |
| initial_optimizer_states = [ |
| copy.deepcopy(opt.state_dict()) for opt in optimizers |
| ] |
| model.train() |
| num_tokens_local = h.train_batch_tokens // h.world_size |
| for blk in base_model.blocks: |
| blk.attn.rotary(num_tokens_local, device, torch.bfloat16) |
| cu_bucket_size = train_loader.cu_bucket_size |
| warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) |
| warmup_cu_iters = 3 |
| x, y, cu_seqlens, _ = train_loader.next_batch( |
| h.train_batch_tokens, h.grad_accum_steps |
| ) |
| log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") |
| def _run_cu_bucket_warmup(): |
| for bucket_len in warmup_cu_buckets: |
| boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) |
| if boundaries[-1] != x.size(1): |
| boundaries.append(x.size(1)) |
| cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) |
| cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) |
| for _ in range(warmup_cu_iters): |
| optimizers.zero_grad_all() |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) |
| (wloss / h.grad_accum_steps).backward() |
| optimizers.zero_grad_all() |
| _run_cu_bucket_warmup() |
| if h.num_loops > 0: |
| base_model.looping_active = True |
| _run_cu_bucket_warmup() |
| base_model.looping_active = False |
| for warmup_step in range(h.warmup_steps): |
| step_fn(warmup_step, 1.0) |
| if ( |
| warmup_step <= 5 |
| or (warmup_step + 1) % 10 == 0 |
| or warmup_step + 1 == h.warmup_steps |
| ): |
| log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") |
| if h.num_loops > 0: |
| base_model.looping_active = True |
| log( |
| f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" |
| ) |
| for warmup_step in range(h.warmup_steps): |
| step_fn(warmup_step, 1.0) |
| if ( |
| warmup_step <= 5 |
| or (warmup_step + 1) % 10 == 0 |
| or warmup_step + 1 == h.warmup_steps |
| ): |
| log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") |
| base_model.looping_active = False |
| base_model.load_state_dict(initial_model_state, strict=True) |
| for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): |
| opt.load_state_dict(state) |
| optimizers.zero_grad_all() |
| train_loader = DocumentPackingLoader(h, device) |
| _live_state = base_model.state_dict(keep_vars=True) |
| ema_state = { |
| name: t.detach().float().clone() |
| for (name, t) in _live_state.items() |
| } |
| _ema_pairs = [(ema_state[name], t) for (name, t) in _live_state.items()] |
| ema_decay = h.ema_decay |
| training_time_ms = 0.0 |
| stop_after_step = None |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| step = 0 |
| while True: |
| last_step = ( |
| step == h.iterations |
| or stop_after_step is not None |
| and step >= stop_after_step |
| ) |
| should_validate = ( |
| last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 |
| ) |
| if should_validate: |
| torch.cuda.synchronize() |
| training_time_ms += 1e3 * (time.perf_counter() - t0) |
| val_loss, val_bpb = eval_val( |
| h, device, val_data, model, compiled_forward_logits |
| ) |
| log( |
| f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" |
| ) |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| if last_step: |
| if stop_after_step is not None and step < h.iterations: |
| log( |
| f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" |
| ) |
| break |
| elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) |
| frac = training_frac(step, elapsed_ms) |
| scale = lr_mul(frac) |
| if ( |
| h.num_loops > 0 |
| and not base_model.looping_active |
| and frac >= h.enable_looping_at |
| ): |
| base_model.looping_active = True |
| log( |
| f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" |
| ) |
| train_loss = step_fn(step, scale) |
| with torch.no_grad(): |
| for ema_t, t in _ema_pairs: |
| ema_t.mul_(ema_decay).add_(t.detach(), alpha=1.0 - ema_decay) |
| step += 1 |
| approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) |
| should_log_train = h.train_log_every > 0 and ( |
| step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None |
| ) |
| if should_log_train: |
| tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) |
| log( |
| f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" |
| ) |
| reached_cap = ( |
| max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms |
| ) |
| if h.distributed and max_wallclock_ms is not None: |
| reached_cap_tensor = torch.tensor(int(reached_cap), device=device) |
| dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) |
| reached_cap = bool(reached_cap_tensor.item()) |
| if stop_after_step is None and reached_cap: |
| stop_after_step = step |
| log( |
| f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" |
| ) |
| log("ema:applying EMA weights") |
| current_state = base_model.state_dict() |
| avg_state = { |
| name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() |
| } |
| base_model.load_state_dict(avg_state, strict=True) |
| return base_model, compiled_model, compiled_forward_logits |
|
|
|
|
| def train_and_eval(h, device): |
| random.seed(h.seed) |
| np.random.seed(h.seed) |
| torch.manual_seed(h.seed) |
| torch.cuda.manual_seed_all(h.seed) |
| if h.artifact_dir and h.is_main_process: |
| os.makedirs(h.artifact_dir, exist_ok=True) |
| val_data = ValidationData(h, device) |
| log( |
| f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" |
| ) |
| log(f"val_tokens: {val_data.val_tokens.numel()-1}") |
| |
| |
| |
| ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" |
| if ttt_eval_only: |
| log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") |
| log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") |
| log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") |
| log(f"ttt_weight_decay: {h.ttt_weight_decay}") |
| else: |
| base_model, compiled_model, compiled_forward_logits = train_model( |
| h, device, val_data |
| ) |
| torch._dynamo.reset() |
| timed_eval( |
| "diagnostic pre-quantization post-ema", |
| eval_val, |
| h, |
| device, |
| val_data, |
| compiled_model, |
| compiled_forward_logits, |
| ) |
| if os.environ.get("PREQUANT_ONLY", "0") == "1": |
| log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") |
| return |
| serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) |
| if h.distributed: |
| dist.barrier() |
| eval_model = deserialize(h, device) |
| if h.num_loops > 0: |
| eval_model.looping_active = True |
| if not ttt_eval_only: |
| compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) |
| compiled_forward_logits = torch.compile( |
| eval_model.forward_logits, dynamic=False, fullgraph=True |
| ) |
| timed_eval( |
| "diagnostic quantized", |
| eval_val, |
| h, |
| device, |
| val_data, |
| compiled_model, |
| compiled_forward_logits, |
| ) |
| del eval_model |
| if h.ttt_enabled: |
| if not ttt_eval_only: |
| del compiled_model |
| if ttt_eval_only: |
| del eval_model |
| torch._dynamo.reset() |
| torch.cuda.empty_cache() |
| ttt_model = deserialize(h, device) |
| if h.num_loops > 0: |
| ttt_model.looping_active = True |
| for p in ttt_model.parameters(): |
| p.requires_grad_(False) |
|
|
| if h.rope_yarn: |
| _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps |
| for block in ttt_model.blocks: |
| block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) |
| else: |
| for block in ttt_model.blocks: |
| block.attn.rotary._cos_cached = None |
| block.attn.rotary._sin_cached = None |
| block.attn.rotary._seq_len_cached = 0 |
| block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) |
|
|
| def _fwd_ttt_inner(input_ids, target_ids, lora): |
| return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) |
|
|
| _fwd_ttt_compiled_inner = None |
|
|
| def _fwd_ttt(input_ids, target_ids, lora): |
| nonlocal _fwd_ttt_compiled_inner |
| if _fwd_ttt_compiled_inner is None: |
| _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) |
| return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) |
|
|
| fwd_ttt_compiled = _fwd_ttt |
| log(f"ttt_lora:warming up compile (random tokens, no val data)") |
| global BOS_ID |
| if BOS_ID is None: |
| BOS_ID = 1 |
| t_warmup = time.perf_counter() |
| warmup_bszes = [h.ttt_batch_size] |
| for bsz in warmup_bszes: |
| wl = BatchedTTTLoRA( |
| bsz, ttt_model, h.ttt_lora_rank, |
| k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, |
| ).to(device) |
| wo = torch.optim.AdamW( |
| wl.parameters(), |
| lr=h.ttt_lora_lr, |
| betas=(h.ttt_beta1, h.ttt_beta2), |
| eps=1e-10, |
| weight_decay=h.ttt_weight_decay, |
| fused=True, |
| ) |
| for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): |
| xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) |
| yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| ptl = fwd_ttt_compiled(xw, yw, lora=wl) |
| ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() |
| wo.step() |
| wo.zero_grad(set_to_none=True) |
| del wl, wo |
| torch.cuda.empty_cache() |
| compile_elapsed = time.perf_counter() - t_warmup |
| log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") |
| log("\nbeginning TTT eval timer") |
| torch.cuda.synchronize() |
| t_ttt = time.perf_counter() |
| ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( |
| h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled |
| ) |
| torch.cuda.synchronize() |
| ttt_eval_elapsed = time.perf_counter() - t_ttt |
| log( |
| "quantized_ttt_phased " |
| f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " |
| f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" |
| ) |
| log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") |
| del ttt_model |
|
|
|
|
| def main(): |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ |
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required") |
| if world_size <= 0: |
| raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") |
| if 8 % world_size != 0: |
| raise ValueError( |
| f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" |
| ) |
| device = torch.device("cuda", local_rank) |
| torch.cuda.set_device(device) |
| if distributed: |
| dist.init_process_group(backend="nccl", device_id=device) |
| dist.barrier() |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
| from torch.backends.cuda import ( |
| enable_cudnn_sdp, |
| enable_flash_sdp, |
| enable_math_sdp, |
| enable_mem_efficient_sdp, |
| ) |
|
|
| enable_cudnn_sdp(False) |
| enable_flash_sdp(True) |
| enable_mem_efficient_sdp(False) |
| enable_math_sdp(False) |
| torch._dynamo.config.optimize_ddp = False |
| torch._dynamo.config.cache_size_limit = 64 |
| h = Hyperparameters() |
| set_logging_hparams(h) |
| if h.is_main_process: |
| os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) |
| log(100 * "=", console=False) |
| log("Hyperparameters:", console=True) |
| for (k, v) in sorted(vars(type(h)).items()): |
| if not k.startswith("_"): |
| log(f" {k}: {v}", console=True) |
| log("=" * 100, console=False) |
| log("Source code:", console=False) |
| log("=" * 100, console=False) |
| with open(__file__, "r", encoding="utf-8") as _src: |
| log(_src.read(), console=False) |
| log("=" * 100, console=False) |
| log(f"Running Python {sys.version}", console=False) |
| log(f"Running PyTorch {torch.__version__}", console=False) |
| log("=" * 100, console=False) |
| train_and_eval(h, device) |
| if distributed: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|