Codeseys's picture
Wave 18: 14 backlog items closed + 3-reviewer cross-family review
54efac8
raw
history blame
11.4 kB
"""Verifies the serverless DiLoCo allreduce wraps correctly across local
multiprocessing replicas using `file://` rendezvous.
This is the core multi-process test for the serverless layer. It exercises
the real allreduce barrier (with concurrent processes), not just the
single-process API.
"""
from __future__ import annotations
import os
import sys
import tempfile
import time
import pytest
import torch
from composer_replication.diloco.serverless import (
LocalProcessExecutor,
ObjectStoreAllReduce,
ReplicaHandle,
)
# ---------------------------------------------------------------------
# Single-process tests of ObjectStoreAllReduce primitives
# (don't need executor, just the file:// path + local manual orchestration)
# ---------------------------------------------------------------------
def test_object_store_allreduce_init_validates_rank():
with tempfile.TemporaryDirectory() as td:
with pytest.raises(ValueError, match="not in"):
ObjectStoreAllReduce(td, rank=5, world_size=2)
def test_object_store_allreduce_local_paths_create_dir():
"""Local backend should mkdir on init."""
with tempfile.TemporaryDirectory() as td:
new_path = os.path.join(td, "subdir", "subsubdir")
store = ObjectStoreAllReduce(new_path, rank=0, world_size=1)
assert os.path.isdir(new_path)
assert store.world_size == 1
def test_object_store_allreduce_world_size_1_passthrough():
"""With world_size=1 it just averages the tensor with itself."""
with tempfile.TemporaryDirectory() as td:
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
t = torch.tensor([1.0, 2.0, 3.0])
result = store.allreduce(t.clone())
torch.testing.assert_close(result, t, atol=1e-6, rtol=1e-6)
def test_object_store_allreduce_round_id_increments():
with tempfile.TemporaryDirectory() as td:
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
t = torch.zeros(3)
assert store.round_id == 0
store.allreduce(t.clone())
assert store.round_id == 1
store.allreduce(t.clone())
assert store.round_id == 2
# ---------------------------------------------------------------------
# Multi-process tests (the real verification — local executor + spawn)
# ---------------------------------------------------------------------
def _replica_compute_and_sync(
rendezvous_uri: str,
world_size: int,
rank_value: float,
) -> dict:
"""Top-level function — must be importable for multiprocessing 'spawn'.
Each replica creates a tensor whose value is `rank_value * (rank+1)` and
runs allreduce. The expected result is the mean of all replicas' tensors.
"""
rank = int(os.environ["REPLICA_RANK"])
store = ObjectStoreAllReduce(
rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
)
# tensor that depends on rank
t = torch.full((4,), float(rank_value * (rank + 1)))
pre = t.clone()
averaged = store.allreduce(t)
return {
"rank": rank,
"pre": pre.tolist(),
"post": averaged.tolist(),
"world_size": world_size,
}
@pytest.mark.parametrize("n_replicas", [2, 3])
def test_local_executor_runs_allreduce_across_replicas(n_replicas):
"""End-to-end: 2-3 replica processes each call allreduce; result is the mean."""
with tempfile.TemporaryDirectory() as td:
rendezvous = os.path.join(td, "run")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_compute_and_sync",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
"rank_value": 10.0,
"rank_env": "REPLICA_RANK",
},
timeout=180,
)
assert len(handles) == n_replicas
for i, h in enumerate(handles):
assert h.rank == i
assert h.backend_name == "local_process"
results = executor.collect(handles, timeout=180)
assert len(results) == n_replicas
# Verify all succeeded
for r in results:
assert r["status"] == "succeeded", \
f"rank {r['rank']} failed: {r.get('error')}"
# Each replica created tensor full(rank_value * (rank+1)).
# Expected mean = rank_value * (1+2+...+N) / N
N = n_replicas
expected_mean = 10.0 * (N * (N + 1) / 2) / N
for r in results:
post = r["result"]["post"]
for v in post:
assert abs(v - expected_mean) < 1e-4, \
f"rank {r['rank']}: expected mean {expected_mean}, got {v}"
def _replica_two_round_sync(
rendezvous_uri: str,
world_size: int,
) -> dict:
"""Each replica does TWO consecutive allreduce calls; checks round_id increments."""
rank = int(os.environ["REPLICA_RANK"])
store = ObjectStoreAllReduce(
rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
)
t1 = torch.full((2,), float(rank))
avg1 = store.allreduce(t1).clone()
t2 = torch.full((2,), float(rank * 100))
avg2 = store.allreduce(t2).clone()
return {
"rank": rank,
"round_after_2_calls": store.round_id,
"avg1": avg1.tolist(),
"avg2": avg2.tolist(),
}
def test_local_executor_handles_multiple_rounds():
"""Two consecutive rounds each give the right mean; round counter advances."""
n_replicas = 3
with tempfile.TemporaryDirectory() as td:
rendezvous = os.path.join(td, "run-2round")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_two_round_sync",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
},
timeout=180,
)
results = executor.collect(handles, timeout=180)
for r in results:
assert r["status"] == "succeeded", r.get("error")
assert r["result"]["round_after_2_calls"] == 2
# mean of 0,1,2 = 1.0
assert all(abs(v - 1.0) < 1e-4 for v in r["result"]["avg1"])
# mean of 0,100,200 = 100.0
assert all(abs(v - 100.0) < 1e-4 for v in r["result"]["avg2"])
def _replica_that_raises(rendezvous_uri: str, world_size: int) -> dict:
"""Simulates a replica that crashes mid-run."""
rank = int(os.environ["REPLICA_RANK"])
if rank == 1:
raise RuntimeError(f"Simulated crash on rank {rank}")
return {"rank": rank, "ok": True}
def test_local_executor_reports_failed_replicas():
"""When a replica crashes, collect() reports it as failed without hanging
(other ranks complete; the failed one should be reflected in the result).
Note (Wave 18): timeouts bumped from 30s → 90s because this test was
flaky in full-suite runs (passes individually but occasionally times
out when other parallel multiprocessing tests contend for CPU).
The 30s budget was tight for cold-start subprocess + import +
rendezvous-file IO under contention; 90s gives comfortable headroom
without changing the test's semantic intent (subprocess crashes
surface as `failed` status, not hangs).
"""
n_replicas = 2
with tempfile.TemporaryDirectory() as td:
rendezvous = os.path.join(td, "run-failure")
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=n_replicas,
entrypoint=f"{__name__}._replica_that_raises",
entrypoint_args={
"rendezvous_uri": rendezvous,
"world_size": n_replicas,
},
timeout=90,
)
results = executor.collect(handles, timeout=90)
statuses = {r["rank"]: r["status"] for r in results}
assert statuses[0] == "succeeded"
assert statuses[1] == "failed"
# Failure log should mention the simulated crash
failure_log = next(r for r in results if r["rank"] == 1).get("error") or ""
assert "Simulated crash" in failure_log
# ---------------------------------------------------------------------
# Sanity: MockManager is shape-compatible with torchft Manager surface
# ---------------------------------------------------------------------
def test_mock_manager_shape_compat():
from composer_replication.diloco.serverless import MockManager
with tempfile.TemporaryDirectory() as td:
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
mgr = MockManager(store)
# torchft.Manager surface (audited from torchft/local_sgd.py DiLoCo path)
assert hasattr(mgr, "allreduce")
assert hasattr(mgr, "should_commit")
assert hasattr(mgr, "start_quorum")
assert hasattr(mgr, "wait_quorum")
assert hasattr(mgr, "current_step")
assert hasattr(mgr, "disallow_state_dict_read")
assert hasattr(mgr, "allow_state_dict_read")
assert hasattr(mgr, "register_state_dict_fn")
assert hasattr(mgr, "_use_async_quorum")
assert mgr._use_async_quorum is False
assert mgr.num_participants == 1
assert mgr.rank == 0
assert mgr.should_commit() is True
# Single-replica allreduce: averaging is a passthrough, but the return
# must be a Work-shaped object (DiLoCo calls .wait() on it). The
# tensor itself is mutated in place by ObjectStoreAllReduce.
t = torch.tensor([1.0, 2.0])
buf = t.clone()
work = mgr.allreduce(buf)
assert hasattr(work, "wait") and callable(work.wait)
assert work.wait() is True
torch.testing.assert_close(buf, t, atol=1e-6, rtol=1e-6)
# ---------------------------------------------------------------------
# Public re-export surface (Wave 17a)
# ---------------------------------------------------------------------
def test_public_reexports_include_all_executors():
"""`from composer_replication.diloco.serverless import …` must
surface every executor adapter the module's docstring claims, not
just the LocalProcessExecutor.
Wave 16's user-journey reviewer caught that ModalExecutor /
HFJobsExecutor were defined in `modal.py` / `hf_jobs.py` but not
re-exported from the package's `__init__.py`. Users who copied the
docstring's `from composer_replication.diloco.serverless import
ModalExecutor` line got an ImportError. Wave 17a added the missing
re-exports; this test pins them.
"""
import composer_replication.diloco.serverless as ss
expected = {
"LocalProcessExecutor",
"ModalExecutor",
"HFJobsExecutor",
"MockManager",
"ObjectStoreAllReduce",
"ReplicaHandle",
"ServerlessExecutor",
}
actual = set(ss.__all__)
assert expected.issubset(actual), (
f"Missing re-exports: {expected - actual}. "
f"__all__ should include every executor adapter the package "
f"docstring documents."
)
# Also verify each name is actually importable, not just listed.
for name in expected:
assert hasattr(ss, name), (
f"{name} listed in __all__ but not present on package."
)