File size: 5,122 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
f16fa23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""composer_diloco.py — DiLoCo outer-loop wrapper for Composer Replication Framework.

Wraps `torchft.local_sgd.DiLoCo` with the framework's conventions:
- Sign convention is documented LOUDLY here once and tested via Spike 008.
- The wrapper exposes the same constructor shape as torchft's DiLoCo so a
  future swap-in of the upstream class is a one-line change.
- Vanilla DiLoCo (Douillard et al. 2023) = `fragment_sync_delay=0`, single
  fragment. Streaming DiLoCo (Liu et al. 2025) = non-zero delay, multiple
  fragments. Spike 008 uses vanilla; Streaming is configured by the same API.

Reference: `docs/adrs/ADR-003-diloco-impl.md`.

Sign convention (READ THIS BEFORE TOUCHING):
    DiLoCo defines pseudo-gradient as

        pseudograd = θ_initial - θ_local

    (per torchft's `_save_grads()` at line 324 of `torchft/local_sgd.py`).
    This is the **negative** of the local update direction (the local
    update *moved* params from θ_initial toward θ_local).

    Standard SGD subtracts gradients: `p.data ← p.data - lr * grad`.
    So when the outer optimizer runs after `restore_parameters()` puts
    p.data back to θ_initial:

        p.data ← θ_initial - lr * (θ_initial - θ_local)
               = θ_initial + lr * (θ_local - θ_initial)

    For `lr=1, momentum=0` this lands exactly at θ_local. For `lr<1` it
    interpolates between θ_initial and θ_local (the standard DiLoCo outer
    step). Adding Nesterov momentum accumulates the local-update direction
    across outer rounds.

    No negation in our outer optimizer wrapper. The test
    `test_diloco_pseudogradient_sign_convention` in
    `spikes/008-streaming-diloco/tests/test_diloco_smoke.py` pins this
    arithmetic and reports both the expected and the wrong-sign value on
    failure for fast diagnosis if torchft ever flips its convention.
"""
from __future__ import annotations

from typing import Any

import torch

# Import lazily — torchft is an optional dep at framework level.
_TORCHFT_AVAILABLE = False
DiLoCo: Any = None
Manager: Any = None
_DummyWork: Any = None
try:
    from torchft.local_sgd import DiLoCo as _DiLoCo  # type: ignore[import]
    from torchft.manager import Manager as _Manager  # type: ignore[import]
    from torchft.work import _DummyWork as __DummyWork  # type: ignore[import]

    _TORCHFT_AVAILABLE = True
    DiLoCo = _DiLoCo
    Manager = _Manager
    _DummyWork = __DummyWork
except ImportError:  # pragma: no cover — only hits in lighter-weight CI envs
    pass


def make_diloco_outer_loop(
    manager: Any,
    model_fragments: list[torch.nn.Module],
    inner_optimizer: torch.optim.Optimizer,
    *,
    outer_lr: float = 0.7,
    outer_momentum: float = 0.9,
    nesterov: bool = True,
    sync_every: int = 100,
    fragment_sync_delay: int = 0,
    fragment_update_alpha: float = 0.0,
) -> Any:
    """Construct a DiLoCo wrapper around `model_fragments` with default DiLoCo hyperparams.

    Default hyperparams (DiLoCo paper §3.2):
        outer_lr = 0.7, outer_momentum = 0.9, Nesterov

    Args:
        manager: torchft.Manager (or test mock with `.allreduce`, `.should_commit`,
            `.current_step`, `.start_quorum`)
        model_fragments: list of nn.Modules. For vanilla DiLoCo, pass [whole_model].
            For Streaming DiLoCo with N fragments, pass [frag_0, frag_1, ..., frag_N-1].
        inner_optimizer: any torch.optim.Optimizer. Steps every batch.
        outer_lr / outer_momentum / nesterov: outer SGD hyperparams.
            Override defaults only if you know why.
        sync_every: number of inner steps per outer round.
        fragment_sync_delay: 0 = vanilla DiLoCo (sync at outer round).
            >0 = Streaming DiLoCo with overlapped sync. Requires CUDA streams.
        fragment_update_alpha: 0 = full replacement of fragment params on sync.
            >0 = exponential mixing weight. Streaming DiLoCo only.

    Returns:
        A torchft.local_sgd.DiLoCo instance configured for the framework's
        conventions. Use as a context manager:
            with make_diloco_outer_loop(...) as outer:
                for step in range(N):
                    inner_optimizer.zero_grad()
                    loss = compute_loss(...)
                    loss.backward()
                    inner_optimizer.step()  # outer sync fires automatically
    """
    if not _TORCHFT_AVAILABLE:
        raise RuntimeError(
            "torchft is not installed. `pip install torchft-nightly` to use DiLoCo."
        )

    outer_optimizer = torch.optim.SGD(
        [p for frag in model_fragments for p in frag.parameters()],
        lr=outer_lr,
        momentum=outer_momentum,
        nesterov=nesterov,
    )

    return DiLoCo(
        manager=manager,
        model_fragments=model_fragments,
        inner_optimizer=inner_optimizer,
        outer_optimizer=outer_optimizer,
        sync_every=sync_every,
        fragment_sync_delay=fragment_sync_delay,
        fragment_update_alpha=fragment_update_alpha,
    )


__all__ = [
    "make_diloco_outer_loop",
    "DiLoCo",
    "Manager",
    "_DummyWork",
    "_TORCHFT_AVAILABLE",
]