"""Verifies the serverless DiLoCo allreduce wraps correctly across local multiprocessing replicas using `file://` rendezvous. This is the core multi-process test for the serverless layer. It exercises the real allreduce barrier (with concurrent processes), not just the single-process API. """ from __future__ import annotations import os import sys import tempfile import time import pytest import torch from composer_replication.diloco.serverless import ( LocalProcessExecutor, ObjectStoreAllReduce, ReplicaHandle, ) # --------------------------------------------------------------------- # Single-process tests of ObjectStoreAllReduce primitives # (don't need executor, just the file:// path + local manual orchestration) # --------------------------------------------------------------------- def test_object_store_allreduce_init_validates_rank(): with tempfile.TemporaryDirectory() as td: with pytest.raises(ValueError, match="not in"): ObjectStoreAllReduce(td, rank=5, world_size=2) def test_object_store_allreduce_local_paths_create_dir(): """Local backend should mkdir on init.""" with tempfile.TemporaryDirectory() as td: new_path = os.path.join(td, "subdir", "subsubdir") store = ObjectStoreAllReduce(new_path, rank=0, world_size=1) assert os.path.isdir(new_path) assert store.world_size == 1 def test_object_store_allreduce_world_size_1_passthrough(): """With world_size=1 it just averages the tensor with itself.""" with tempfile.TemporaryDirectory() as td: store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0) t = torch.tensor([1.0, 2.0, 3.0]) result = store.allreduce(t.clone()) torch.testing.assert_close(result, t, atol=1e-6, rtol=1e-6) def test_object_store_allreduce_round_id_increments(): with tempfile.TemporaryDirectory() as td: store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0) t = torch.zeros(3) assert store.round_id == 0 store.allreduce(t.clone()) assert store.round_id == 1 store.allreduce(t.clone()) assert store.round_id == 2 # --------------------------------------------------------------------- # Multi-process tests (the real verification — local executor + spawn) # --------------------------------------------------------------------- def _replica_compute_and_sync( rendezvous_uri: str, world_size: int, rank_value: float, ) -> dict: """Top-level function — must be importable for multiprocessing 'spawn'. Each replica creates a tensor whose value is `rank_value * (rank+1)` and runs allreduce. The expected result is the mean of all replicas' tensors. """ rank = int(os.environ["REPLICA_RANK"]) store = ObjectStoreAllReduce( rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0, ) # tensor that depends on rank t = torch.full((4,), float(rank_value * (rank + 1))) pre = t.clone() averaged = store.allreduce(t) return { "rank": rank, "pre": pre.tolist(), "post": averaged.tolist(), "world_size": world_size, } @pytest.mark.parametrize("n_replicas", [2, 3]) def test_local_executor_runs_allreduce_across_replicas(n_replicas): """End-to-end: 2-3 replica processes each call allreduce; result is the mean.""" with tempfile.TemporaryDirectory() as td: rendezvous = os.path.join(td, "run") executor = LocalProcessExecutor() handles = executor.launch_replicas( n_replicas=n_replicas, entrypoint=f"{__name__}._replica_compute_and_sync", entrypoint_args={ "rendezvous_uri": rendezvous, "world_size": n_replicas, "rank_value": 10.0, "rank_env": "REPLICA_RANK", }, timeout=180, ) assert len(handles) == n_replicas for i, h in enumerate(handles): assert h.rank == i assert h.backend_name == "local_process" results = executor.collect(handles, timeout=180) assert len(results) == n_replicas # Verify all succeeded for r in results: assert r["status"] == "succeeded", \ f"rank {r['rank']} failed: {r.get('error')}" # Each replica created tensor full(rank_value * (rank+1)). # Expected mean = rank_value * (1+2+...+N) / N N = n_replicas expected_mean = 10.0 * (N * (N + 1) / 2) / N for r in results: post = r["result"]["post"] for v in post: assert abs(v - expected_mean) < 1e-4, \ f"rank {r['rank']}: expected mean {expected_mean}, got {v}" def _replica_two_round_sync( rendezvous_uri: str, world_size: int, ) -> dict: """Each replica does TWO consecutive allreduce calls; checks round_id increments.""" rank = int(os.environ["REPLICA_RANK"]) store = ObjectStoreAllReduce( rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0, ) t1 = torch.full((2,), float(rank)) avg1 = store.allreduce(t1).clone() t2 = torch.full((2,), float(rank * 100)) avg2 = store.allreduce(t2).clone() return { "rank": rank, "round_after_2_calls": store.round_id, "avg1": avg1.tolist(), "avg2": avg2.tolist(), } def test_local_executor_handles_multiple_rounds(): """Two consecutive rounds each give the right mean; round counter advances.""" n_replicas = 3 with tempfile.TemporaryDirectory() as td: rendezvous = os.path.join(td, "run-2round") executor = LocalProcessExecutor() handles = executor.launch_replicas( n_replicas=n_replicas, entrypoint=f"{__name__}._replica_two_round_sync", entrypoint_args={ "rendezvous_uri": rendezvous, "world_size": n_replicas, }, timeout=180, ) results = executor.collect(handles, timeout=180) for r in results: assert r["status"] == "succeeded", r.get("error") assert r["result"]["round_after_2_calls"] == 2 # mean of 0,1,2 = 1.0 assert all(abs(v - 1.0) < 1e-4 for v in r["result"]["avg1"]) # mean of 0,100,200 = 100.0 assert all(abs(v - 100.0) < 1e-4 for v in r["result"]["avg2"]) def _replica_that_raises(rendezvous_uri: str, world_size: int) -> dict: """Simulates a replica that crashes mid-run.""" rank = int(os.environ["REPLICA_RANK"]) if rank == 1: raise RuntimeError(f"Simulated crash on rank {rank}") return {"rank": rank, "ok": True} def test_local_executor_reports_failed_replicas(): """When a replica crashes, collect() reports it as failed without hanging (other ranks complete; the failed one should be reflected in the result). Note (Wave 18): timeouts bumped from 30s → 90s because this test was flaky in full-suite runs (passes individually but occasionally times out when other parallel multiprocessing tests contend for CPU). The 30s budget was tight for cold-start subprocess + import + rendezvous-file IO under contention; 90s gives comfortable headroom without changing the test's semantic intent (subprocess crashes surface as `failed` status, not hangs). """ n_replicas = 2 with tempfile.TemporaryDirectory() as td: rendezvous = os.path.join(td, "run-failure") executor = LocalProcessExecutor() handles = executor.launch_replicas( n_replicas=n_replicas, entrypoint=f"{__name__}._replica_that_raises", entrypoint_args={ "rendezvous_uri": rendezvous, "world_size": n_replicas, }, timeout=90, ) results = executor.collect(handles, timeout=90) statuses = {r["rank"]: r["status"] for r in results} assert statuses[0] == "succeeded" assert statuses[1] == "failed" # Failure log should mention the simulated crash failure_log = next(r for r in results if r["rank"] == 1).get("error") or "" assert "Simulated crash" in failure_log # --------------------------------------------------------------------- # Sanity: MockManager is shape-compatible with torchft Manager surface # --------------------------------------------------------------------- def test_mock_manager_shape_compat(): from composer_replication.diloco.serverless import MockManager with tempfile.TemporaryDirectory() as td: store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0) mgr = MockManager(store) # torchft.Manager surface (audited from torchft/local_sgd.py DiLoCo path) assert hasattr(mgr, "allreduce") assert hasattr(mgr, "should_commit") assert hasattr(mgr, "start_quorum") assert hasattr(mgr, "wait_quorum") assert hasattr(mgr, "current_step") assert hasattr(mgr, "disallow_state_dict_read") assert hasattr(mgr, "allow_state_dict_read") assert hasattr(mgr, "register_state_dict_fn") assert hasattr(mgr, "_use_async_quorum") assert mgr._use_async_quorum is False assert mgr.num_participants == 1 assert mgr.rank == 0 assert mgr.should_commit() is True # Single-replica allreduce: averaging is a passthrough, but the return # must be a Work-shaped object (DiLoCo calls .wait() on it). The # tensor itself is mutated in place by ObjectStoreAllReduce. t = torch.tensor([1.0, 2.0]) buf = t.clone() work = mgr.allreduce(buf) assert hasattr(work, "wait") and callable(work.wait) assert work.wait() is True torch.testing.assert_close(buf, t, atol=1e-6, rtol=1e-6) # --------------------------------------------------------------------- # Public re-export surface (Wave 17a) # --------------------------------------------------------------------- def test_public_reexports_include_all_executors(): """`from composer_replication.diloco.serverless import …` must surface every executor adapter the module's docstring claims, not just the LocalProcessExecutor. Wave 16's user-journey reviewer caught that ModalExecutor / HFJobsExecutor were defined in `modal.py` / `hf_jobs.py` but not re-exported from the package's `__init__.py`. Users who copied the docstring's `from composer_replication.diloco.serverless import ModalExecutor` line got an ImportError. Wave 17a added the missing re-exports; this test pins them. """ import composer_replication.diloco.serverless as ss expected = { "LocalProcessExecutor", "ModalExecutor", "HFJobsExecutor", "MockManager", "ObjectStoreAllReduce", "ReplicaHandle", "ServerlessExecutor", } actual = set(ss.__all__) assert expected.issubset(actual), ( f"Missing re-exports: {expected - actual}. " f"__all__ should include every executor adapter the package " f"docstring documents." ) # Also verify each name is actually importable, not just listed. for name in expected: assert hasattr(ss, name), ( f"{name} listed in __all__ but not present on package." )