"""composer_diloco.py — DiLoCo outer-loop wrapper for Composer Replication Framework. Wraps `torchft.local_sgd.DiLoCo` with the framework's conventions: - Sign convention is documented LOUDLY here once and tested via Spike 008. - The wrapper exposes the same constructor shape as torchft's DiLoCo so a future swap-in of the upstream class is a one-line change. - Vanilla DiLoCo (Douillard et al. 2023) = `fragment_sync_delay=0`, single fragment. Streaming DiLoCo (Liu et al. 2025) = non-zero delay, multiple fragments. Spike 008 uses vanilla; Streaming is configured by the same API. Reference: `docs/adrs/ADR-003-diloco-impl.md`. Sign convention (READ THIS BEFORE TOUCHING): DiLoCo defines pseudo-gradient as pseudograd = θ_initial - θ_local (per torchft's `_save_grads()` at line 324 of `torchft/local_sgd.py`). This is the **negative** of the local update direction (the local update *moved* params from θ_initial toward θ_local). Standard SGD subtracts gradients: `p.data ← p.data - lr * grad`. So when the outer optimizer runs after `restore_parameters()` puts p.data back to θ_initial: p.data ← θ_initial - lr * (θ_initial - θ_local) = θ_initial + lr * (θ_local - θ_initial) For `lr=1, momentum=0` this lands exactly at θ_local. For `lr<1` it interpolates between θ_initial and θ_local (the standard DiLoCo outer step). Adding Nesterov momentum accumulates the local-update direction across outer rounds. No negation in our outer optimizer wrapper. The test `test_diloco_pseudogradient_sign_convention` in `spikes/008-streaming-diloco/tests/test_diloco_smoke.py` pins this arithmetic and reports both the expected and the wrong-sign value on failure for fast diagnosis if torchft ever flips its convention. """ from __future__ import annotations from typing import Any import torch # Import lazily — torchft is an optional dep at framework level. _TORCHFT_AVAILABLE = False DiLoCo: Any = None Manager: Any = None _DummyWork: Any = None try: from torchft.local_sgd import DiLoCo as _DiLoCo # type: ignore[import] from torchft.manager import Manager as _Manager # type: ignore[import] from torchft.work import _DummyWork as __DummyWork # type: ignore[import] _TORCHFT_AVAILABLE = True DiLoCo = _DiLoCo Manager = _Manager _DummyWork = __DummyWork except ImportError: # pragma: no cover — only hits in lighter-weight CI envs pass def make_diloco_outer_loop( manager: Any, model_fragments: list[torch.nn.Module], inner_optimizer: torch.optim.Optimizer, *, outer_lr: float = 0.7, outer_momentum: float = 0.9, nesterov: bool = True, sync_every: int = 100, fragment_sync_delay: int = 0, fragment_update_alpha: float = 0.0, ) -> Any: """Construct a DiLoCo wrapper around `model_fragments` with default DiLoCo hyperparams. Default hyperparams (DiLoCo paper §3.2): outer_lr = 0.7, outer_momentum = 0.9, Nesterov Args: manager: torchft.Manager (or test mock with `.allreduce`, `.should_commit`, `.current_step`, `.start_quorum`) model_fragments: list of nn.Modules. For vanilla DiLoCo, pass [whole_model]. For Streaming DiLoCo with N fragments, pass [frag_0, frag_1, ..., frag_N-1]. inner_optimizer: any torch.optim.Optimizer. Steps every batch. outer_lr / outer_momentum / nesterov: outer SGD hyperparams. Override defaults only if you know why. sync_every: number of inner steps per outer round. fragment_sync_delay: 0 = vanilla DiLoCo (sync at outer round). >0 = Streaming DiLoCo with overlapped sync. Requires CUDA streams. fragment_update_alpha: 0 = full replacement of fragment params on sync. >0 = exponential mixing weight. Streaming DiLoCo only. Returns: A torchft.local_sgd.DiLoCo instance configured for the framework's conventions. Use as a context manager: with make_diloco_outer_loop(...) as outer: for step in range(N): inner_optimizer.zero_grad() loss = compute_loss(...) loss.backward() inner_optimizer.step() # outer sync fires automatically """ if not _TORCHFT_AVAILABLE: raise RuntimeError( "torchft is not installed. `pip install torchft-nightly` to use DiLoCo." ) outer_optimizer = torch.optim.SGD( [p for frag in model_fragments for p in frag.parameters()], lr=outer_lr, momentum=outer_momentum, nesterov=nesterov, ) return DiLoCo( manager=manager, model_fragments=model_fragments, inner_optimizer=inner_optimizer, outer_optimizer=outer_optimizer, sync_every=sync_every, fragment_sync_delay=fragment_sync_delay, fragment_update_alpha=fragment_update_alpha, ) __all__ = [ "make_diloco_outer_loop", "DiLoCo", "Manager", "_DummyWork", "_TORCHFT_AVAILABLE", ]