File size: 13,476 Bytes
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""End-to-end MockManager × torchft.DiLoCo integration test.

Closes the Wave 13 cross-model adversarial-review gap (Suggestion 4):
the original MockManager was advertised as a drop-in for torchft.Manager
but only stubbed `.allreduce / .should_commit / .start_quorum`. DiLoCo's
real call surface (audited from `torchft/local_sgd.py` v2026-spring) also
includes `current_step()`, `disallow_state_dict_read()`,
`allow_state_dict_read()`, `register_state_dict_fn()`, and the
`_use_async_quorum` attribute — plus `allreduce()` must return a Work-like
object with `.wait()`, not a raw tensor.

This test runs ONE full DiLoCo outer round (sync_every inner steps + the
sync) against a tiny `nn.Linear(4, 4)` with `world_size=1` so the
object-store rendezvous is trivial. It verifies:

1. Construction does not raise.
2. Running through one full outer round does not raise AttributeError
   (which is what the old MockManager would have hit at `current_step()`).
3. The model parameters change after the outer step fires (proving the
   outer SGD path actually executed end-to-end, not just that the
   inner-step hooks ran).
4. The MockManager's step counter advanced exactly once (one outer round
   ⇒ one start_quorum bump).
5. DiLoCo registered a state-dict fn per fragment.
"""
from __future__ import annotations

import pytest
import torch

torchft = pytest.importorskip(
    "torchft.local_sgd",
    reason="torchft must be installed to run the DiLoCo integration test",
)

from composer_replication.diloco import make_diloco_outer_loop
from composer_replication.diloco.serverless.allreduce import (
    MockManager,
    ObjectStoreAllReduce,
    _ImmediateWork,
)


def _make_store(tmp_path) -> ObjectStoreAllReduce:
    return ObjectStoreAllReduce(
        uri=str(tmp_path),
        rank=0,
        world_size=1,
        timeout_s=10.0,
        poll_interval_s=0.05,
    )


def test_mockmanager_has_full_diloco_call_surface(tmp_path):
    """Audited methods/attrs from torchft/local_sgd.py DiLoCo path must exist."""
    mgr = MockManager(_make_store(tmp_path))
    # Methods DiLoCo invokes
    for attr in (
        "allreduce",
        "should_commit",
        "start_quorum",
        "current_step",
        "disallow_state_dict_read",
        "allow_state_dict_read",
        "register_state_dict_fn",
        "wait_quorum",
        "is_leader",
    ):
        assert callable(getattr(mgr, attr)), f"MockManager missing method: {attr}"
    # Attributes DiLoCo reads at construction / runtime
    assert hasattr(mgr, "_use_async_quorum")
    assert mgr._use_async_quorum is False  # DiLoCo.__init__ rejects True
    assert hasattr(mgr, "num_participants")
    assert hasattr(mgr, "rank")


def test_mockmanager_allreduce_returns_workshaped(tmp_path):
    """DiLoCo stores the allreduce return in a list and calls `.wait()` later."""
    mgr = MockManager(_make_store(tmp_path))
    work = mgr.allreduce(torch.zeros(2, 2))
    # It must look like torch.distributed.Work / torchft._DummyWork
    assert hasattr(work, "wait"), "allreduce return must have .wait() (DiLoCo calls it)"
    assert callable(work.wait)
    # No-op .wait() must not raise on a synchronous mock.
    assert work.wait() is True
    # Defensive: get_future() should also work (some torch paths probe it).
    fut = work.get_future()
    assert fut is None or hasattr(fut, "wait")
    # Concrete type
    assert isinstance(work, _ImmediateWork)


def test_mockmanager_diloco_outer_round_completes(tmp_path):
    """Run one full inner+outer DiLoCo round and verify params change.

    With world_size=1 + MockManager → ObjectStoreAllReduce(file://), the
    rendezvous is single-process, so this test runs synchronously. We
    use `sync_every=4` and run exactly 4 inner-optimizer steps; at the
    4th step DiLoCo's post-hook fires `prepare_sync` then `perform_sync`,
    exercising the entire MockManager surface.
    """
    torch.manual_seed(0)
    model = torch.nn.Linear(4, 4, bias=False)
    initial_params = model.weight.detach().clone()

    inner_optim = torch.optim.SGD(model.parameters(), lr=0.1)

    store = _make_store(tmp_path)
    manager = MockManager(store)

    diloco = make_diloco_outer_loop(
        manager=manager,
        model_fragments=[model],
        inner_optimizer=inner_optim,
        outer_lr=0.7,
        outer_momentum=0.9,
        nesterov=True,
        sync_every=4,
        fragment_sync_delay=0,
        fragment_update_alpha=0.0,
    )

    # Sanity: DiLoCo registered a state-dict fn for our single fragment.
    assert len(manager._state_dict_fns) == 1, (
        f"expected 1 fragment registration, got {list(manager._state_dict_fns)}"
    )

    x = torch.randn(2, 4)
    target = torch.randn(2, 4)

    with diloco:
        for _ in range(4):  # exactly sync_every inner steps → one outer round
            inner_optim.zero_grad()
            loss = ((model(x) - target) ** 2).mean()
            loss.backward()
            # Must NOT raise AttributeError on current_step / state_dict_read /
            # register_state_dict_fn / etc. The original MockManager would have
            # crashed here on the very first step's _step_pre_hook calling
            # disallow_state_dict_read.
            inner_optim.step()

    # After exactly one outer round, the MockManager's step counter
    # should have advanced exactly once (start_quorum is called once).
    assert manager.current_step() == 1, (
        f"expected current_step()==1 after one outer round, got {manager.current_step()}"
    )

    # The outer SGD step actually fired ⇒ params differ from initial.
    final_params = model.weight.detach().clone()
    assert not torch.allclose(initial_params, final_params), (
        "model params unchanged after outer round — outer optimizer never ran"
    )


def _diloco_replica_one_outer_round(
    rendezvous_uri: str,
    world_size: int,
    sync_every: int,
) -> dict:
    """Top-level entry — must be importable for multiprocessing 'spawn'.

    Each replica:
      1. seeds torch with a SHARED seed for model init (DiLoCo's standard
         assumption: all replicas start with identical weights — DiLoCo
         only averages pseudo-gradients, not absolute weights, so divergent
         inits would never reconcile).
      2. builds nn.Linear(4, 4, bias=False) + SGD inner optimizer.
      3. trains on RANK-SPECIFIC data so each replica's inner-trained
         weights diverge during the inner loop (this is what gives the
         pseudo-gradient real cross-rank variance — without it, the
         averaging is observationally a no-op).
      4. runs `sync_every` inner steps inside `make_diloco_outer_loop` —
         this fires exactly one outer round.
      5. returns the final flattened weight vector and the pre-outer
         (purely-inner) weights.

    The test then asserts both ranks' final weights are identical
    (allclose), which proves the cross-replica allreduce of the
    pseudo-gradient ran end-to-end. The pre-outer weights MUST differ
    across ranks (proving rank-specific data drove divergence in the
    inner loop) — otherwise the convergence assertion is vacuous.
    """
    import os as _os
    import torch as _torch
    import torch.nn as _nn

    from composer_replication.diloco import make_diloco_outer_loop
    from composer_replication.diloco.serverless.allreduce import (
        MockManager,
        ObjectStoreAllReduce,
    )

    rank = int(_os.environ["REPLICA_RANK"])

    # SHARED init seed — both replicas start with identical weights, as
    # DiLoCo assumes. (DiLoCo averages pseudo-gradients, not weights, so
    # divergent inits would never reconcile and the convergence claim
    # would be incorrect.)
    _torch.manual_seed(0)
    model = _nn.Linear(4, 4, bias=False)
    initial = model.weight.detach().clone()

    inner_optim = _torch.optim.SGD(model.parameters(), lr=0.1)

    store = ObjectStoreAllReduce(
        rendezvous_uri,
        rank=rank,
        world_size=world_size,
        timeout_s=120.0,
        poll_interval_s=0.05,
    )
    manager = MockManager(store)

    diloco = make_diloco_outer_loop(
        manager=manager,
        model_fragments=[model],
        inner_optimizer=inner_optim,
        sync_every=sync_every,
    )

    # RANK-SPECIFIC data so the inner-trained weights diverge before the
    # outer sync — this is what makes "post-sync convergence" a real
    # property to verify rather than a tautology.
    _torch.manual_seed(100 + rank)
    x = _torch.randn(2, 4)
    target = _torch.randn(2, 4)

    with diloco:
        for _ in range(sync_every):
            inner_optim.zero_grad()
            loss = ((model(x) - target) ** 2).mean()
            loss.backward()
            inner_optim.step()

    final = model.weight.detach().clone()
    return {
        "rank": rank,
        "initial": initial.flatten().tolist(),
        "final": final.flatten().tolist(),
        "current_step": manager.current_step(),
    }


def test_mockmanager_diloco_multi_process_weights_converge(tmp_path):
    """Wave 14 (Suggestion 4): cross-replica weight convergence after one outer round.

    Spawns n_replicas=2 subprocesses with IDENTICAL initial weights
    (DiLoCo's standard assumption — it averages pseudo-gradients, not
    absolute weights) but RANK-SPECIFIC training data. After exactly
    one DiLoCo outer round, both replicas must end with IDENTICAL
    weights, because:

      pseudo_grad_i = init - inner_trained_i        # per-rank, differ
      avg_pseudo    = mean_i(pseudo_grad_i)         # same on all ranks
      final         = init - outer_lr * avg_pseudo  # same on all ranks

    This catches averaging-direction bugs that the world_size=1
    single-process test silently misses (a single-rank allreduce is a
    no-op and can hide bugs in the multi-rank averaging arithmetic, the
    file-staging round-id increment, or the weight redistribution after
    the outer SGD step).
    """
    import os as _os
    import tempfile as _tempfile

    from composer_replication.diloco.serverless import LocalProcessExecutor

    n_replicas = 2
    sync_every = 2
    with _tempfile.TemporaryDirectory() as td:
        rendezvous = _os.path.join(td, "diloco-multiproc-run")
        executor = LocalProcessExecutor()
        handles = executor.launch_replicas(
            n_replicas=n_replicas,
            entrypoint=f"{__name__}._diloco_replica_one_outer_round",
            entrypoint_args={
                "rendezvous_uri": rendezvous,
                "world_size": n_replicas,
                "sync_every": sync_every,
                "rank_env": "REPLICA_RANK",
            },
            timeout=180,
        )
        results = executor.collect(handles, timeout=180)

    # Diagnostic-friendly failure: surface per-rank error if any replica died.
    statuses = {r["rank"]: r["status"] for r in results}
    for rank in range(n_replicas):
        assert statuses[rank] == "succeeded", (
            f"rank {rank} failed: "
            f"{next(r for r in results if r['rank'] == rank).get('error')}"
        )

    payloads = sorted([r["result"] for r in results], key=lambda d: d["rank"])
    rank0, rank1 = payloads[0], payloads[1]

    # Sanity: each replica really did fire exactly one outer round.
    assert rank0["current_step"] == 1, rank0
    assert rank1["current_step"] == 1, rank1

    # Sanity: replicas STARTED with identical weights (DiLoCo assumption).
    assert rank0["initial"] == rank1["initial"], (
        "replicas started with different initial weights — DiLoCo only "
        "averages pseudo-gradients, not weights, so this would prevent "
        "convergence even with a perfectly correct allreduce"
    )

    # The actual property: after one full outer round both replicas must
    # have the SAME final weights. Tight tolerance because the only
    # arithmetic between them is SGD + a single allreduce-mean.
    final0 = torch.tensor(rank0["final"])
    final1 = torch.tensor(rank1["final"])
    if not torch.allclose(final0, final1, atol=1e-5, rtol=1e-5):
        max_abs_diff = (final0 - final1).abs().max().item()
        pytest.fail(
            "Multi-process DiLoCo did NOT converge to identical weights "
            "after one outer round.\n"
            f"  rank0 final = {final0.tolist()}\n"
            f"  rank1 final = {final1.tolist()}\n"
            f"  max|diff|   = {max_abs_diff}\n"
            "This indicates a real cross-replica-averaging bug "
            "(averaging direction, round-id desync, or weight redistribution)."
        )


def test_mockmanager_diloco_two_outer_rounds_step_counter(tmp_path):
    """Two outer rounds must bump current_step() to 2 (fragment rotation safety)."""
    torch.manual_seed(1)
    model = torch.nn.Linear(4, 4, bias=False)
    inner_optim = torch.optim.SGD(model.parameters(), lr=0.05)

    manager = MockManager(_make_store(tmp_path))

    diloco = make_diloco_outer_loop(
        manager=manager,
        model_fragments=[model],
        inner_optimizer=inner_optim,
        sync_every=2,
    )

    x = torch.randn(2, 4)
    target = torch.randn(2, 4)

    with diloco:
        for _ in range(4):  # 2 outer rounds at sync_every=2
            inner_optim.zero_grad()
            (((model(x) - target) ** 2).mean()).backward()
            inner_optim.step()

    assert manager.current_step() == 2, (
        f"expected current_step()==2 after two outer rounds, got {manager.current_step()}"
    )