ARBS / LOOS-RESTORE.md
CLIWorks's picture
Upload folder using huggingface_hub
07c6ab1 verified

LOOS Restore

Goal

Restore stable pure-ternary training behavior without changing Triton, TileLang, or component kernels.

Findings

  • The current model update path is not dead: _ternary_update_memory(loss_components=...) can still call backward when the caller has not already done it, and direct loss.backward() plus _ternary_update_memory(loss_signal=...) also works.
  • The training scripts were stale against the current ARBModel API. They still used enable_image, while the model now expects enable_vision.
  • training/pretrain.py treated --accum like float-gradient accumulation, but ternary modules store gradients in hook buffers. Waiting until the accumulation boundary only preserved the last microbatch's hook state unless every hook tensor was also stored, which would raise memory.
  • Short loss spikes were caused by aggressive integer state updates:
    • E_accum defaulted to threshold 4, so scale exponents could move after only a few batches.
    • T_accum threshold 3/8 flipped signs early and could move many packed ternary weights at once.
  • TernaryEmbeddingTable had corr_accum and step_counter as float16, which broke the integer-first training rule and reduced accumulation precision.

Changes

  • Added training/ternary_runtime.py.
    • configure_ternary_training(...) sets conservative runtime thresholds on ternary modules without editing kernels.
    • reset_runtime_state(...) clears KV/sliding-window state for randomly sampled pretraining batches.
  • Updated training/pretrain.py.
    • Uses enable_vision.
    • Defaults --accum-threshold to 32.
    • Adds --e-accum-threshold, default 32.
    • Applies ternary state updates once per microbatch, then uses --accum for logging/checkpoint cadence.
    • Resets runtime KV/sliding-window state by default for random batch training; --preserve-state keeps it for sequential/streamed training.
  • Updated pure training entrypoints.
    • training/text.py, training/audio.py, training/vision.py, and training/diffusion.py now use enable_vision.
    • Each configures the same conservative ternary thresholds and passes detached loss_signal.
  • Updated finetuning entrypoints to use enable_vision.
  • Hardened _ternary_update_memory.
    • Detaches loss_signal.
    • Avoids double-backward when hooks already exist.
    • Clears stale hooks after update or skipped non-finite update.
  • Restored integer correlation state in TernaryEmbeddingTable.
    • corr_accum: float16 -> int16.
    • step_counter: float16 -> int32.

Verification

  • python -m compileall -q training/ternary_runtime.py training/pretrain.py training/text.py training/audio.py training/vision.py training/diffusion.py training/finetuning/text.py training/finetuning/audio.py training/finetuning/vision.py training/finetuning/diffusion.py arbitor/main.py arbitor/components.py
  • ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_gradient_capture.py
    • Result: 5 passed.
  • python -m pytest -q testing/test_trainers.py::test_all_trainers_loss_signal_detached testing/test_trainers.py::test_pretrain_loss_signal_detached
    • Result: 2 passed.
  • ARB_TERNARY_BACKEND=triton python training/pretrain.py --text-data training/data/tinyshakespeare.txt --steps 20 --batch 1 --ctx 16 --accum 2 --max-moe-iters 1 --no-save --log-interval 2 --eval-interval 0 --save-interval 0
    • Result: finite text training, reported loss stayed in the approximate 26-37 band over the 20-step smoke run.
  • ARB_TERNARY_BACKEND=triton python training/text.py --data training/data/tinyshakespeare.txt --steps 5 --batch 1 --ctx 16 --eval-interval 5 --run text-smoke
    • Result: trainer started, kept 0 trainable float params, and reported train=31.995, eval=27.958.

Notes

  • For direct ad hoc scripts that instantiate ARBModel manually, call:
from training.ternary_runtime import configure_ternary_training
accum_threshold = configure_ternary_training(model, accum_threshold=32, e_accum_threshold=32)
  • Lower thresholds such as 3 can converge over long runs, but they produce much larger loss spikes because signs and group scales move early. The training entrypoints now default to the smoother production setting.