Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,122 Bytes
ac05fbf f16fa23 ac05fbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """composer_diloco.py — DiLoCo outer-loop wrapper for Composer Replication Framework.
Wraps `torchft.local_sgd.DiLoCo` with the framework's conventions:
- Sign convention is documented LOUDLY here once and tested via Spike 008.
- The wrapper exposes the same constructor shape as torchft's DiLoCo so a
future swap-in of the upstream class is a one-line change.
- Vanilla DiLoCo (Douillard et al. 2023) = `fragment_sync_delay=0`, single
fragment. Streaming DiLoCo (Liu et al. 2025) = non-zero delay, multiple
fragments. Spike 008 uses vanilla; Streaming is configured by the same API.
Reference: `docs/adrs/ADR-003-diloco-impl.md`.
Sign convention (READ THIS BEFORE TOUCHING):
DiLoCo defines pseudo-gradient as
pseudograd = θ_initial - θ_local
(per torchft's `_save_grads()` at line 324 of `torchft/local_sgd.py`).
This is the **negative** of the local update direction (the local
update *moved* params from θ_initial toward θ_local).
Standard SGD subtracts gradients: `p.data ← p.data - lr * grad`.
So when the outer optimizer runs after `restore_parameters()` puts
p.data back to θ_initial:
p.data ← θ_initial - lr * (θ_initial - θ_local)
= θ_initial + lr * (θ_local - θ_initial)
For `lr=1, momentum=0` this lands exactly at θ_local. For `lr<1` it
interpolates between θ_initial and θ_local (the standard DiLoCo outer
step). Adding Nesterov momentum accumulates the local-update direction
across outer rounds.
No negation in our outer optimizer wrapper. The test
`test_diloco_pseudogradient_sign_convention` in
`spikes/008-streaming-diloco/tests/test_diloco_smoke.py` pins this
arithmetic and reports both the expected and the wrong-sign value on
failure for fast diagnosis if torchft ever flips its convention.
"""
from __future__ import annotations
from typing import Any
import torch
# Import lazily — torchft is an optional dep at framework level.
_TORCHFT_AVAILABLE = False
DiLoCo: Any = None
Manager: Any = None
_DummyWork: Any = None
try:
from torchft.local_sgd import DiLoCo as _DiLoCo # type: ignore[import]
from torchft.manager import Manager as _Manager # type: ignore[import]
from torchft.work import _DummyWork as __DummyWork # type: ignore[import]
_TORCHFT_AVAILABLE = True
DiLoCo = _DiLoCo
Manager = _Manager
_DummyWork = __DummyWork
except ImportError: # pragma: no cover — only hits in lighter-weight CI envs
pass
def make_diloco_outer_loop(
manager: Any,
model_fragments: list[torch.nn.Module],
inner_optimizer: torch.optim.Optimizer,
*,
outer_lr: float = 0.7,
outer_momentum: float = 0.9,
nesterov: bool = True,
sync_every: int = 100,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> Any:
"""Construct a DiLoCo wrapper around `model_fragments` with default DiLoCo hyperparams.
Default hyperparams (DiLoCo paper §3.2):
outer_lr = 0.7, outer_momentum = 0.9, Nesterov
Args:
manager: torchft.Manager (or test mock with `.allreduce`, `.should_commit`,
`.current_step`, `.start_quorum`)
model_fragments: list of nn.Modules. For vanilla DiLoCo, pass [whole_model].
For Streaming DiLoCo with N fragments, pass [frag_0, frag_1, ..., frag_N-1].
inner_optimizer: any torch.optim.Optimizer. Steps every batch.
outer_lr / outer_momentum / nesterov: outer SGD hyperparams.
Override defaults only if you know why.
sync_every: number of inner steps per outer round.
fragment_sync_delay: 0 = vanilla DiLoCo (sync at outer round).
>0 = Streaming DiLoCo with overlapped sync. Requires CUDA streams.
fragment_update_alpha: 0 = full replacement of fragment params on sync.
>0 = exponential mixing weight. Streaming DiLoCo only.
Returns:
A torchft.local_sgd.DiLoCo instance configured for the framework's
conventions. Use as a context manager:
with make_diloco_outer_loop(...) as outer:
for step in range(N):
inner_optimizer.zero_grad()
loss = compute_loss(...)
loss.backward()
inner_optimizer.step() # outer sync fires automatically
"""
if not _TORCHFT_AVAILABLE:
raise RuntimeError(
"torchft is not installed. `pip install torchft-nightly` to use DiLoCo."
)
outer_optimizer = torch.optim.SGD(
[p for frag in model_fragments for p in frag.parameters()],
lr=outer_lr,
momentum=outer_momentum,
nesterov=nesterov,
)
return DiLoCo(
manager=manager,
model_fragments=model_fragments,
inner_optimizer=inner_optimizer,
outer_optimizer=outer_optimizer,
sync_every=sync_every,
fragment_sync_delay=fragment_sync_delay,
fragment_update_alpha=fragment_update_alpha,
)
__all__ = [
"make_diloco_outer_loop",
"DiLoCo",
"Manager",
"_DummyWork",
"_TORCHFT_AVAILABLE",
]
|