Codeseys's picture
Wave 14: close every Wave 13 review finding + 4 documentation files; Wave 14b: real PRIME-RL parity + multi-process DiLoCo convergence
d9dd3a5
"""ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo.
DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local)
to a shared location once per H ≈ 500-1000 inner steps, then averages across
all replicas before the outer SGD step. With cross-job NCCL unavailable on
most serverless backends, we use object storage as the rendezvous medium.
Communication pattern per outer round:
1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt)
2. Each replica reads all peer pseudo-gradients: GET × N
3. Average locally → applied as `Manager.allreduce()` would have.
Backend support via fsspec: s3://, gs://, az://, hf://, file://.
The same code path works across all of them.
License compatibility: this module re-implements the contract of
`torchft.Manager.allreduce` through duck-typing — no torchft code is
copied. torchft itself is BSD-3.
"""
from __future__ import annotations
import io
import os
import time
from typing import Any
import torch
class ObjectStoreAllReduce:
"""fsspec-backed pseudo-gradient rendezvous.
Each call to `allreduce(tensor, name)` blocks until all peers have
written their version of `tensor` to the rendezvous location, then
returns the average.
Args:
uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or
a plain path "/tmp/diloco/run42/" (treated as file://).
rank: this replica's rank (0-indexed)
world_size: total number of replicas
round_id: optional, used to namespace successive sync rounds.
If None, a monotonically increasing counter is used internally.
timeout_s: per-allreduce timeout in seconds.
poll_interval_s: how often to check for peer files.
"""
def __init__(
self,
uri: str,
rank: int,
world_size: int,
*,
round_id: int | None = None,
timeout_s: float = 1800.0,
poll_interval_s: float = 1.0,
) -> None:
if not (0 <= rank < world_size):
raise ValueError(f"rank {rank} not in [0, {world_size})")
self.uri = uri.rstrip("/") + "/"
self.rank = rank
self.world_size = world_size
self.timeout_s = timeout_s
self.poll_interval_s = poll_interval_s
self._round_counter = 0 if round_id is None else round_id
# Lazy fsspec init; deferred so that local-only smoke tests don't
# require fsspec install in the dev environment.
self._fs = None
self._is_local = self.uri.startswith("file://") or self.uri.startswith("/")
if self._is_local:
local_path = self.uri.removeprefix("file://")
os.makedirs(local_path, exist_ok=True)
self._local_root = local_path
else:
self._init_fsspec()
def _init_fsspec(self) -> None:
try:
import fsspec # noqa: F401
except ImportError as e:
raise RuntimeError(
"Non-local rendezvous requires fsspec; install with "
"`pip install -e .[serverless]`. Got: " + repr(e)
)
import fsspec
protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file"
self._fs = fsspec.filesystem(protocol)
@property
def round_id(self) -> int:
return self._round_counter
def _round_dir(self, round_id: int) -> str:
return f"round_{round_id:06d}"
def _path_for(self, round_id: int, rank: int) -> str:
return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt"
def _put(self, relpath: str, payload: bytes) -> None:
if self._is_local:
full = os.path.join(self._local_root, relpath)
os.makedirs(os.path.dirname(full), exist_ok=True)
tmp = full + ".tmp"
with open(tmp, "wb") as f:
f.write(payload)
os.replace(tmp, full) # atomic on POSIX
else:
full = self.uri + relpath
assert self._fs is not None
with self._fs.open(full, "wb") as f:
f.write(payload)
def _get(self, relpath: str) -> bytes:
if self._is_local:
full = os.path.join(self._local_root, relpath)
with open(full, "rb") as f:
return f.read()
full = self.uri + relpath
assert self._fs is not None
with self._fs.open(full, "rb") as f:
return f.read()
def _exists(self, relpath: str) -> bool:
if self._is_local:
return os.path.exists(os.path.join(self._local_root, relpath))
full = self.uri + relpath
assert self._fs is not None
return self._fs.exists(full)
def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor:
"""Average `tensor` across all replicas via the object store.
Args:
tensor: the tensor to average. Modified in-place AND returned.
name: ignored — provided for API compat with torchft.Manager.
Returns:
The averaged tensor (modifies in-place; returns the same object).
"""
round_id = self._round_counter
self._round_counter += 1
# Serialize my tensor
buf = io.BytesIO()
torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf)
my_path = self._path_for(round_id, self.rank)
self._put(my_path, buf.getvalue())
# Wait for all peers
deadline = time.time() + self.timeout_s
peer_tensors: list[torch.Tensor] = []
for peer_rank in range(self.world_size):
peer_path = self._path_for(round_id, peer_rank)
while not self._exists(peer_path):
if time.time() > deadline:
raise TimeoutError(
f"ObjectStoreAllReduce: timed out waiting for "
f"rank {peer_rank} at {self.uri}{peer_path} "
f"(world_size={self.world_size}, round={round_id})"
)
time.sleep(self.poll_interval_s)
payload = self._get(peer_path)
peer_data = torch.load(io.BytesIO(payload), weights_only=False)
peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype))
# Compute average
stacked = torch.stack(peer_tensors, dim=0)
avg = stacked.mean(dim=0)
tensor.copy_(avg)
return tensor
# ---------------------------------------------------------------------
# MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce
# ---------------------------------------------------------------------
class _ImmediateWork:
"""Work-shaped wrapper for an already-completed allreduce.
`torchft.Manager.allreduce` returns a `torch.distributed.Work` (or
`torchft.work._DummyWork`) which DiLoCo calls `.wait()` on inside
`_StreamingDiLoCoFragment.perform_sync`. Our `ObjectStoreAllReduce`
is synchronous — by the time it returns, the average is already in
the tensor — so `.wait()` is a no-op.
We deliberately don't subclass `torch.distributed._Work` to keep this
module importable in environments without a full torch distributed
build; DiLoCo only does `work.wait()`, nothing more.
"""
__slots__ = ("_tensor",)
def __init__(self, tensor: torch.Tensor) -> None:
self._tensor = tensor
def wait(self, *_args: Any, **_kwargs: Any) -> bool:
return True
def get_future(self) -> Any:
# Torch >=2.x sometimes calls Work.get_future(); provide a satisfied
# future so callers don't crash. We only need to be defensive here;
# DiLoCo itself doesn't call this.
try:
import torch.futures as _f
fut = _f.Future()
fut.set_result(self._tensor)
return fut
except Exception: # pragma: no cover — defensive only
return None
class MockManager:
"""Drop-in replacement for `torchft.Manager` that delegates allreduce
to `ObjectStoreAllReduce`.
The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce`
method on the pseudo-gradient. By passing this mock instead, we route
that call through the object store, leaving the rest of the DiLoCo
machinery (sign convention, post-hook sequencing, etc.) untouched.
Reference: `make_diloco_outer_loop` in
`composer_replication/diloco/__init__.py` accepts an optional
`manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.
torchft.Manager surface audited from
``torchft/local_sgd.py`` (DiLoCo + _StreamingDiLoCoFragment paths) and
``torchft/manager.py``. Methods/attributes DiLoCo touches:
* ``allreduce(tensor, should_quantize=...) -> Work`` — must return an
object with ``.wait()`` (DiLoCo calls ``work.wait()`` in
``perform_sync``).
* ``should_commit() -> bool`` — gates the outer-optimizer step.
* ``start_quorum()`` — called once per outer round, before
``prepare_sync``.
* ``current_step() -> int`` — used to pick the streaming-DiLoCo
fragment for this round (``step % len(fragments)``).
* ``disallow_state_dict_read()`` / ``allow_state_dict_read()`` —
called every inner step from the optimizer pre/post hooks.
* ``register_state_dict_fn(key, load_fn, save_fn)`` — called once
per fragment from ``DiLoCo.__init__``.
* ``_use_async_quorum`` (attribute) — DiLoCo's constructor refuses
to start if this is truthy. Must exist and be False.
* ``num_participants`` / ``rank`` — read by upstream callers.
"""
def __init__(self, store: ObjectStoreAllReduce) -> None:
self._store = store
# torchft Manager attributes that DiLoCo consults at construction time
# or in user code paths.
self.num_participants = store.world_size
self.rank = store.rank
# DiLoCo.__init__ raises if this is truthy (line 622 of
# torchft/local_sgd.py). Object-store sync is synchronous → False.
self._use_async_quorum: bool = False
# Mirror the upstream Manager's monotonic step counter. DiLoCo reads
# this via current_step() to decide which fragment to sync each round.
# Bumped from start_quorum() so it advances exactly once per outer round.
self._step: int = 0
# State-dict-fn registry: torchft uses this for fault-tolerant
# checkpoint restore. We're single-shot serverless — record but never
# invoke. Tests can introspect this dict to confirm registration.
self._state_dict_fns: dict[str, tuple[Any, Any]] = {}
# ---- Core collective ------------------------------------------------
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> _ImmediateWork:
# DiLoCo expects a Work-like return value (it stores it in a list
# then calls .wait() later). Object-store all-reduce is synchronous,
# so the tensor is already averaged when we hand back the wrapper.
averaged = self._store.allreduce(tensor)
return _ImmediateWork(averaged)
# ---- Quorum / commit lifecycle -------------------------------------
def should_commit(self) -> bool:
# No fault-tolerance failover in serverless mode: every quorum
# always commits. Replica failure is handled by the orchestration
# layer (HF Jobs / Modal restart), not by DiLoCo skipping a round.
return True
def start_quorum(self) -> None:
# The upstream Manager bumps its step counter inside the quorum
# bookkeeping. Do the same so current_step() advances per round
# and DiLoCo's fragment-rotation math matches across replicas.
self._step += 1
def wait_quorum(self) -> int:
return self.num_participants
# ---- Step counter ---------------------------------------------------
def current_step(self) -> int:
return self._step
# ---- State-dict read gating ----------------------------------------
# torchft uses these to make checkpoint restore thread-safe. In a
# single-process serverless mock there's no concurrent reader, so they
# are no-ops — but they MUST exist (DiLoCo's pre/post optimizer hooks
# call them on every inner step).
def allow_state_dict_read(self) -> None:
pass
def disallow_state_dict_read(self) -> None:
pass
# ---- Checkpoint hook registry --------------------------------------
def register_state_dict_fn(
self,
key: str,
load_fn: Any,
save_fn: Any,
) -> None:
# DiLoCo registers one (load, save) pair per fragment so torchft can
# checkpoint the outer-optimizer state and original-parameter backup.
# In serverless mode we capture the registration so tests can verify
# it happened, but never invoke it — there's no HA failover.
self._state_dict_fns[key] = (load_fn, save_fn)
# ---- Convenience ----------------------------------------------------
def is_leader(self) -> bool:
# Not strictly required by DiLoCo but referenced in some torchft
# integrations / our own code that may swap MockManager in.
return self.rank == 0
__all__ = ["MockManager", "ObjectStoreAllReduce", "_ImmediateWork"]