Codeseys's picture
Wave 11: cross-model adversarial review + honest down-revision
f16fa23
"""composer_diloco.py — DiLoCo outer-loop wrapper for Composer Replication Framework.
Wraps `torchft.local_sgd.DiLoCo` with the framework's conventions:
- Sign convention is documented LOUDLY here once and tested via Spike 008.
- The wrapper exposes the same constructor shape as torchft's DiLoCo so a
future swap-in of the upstream class is a one-line change.
- Vanilla DiLoCo (Douillard et al. 2023) = `fragment_sync_delay=0`, single
fragment. Streaming DiLoCo (Liu et al. 2025) = non-zero delay, multiple
fragments. Spike 008 uses vanilla; Streaming is configured by the same API.
Reference: `docs/adrs/ADR-003-diloco-impl.md`.
Sign convention (READ THIS BEFORE TOUCHING):
DiLoCo defines pseudo-gradient as
pseudograd = θ_initial - θ_local
(per torchft's `_save_grads()` at line 324 of `torchft/local_sgd.py`).
This is the **negative** of the local update direction (the local
update *moved* params from θ_initial toward θ_local).
Standard SGD subtracts gradients: `p.data ← p.data - lr * grad`.
So when the outer optimizer runs after `restore_parameters()` puts
p.data back to θ_initial:
p.data ← θ_initial - lr * (θ_initial - θ_local)
= θ_initial + lr * (θ_local - θ_initial)
For `lr=1, momentum=0` this lands exactly at θ_local. For `lr<1` it
interpolates between θ_initial and θ_local (the standard DiLoCo outer
step). Adding Nesterov momentum accumulates the local-update direction
across outer rounds.
No negation in our outer optimizer wrapper. The test
`test_diloco_pseudogradient_sign_convention` in
`spikes/008-streaming-diloco/tests/test_diloco_smoke.py` pins this
arithmetic and reports both the expected and the wrong-sign value on
failure for fast diagnosis if torchft ever flips its convention.
"""
from __future__ import annotations
from typing import Any
import torch
# Import lazily — torchft is an optional dep at framework level.
_TORCHFT_AVAILABLE = False
DiLoCo: Any = None
Manager: Any = None
_DummyWork: Any = None
try:
from torchft.local_sgd import DiLoCo as _DiLoCo # type: ignore[import]
from torchft.manager import Manager as _Manager # type: ignore[import]
from torchft.work import _DummyWork as __DummyWork # type: ignore[import]
_TORCHFT_AVAILABLE = True
DiLoCo = _DiLoCo
Manager = _Manager
_DummyWork = __DummyWork
except ImportError: # pragma: no cover — only hits in lighter-weight CI envs
pass
def make_diloco_outer_loop(
manager: Any,
model_fragments: list[torch.nn.Module],
inner_optimizer: torch.optim.Optimizer,
*,
outer_lr: float = 0.7,
outer_momentum: float = 0.9,
nesterov: bool = True,
sync_every: int = 100,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> Any:
"""Construct a DiLoCo wrapper around `model_fragments` with default DiLoCo hyperparams.
Default hyperparams (DiLoCo paper §3.2):
outer_lr = 0.7, outer_momentum = 0.9, Nesterov
Args:
manager: torchft.Manager (or test mock with `.allreduce`, `.should_commit`,
`.current_step`, `.start_quorum`)
model_fragments: list of nn.Modules. For vanilla DiLoCo, pass [whole_model].
For Streaming DiLoCo with N fragments, pass [frag_0, frag_1, ..., frag_N-1].
inner_optimizer: any torch.optim.Optimizer. Steps every batch.
outer_lr / outer_momentum / nesterov: outer SGD hyperparams.
Override defaults only if you know why.
sync_every: number of inner steps per outer round.
fragment_sync_delay: 0 = vanilla DiLoCo (sync at outer round).
>0 = Streaming DiLoCo with overlapped sync. Requires CUDA streams.
fragment_update_alpha: 0 = full replacement of fragment params on sync.
>0 = exponential mixing weight. Streaming DiLoCo only.
Returns:
A torchft.local_sgd.DiLoCo instance configured for the framework's
conventions. Use as a context manager:
with make_diloco_outer_loop(...) as outer:
for step in range(N):
inner_optimizer.zero_grad()
loss = compute_loss(...)
loss.backward()
inner_optimizer.step() # outer sync fires automatically
"""
if not _TORCHFT_AVAILABLE:
raise RuntimeError(
"torchft is not installed. `pip install torchft-nightly` to use DiLoCo."
)
outer_optimizer = torch.optim.SGD(
[p for frag in model_fragments for p in frag.parameters()],
lr=outer_lr,
momentum=outer_momentum,
nesterov=nesterov,
)
return DiLoCo(
manager=manager,
model_fragments=model_fragments,
inner_optimizer=inner_optimizer,
outer_optimizer=outer_optimizer,
sync_every=sync_every,
fragment_sync_delay=fragment_sync_delay,
fragment_update_alpha=fragment_update_alpha,
)
__all__ = [
"make_diloco_outer_loop",
"DiLoCo",
"Manager",
"_DummyWork",
"_TORCHFT_AVAILABLE",
]