LOOS Restore
Goal
Restore stable pure-ternary training behavior without changing Triton, TileLang, or component kernels.
Findings
- The current model update path is not dead:
_ternary_update_memory(loss_components=...)can still call backward when the caller has not already done it, and directloss.backward()plus_ternary_update_memory(loss_signal=...)also works. - The training scripts were stale against the current
ARBModelAPI. They still usedenable_image, while the model now expectsenable_vision. training/pretrain.pytreated--accumlike float-gradient accumulation, but ternary modules store gradients in hook buffers. Waiting until the accumulation boundary only preserved the last microbatch's hook state unless every hook tensor was also stored, which would raise memory.- Short loss spikes were caused by aggressive integer state updates:
E_accumdefaulted to threshold 4, so scale exponents could move after only a few batches.T_accumthreshold 3/8 flipped signs early and could move many packed ternary weights at once.
TernaryEmbeddingTablehadcorr_accumandstep_counterasfloat16, which broke the integer-first training rule and reduced accumulation precision.
Changes
- Added
training/ternary_runtime.py.configure_ternary_training(...)sets conservative runtime thresholds on ternary modules without editing kernels.reset_runtime_state(...)clears KV/sliding-window state for randomly sampled pretraining batches.
- Updated
training/pretrain.py.- Uses
enable_vision. - Defaults
--accum-thresholdto32. - Adds
--e-accum-threshold, default32. - Applies ternary state updates once per microbatch, then uses
--accumfor logging/checkpoint cadence. - Resets runtime KV/sliding-window state by default for random batch training;
--preserve-statekeeps it for sequential/streamed training.
- Uses
- Updated pure training entrypoints.
training/text.py,training/audio.py,training/vision.py, andtraining/diffusion.pynow useenable_vision.- Each configures the same conservative ternary thresholds and passes detached
loss_signal.
- Updated finetuning entrypoints to use
enable_vision. - Hardened
_ternary_update_memory.- Detaches
loss_signal. - Avoids double-backward when hooks already exist.
- Clears stale hooks after update or skipped non-finite update.
- Detaches
- Restored integer correlation state in
TernaryEmbeddingTable.corr_accum:float16->int16.step_counter:float16->int32.
Verification
python -m compileall -q training/ternary_runtime.py training/pretrain.py training/text.py training/audio.py training/vision.py training/diffusion.py training/finetuning/text.py training/finetuning/audio.py training/finetuning/vision.py training/finetuning/diffusion.py arbitor/main.py arbitor/components.pyARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_gradient_capture.py- Result:
5 passed.
- Result:
python -m pytest -q testing/test_trainers.py::test_all_trainers_loss_signal_detached testing/test_trainers.py::test_pretrain_loss_signal_detached- Result:
2 passed.
- Result:
ARB_TERNARY_BACKEND=triton python training/pretrain.py --text-data training/data/tinyshakespeare.txt --steps 20 --batch 1 --ctx 16 --accum 2 --max-moe-iters 1 --no-save --log-interval 2 --eval-interval 0 --save-interval 0- Result: finite text training, reported loss stayed in the approximate
26-37band over the 20-step smoke run.
- Result: finite text training, reported loss stayed in the approximate
ARB_TERNARY_BACKEND=triton python training/text.py --data training/data/tinyshakespeare.txt --steps 5 --batch 1 --ctx 16 --eval-interval 5 --run text-smoke- Result: trainer started, kept
0trainable float params, and reportedtrain=31.995,eval=27.958.
- Result: trainer started, kept
Notes
- For direct ad hoc scripts that instantiate
ARBModelmanually, call:
from training.ternary_runtime import configure_ternary_training
accum_threshold = configure_ternary_training(model, accum_threshold=32, e_accum_threshold=32)
- Lower thresholds such as
3can converge over long runs, but they produce much larger loss spikes because signs and group scales move early. The training entrypoints now default to the smoother production setting.