composer-replication-framework / composer_replication /diloco /serverless /tests /test_serverless_diloco_integration.py
Codeseys's picture
Wave 14: close every Wave 13 review finding + 4 documentation files; Wave 14b: real PRIME-RL parity + multi-process DiLoCo convergence
d9dd3a5
"""End-to-end MockManager × torchft.DiLoCo integration test.
Closes the Wave 13 cross-model adversarial-review gap (Suggestion 4):
the original MockManager was advertised as a drop-in for torchft.Manager
but only stubbed `.allreduce / .should_commit / .start_quorum`. DiLoCo's
real call surface (audited from `torchft/local_sgd.py` v2026-spring) also
includes `current_step()`, `disallow_state_dict_read()`,
`allow_state_dict_read()`, `register_state_dict_fn()`, and the
`_use_async_quorum` attribute — plus `allreduce()` must return a Work-like
object with `.wait()`, not a raw tensor.
This test runs ONE full DiLoCo outer round (sync_every inner steps + the
sync) against a tiny `nn.Linear(4, 4)` with `world_size=1` so the
object-store rendezvous is trivial. It verifies:
1. Construction does not raise.
2. Running through one full outer round does not raise AttributeError
(which is what the old MockManager would have hit at `current_step()`).
3. The model parameters change after the outer step fires (proving the
outer SGD path actually executed end-to-end, not just that the
inner-step hooks ran).
4. The MockManager's step counter advanced exactly once (one outer round
⇒ one start_quorum bump).
5. DiLoCo registered a state-dict fn per fragment.
"""
from __future__ import annotations
import pytest
import torch
torchft = pytest.importorskip(
"torchft.local_sgd",
reason="torchft must be installed to run the DiLoCo integration test",
)
from composer_replication.diloco import make_diloco_outer_loop
from composer_replication.diloco.serverless.allreduce import (
MockManager,
ObjectStoreAllReduce,
_ImmediateWork,
)
def _make_store(tmp_path) -> ObjectStoreAllReduce:
return ObjectStoreAllReduce(
uri=str(tmp_path),
rank=0,
world_size=1,
timeout_s=10.0,
poll_interval_s=0.05,
)
def test_mockmanager_has_full_diloco_call_surface(tmp_path):
"""Audited methods/attrs from torchft/local_sgd.py DiLoCo path must exist."""
mgr = MockManager(_make_store(tmp_path))
# Methods DiLoCo invokes
for attr in (
"allreduce",
"should_commit",
"start_quorum",
"current_step",
"disallow_state_dict_read",
"allow_state_dict_read",
"register_state_dict_fn",
"wait_quorum",
"is_leader",
):
assert callable(getattr(mgr, attr)), f"MockManager missing method: {attr}"
# Attributes DiLoCo reads at construction / runtime
assert hasattr(mgr, "_use_async_quorum")
assert mgr._use_async_quorum is False # DiLoCo.__init__ rejects True
assert hasattr(mgr, "num_participants")
assert hasattr(mgr, "rank")
def test_mockmanager_allreduce_returns_workshaped(tmp_path):
"""DiLoCo stores the allreduce return in a list and calls `.wait()` later."""
mgr = MockManager(_make_store(tmp_path))
work = mgr.allreduce(torch.zeros(2, 2))
# It must look like torch.distributed.Work / torchft._DummyWork
assert hasattr(work, "wait"), "allreduce return must have .wait() (DiLoCo calls it)"
assert callable(work.wait)
# No-op .wait() must not raise on a synchronous mock.
assert work.wait() is True
# Defensive: get_future() should also work (some torch paths probe it).
fut = work.get_future()
assert fut is None or hasattr(fut, "wait")
# Concrete type
assert isinstance(work, _ImmediateWork)
def test_mockmanager_diloco_outer_round_completes(tmp_path):
"""Run one full inner+outer DiLoCo round and verify params change.
With world_size=1 + MockManager → ObjectStoreAllReduce(file://), the
rendezvous is single-process, so this test runs synchronously. We
use `sync_every=4` and run exactly 4 inner-optimizer steps; at the
4th step DiLoCo's post-hook fires `prepare_sync` then `perform_sync`,
exercising the entire MockManager surface.
"""
torch.manual_seed(0)
model = torch.nn.Linear(4, 4, bias=False)
initial_params = model.weight.detach().clone()
inner_optim = torch.optim.SGD(model.parameters(), lr=0.1)
store = _make_store(tmp_path)
manager = MockManager(store)
diloco = make_diloco_outer_loop(
manager=manager,
model_fragments=[model],
inner_optimizer=inner_optim,
outer_lr=0.7,
outer_momentum=0.9,
nesterov=True,
sync_every=4,
fragment_sync_delay=0,
fragment_update_alpha=0.0,
)
# Sanity: DiLoCo registered a state-dict fn for our single fragment.
assert len(manager._state_dict_fns) == 1, (
f"expected 1 fragment registration, got {list(manager._state_dict_fns)}"
)
x = torch.randn(2, 4)
target = torch.randn(2, 4)
with diloco:
for _ in range(4): # exactly sync_every inner steps → one outer round
inner_optim.zero_grad()
loss = ((model(x) - target) ** 2).mean()
loss.backward()
# Must NOT raise AttributeError on current_step / state_dict_read /
# register_state_dict_fn / etc. The original MockManager would have
# crashed here on the very first step's _step_pre_hook calling
# disallow_state_dict_read.
inner_optim.step()
# After exactly one outer round, the MockManager's step counter
# should have advanced exactly once (start_quorum is called once).
assert manager.current_step() == 1, (
f"expected current_step()==1 after one outer round, got {manager.current_step()}"
)
# The outer SGD step actually fired ⇒ params differ from initial.
final_params = model.weight.detach().clone()
assert not torch.allclose(initial_params, final_params), (
"model params unchanged after outer round — outer optimizer never ran"
)
def _diloco_replica_one_outer_round(
rendezvous_uri: str,
world_size: int,
sync_every: int,
) -> dict:
"""Top-level entry — must be importable for multiprocessing 'spawn'.
Each replica:
1. seeds torch with a SHARED seed for model init (DiLoCo's standard
assumption: all replicas start with identical weights — DiLoCo
only averages pseudo-gradients, not absolute weights, so divergent
inits would never reconcile).
2. builds nn.Linear(4, 4, bias=False) + SGD inner optimizer.
3. trains on RANK-SPECIFIC data so each replica's inner-trained
weights diverge during the inner loop (this is what gives the
pseudo-gradient real cross-rank variance — without it, the
averaging is observationally a no-op).
4. runs `sync_every` inner steps inside `make_diloco_outer_loop` —
this fires exactly one outer round.
5. returns the final flattened weight vector and the pre-outer
(purely-inner) weights.
The test then asserts both ranks' final weights are identical
(allclose), which proves the cross-replica allreduce of the
pseudo-gradient ran end-to-end. The pre-outer weights MUST differ
across ranks (proving rank-specific data drove divergence in the
inner loop) — otherwise the convergence assertion is vacuous.
"""
import os as _os
import torch as _torch
import torch.nn as _nn
from composer_replication.diloco import make_diloco_outer_loop
from composer_replication.diloco.serverless.allreduce import (
MockManager,
ObjectStoreAllReduce,
)
rank = int(_os.environ["REPLICA_RANK"])
# SHARED init seed — both replicas start with identical weights, as
# DiLoCo assumes. (DiLoCo averages pseudo-gradients, not weights, so
# divergent inits would never reconcile and the convergence claim
# would be incorrect.)
_torch.manual_seed(0)
model = _nn.Linear(4, 4, bias=False)
initial = model.weight.detach().clone()
inner_optim = _torch.optim.SGD(model.parameters(), lr=0.1)
store = ObjectStoreAllReduce(
rendezvous_uri,
rank=rank,
world_size=world_size,
timeout_s=120.0,
poll_interval_s=0.05,
)
manager = MockManager(store)
diloco = make_diloco_outer_loop(
manager=manager,
model_fragments=[model],
inner_optimizer=inner_optim,
sync_every=sync_every,
)
# RANK-SPECIFIC data so the inner-trained weights diverge before the
# outer sync — this is what makes "post-sync convergence" a real
# property to verify rather than a tautology.
_torch.manual_seed(100 + rank)
x = _torch.randn(2, 4)
target = _torch.randn(2, 4)
with diloco:
for _ in range(sync_every):
inner_optim.zero_grad()
loss = ((model(x) - target) ** 2).mean()
loss.backward()
inner_optim.step()
final = model.weight.detach().clone()
return {
"rank": rank,
"initial": initial.flatten().tolist(),
"final": final.flatten().tolist(),
"current_step": manager.current_step(),
}
def test_mockmanager_diloco_multi_process_weights_converge(tmp_path):
"""Wave 14 (Suggestion 4): cross-replica weight convergence after one outer round.
Spawns n_replicas=2 subprocesses with IDENTICAL initial weights
(DiLoCo's standard assumption — it averages pseudo-gradients, not
absolute weights) but RANK-SPECIFIC training data. After exactly
one DiLoCo outer round, both replicas must end with IDENTICAL
weights, because:
pseudo_grad_i = init - inner_trained_i # per-rank, differ
avg_pseudo = mean_i(pseudo_grad_i) # same on all ranks
final = init - outer_lr * avg_pseudo # same on all ranks
This catches averaging-direction bugs that the world_size=1
single-process test silently misses (a single-rank allreduce is a
no-op and can hide bugs in the multi-rank averaging arithmetic, the
file-staging round-id increment, or the weight redistribution after
the outer SGD step).
"""
import os as _os
import tempfile as _tempfile
from composer_replication.diloco.serverless import LocalProcessExecutor
n_replicas = 2
sync_every = 2
with _tempfile.TemporaryDirectory() as td:
rendezvous = _os.path.join(td, "diloco-multiproc-run")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._diloco_replica_one_outer_round",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
"sync_every": sync_every,
"rank_env": "REPLICA_RANK",
},
timeout=180,
)
results = executor.collect(handles, timeout=180)
# Diagnostic-friendly failure: surface per-rank error if any replica died.
statuses = {r["rank"]: r["status"] for r in results}
for rank in range(n_replicas):
assert statuses[rank] == "succeeded", (
f"rank {rank} failed: "
f"{next(r for r in results if r['rank'] == rank).get('error')}"
)
payloads = sorted([r["result"] for r in results], key=lambda d: d["rank"])
rank0, rank1 = payloads[0], payloads[1]
# Sanity: each replica really did fire exactly one outer round.
assert rank0["current_step"] == 1, rank0
assert rank1["current_step"] == 1, rank1
# Sanity: replicas STARTED with identical weights (DiLoCo assumption).
assert rank0["initial"] == rank1["initial"], (
"replicas started with different initial weights — DiLoCo only "
"averages pseudo-gradients, not weights, so this would prevent "
"convergence even with a perfectly correct allreduce"
)
# The actual property: after one full outer round both replicas must
# have the SAME final weights. Tight tolerance because the only
# arithmetic between them is SGD + a single allreduce-mean.
final0 = torch.tensor(rank0["final"])
final1 = torch.tensor(rank1["final"])
if not torch.allclose(final0, final1, atol=1e-5, rtol=1e-5):
max_abs_diff = (final0 - final1).abs().max().item()
pytest.fail(
"Multi-process DiLoCo did NOT converge to identical weights "
"after one outer round.\n"
f" rank0 final = {final0.tolist()}\n"
f" rank1 final = {final1.tolist()}\n"
f" max|diff| = {max_abs_diff}\n"
"This indicates a real cross-replica-averaging bug "
"(averaging direction, round-id desync, or weight redistribution)."
)
def test_mockmanager_diloco_two_outer_rounds_step_counter(tmp_path):
"""Two outer rounds must bump current_step() to 2 (fragment rotation safety)."""
torch.manual_seed(1)
model = torch.nn.Linear(4, 4, bias=False)
inner_optim = torch.optim.SGD(model.parameters(), lr=0.05)
manager = MockManager(_make_store(tmp_path))
diloco = make_diloco_outer_loop(
manager=manager,
model_fragments=[model],
inner_optimizer=inner_optim,
sync_every=2,
)
x = torch.randn(2, 4)
target = torch.randn(2, 4)
with diloco:
for _ in range(4): # 2 outer rounds at sync_every=2
inner_optim.zero_grad()
(((model(x) - target) ** 2).mean()).backward()
inner_optim.step()
assert manager.current_step() == 2, (
f"expected current_step()==2 after two outer rounds, got {manager.current_step()}"
)