| # LOOS Restore |
|
|
| ## Goal |
|
|
| Restore stable pure-ternary training behavior without changing Triton, TileLang, or component kernels. |
|
|
| ## Findings |
|
|
| - The current model update path is not dead: `_ternary_update_memory(loss_components=...)` can still call backward when the caller has not already done it, and direct `loss.backward()` plus `_ternary_update_memory(loss_signal=...)` also works. |
| - The training scripts were stale against the current `ARBModel` API. They still used `enable_image`, while the model now expects `enable_vision`. |
| - `training/pretrain.py` treated `--accum` like float-gradient accumulation, but ternary modules store gradients in hook buffers. Waiting until the accumulation boundary only preserved the last microbatch's hook state unless every hook tensor was also stored, which would raise memory. |
| - Short loss spikes were caused by aggressive integer state updates: |
| - `E_accum` defaulted to threshold 4, so scale exponents could move after only a few batches. |
| - `T_accum` threshold 3/8 flipped signs early and could move many packed ternary weights at once. |
| - `TernaryEmbeddingTable` had `corr_accum` and `step_counter` as `float16`, which broke the integer-first training rule and reduced accumulation precision. |
|
|
| ## Changes |
|
|
| - Added `training/ternary_runtime.py`. |
| - `configure_ternary_training(...)` sets conservative runtime thresholds on ternary modules without editing kernels. |
| - `reset_runtime_state(...)` clears KV/sliding-window state for randomly sampled pretraining batches. |
| - Updated `training/pretrain.py`. |
| - Uses `enable_vision`. |
| - Defaults `--accum-threshold` to `32`. |
| - Adds `--e-accum-threshold`, default `32`. |
| - Applies ternary state updates once per microbatch, then uses `--accum` for logging/checkpoint cadence. |
| - Resets runtime KV/sliding-window state by default for random batch training; `--preserve-state` keeps it for sequential/streamed training. |
| - Updated pure training entrypoints. |
| - `training/text.py`, `training/audio.py`, `training/vision.py`, and `training/diffusion.py` now use `enable_vision`. |
| - Each configures the same conservative ternary thresholds and passes detached `loss_signal`. |
| - Updated finetuning entrypoints to use `enable_vision`. |
| - Hardened `_ternary_update_memory`. |
| - Detaches `loss_signal`. |
| - Avoids double-backward when hooks already exist. |
| - Clears stale hooks after update or skipped non-finite update. |
| - Restored integer correlation state in `TernaryEmbeddingTable`. |
| - `corr_accum`: `float16` -> `int16`. |
| - `step_counter`: `float16` -> `int32`. |
|
|
| ## Verification |
|
|
| - `python -m compileall -q training/ternary_runtime.py training/pretrain.py training/text.py training/audio.py training/vision.py training/diffusion.py training/finetuning/text.py training/finetuning/audio.py training/finetuning/vision.py training/finetuning/diffusion.py arbitor/main.py arbitor/components.py` |
| - `ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_gradient_capture.py` |
| - Result: `5 passed`. |
| - `python -m pytest -q testing/test_trainers.py::test_all_trainers_loss_signal_detached testing/test_trainers.py::test_pretrain_loss_signal_detached` |
| - Result: `2 passed`. |
| - `ARB_TERNARY_BACKEND=triton python training/pretrain.py --text-data training/data/tinyshakespeare.txt --steps 20 --batch 1 --ctx 16 --accum 2 --max-moe-iters 1 --no-save --log-interval 2 --eval-interval 0 --save-interval 0` |
| - Result: finite text training, reported loss stayed in the approximate `26-37` band over the 20-step smoke run. |
| - `ARB_TERNARY_BACKEND=triton python training/text.py --data training/data/tinyshakespeare.txt --steps 5 --batch 1 --ctx 16 --eval-interval 5 --run text-smoke` |
| - Result: trainer started, kept `0` trainable float params, and reported `train=31.995`, `eval=27.958`. |
|
|
| ## Notes |
|
|
| - For direct ad hoc scripts that instantiate `ARBModel` manually, call: |
|
|
| ```python |
| from training.ternary_runtime import configure_ternary_training |
| accum_threshold = configure_ternary_training(model, accum_threshold=32, e_accum_threshold=32) |
| ``` |
|
|
| - Lower thresholds such as `3` can converge over long runs, but they produce much larger loss spikes because signs and group scales move early. The training entrypoints now default to the smoother production setting. |
|
|