"""ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo. DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local) to a shared location once per H ≈ 500-1000 inner steps, then averages across all replicas before the outer SGD step. With cross-job NCCL unavailable on most serverless backends, we use object storage as the rendezvous medium. Communication pattern per outer round: 1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt) 2. Each replica reads all peer pseudo-gradients: GET × N 3. Average locally → applied as `Manager.allreduce()` would have. Backend support via fsspec: s3://, gs://, az://, hf://, file://. The same code path works across all of them. License compatibility: this module re-implements the contract of `torchft.Manager.allreduce` through duck-typing — no torchft code is copied. torchft itself is BSD-3. """ from __future__ import annotations import io import os import time from typing import Any import torch class ObjectStoreAllReduce: """fsspec-backed pseudo-gradient rendezvous. Each call to `allreduce(tensor, name)` blocks until all peers have written their version of `tensor` to the rendezvous location, then returns the average. Args: uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or a plain path "/tmp/diloco/run42/" (treated as file://). rank: this replica's rank (0-indexed) world_size: total number of replicas round_id: optional, used to namespace successive sync rounds. If None, a monotonically increasing counter is used internally. timeout_s: per-allreduce timeout in seconds. poll_interval_s: how often to check for peer files. """ def __init__( self, uri: str, rank: int, world_size: int, *, round_id: int | None = None, timeout_s: float = 1800.0, poll_interval_s: float = 1.0, ) -> None: if not (0 <= rank < world_size): raise ValueError(f"rank {rank} not in [0, {world_size})") self.uri = uri.rstrip("/") + "/" self.rank = rank self.world_size = world_size self.timeout_s = timeout_s self.poll_interval_s = poll_interval_s self._round_counter = 0 if round_id is None else round_id # Lazy fsspec init; deferred so that local-only smoke tests don't # require fsspec install in the dev environment. self._fs = None self._is_local = self.uri.startswith("file://") or self.uri.startswith("/") if self._is_local: local_path = self.uri.removeprefix("file://") os.makedirs(local_path, exist_ok=True) self._local_root = local_path else: self._init_fsspec() def _init_fsspec(self) -> None: try: import fsspec # noqa: F401 except ImportError as e: raise RuntimeError( "Non-local rendezvous requires fsspec; install with " "`pip install -e .[serverless]`. Got: " + repr(e) ) import fsspec protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file" self._fs = fsspec.filesystem(protocol) @property def round_id(self) -> int: return self._round_counter def _round_dir(self, round_id: int) -> str: return f"round_{round_id:06d}" def _path_for(self, round_id: int, rank: int) -> str: return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt" def _put(self, relpath: str, payload: bytes) -> None: if self._is_local: full = os.path.join(self._local_root, relpath) os.makedirs(os.path.dirname(full), exist_ok=True) tmp = full + ".tmp" with open(tmp, "wb") as f: f.write(payload) os.replace(tmp, full) # atomic on POSIX else: full = self.uri + relpath assert self._fs is not None with self._fs.open(full, "wb") as f: f.write(payload) def _get(self, relpath: str) -> bytes: if self._is_local: full = os.path.join(self._local_root, relpath) with open(full, "rb") as f: return f.read() full = self.uri + relpath assert self._fs is not None with self._fs.open(full, "rb") as f: return f.read() def _exists(self, relpath: str) -> bool: if self._is_local: return os.path.exists(os.path.join(self._local_root, relpath)) full = self.uri + relpath assert self._fs is not None return self._fs.exists(full) def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor: """Average `tensor` across all replicas via the object store. Args: tensor: the tensor to average. Modified in-place AND returned. name: ignored — provided for API compat with torchft.Manager. Returns: The averaged tensor (modifies in-place; returns the same object). """ round_id = self._round_counter self._round_counter += 1 # Serialize my tensor buf = io.BytesIO() torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf) my_path = self._path_for(round_id, self.rank) self._put(my_path, buf.getvalue()) # Wait for all peers deadline = time.time() + self.timeout_s peer_tensors: list[torch.Tensor] = [] for peer_rank in range(self.world_size): peer_path = self._path_for(round_id, peer_rank) while not self._exists(peer_path): if time.time() > deadline: raise TimeoutError( f"ObjectStoreAllReduce: timed out waiting for " f"rank {peer_rank} at {self.uri}{peer_path} " f"(world_size={self.world_size}, round={round_id})" ) time.sleep(self.poll_interval_s) payload = self._get(peer_path) peer_data = torch.load(io.BytesIO(payload), weights_only=False) peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype)) # Compute average stacked = torch.stack(peer_tensors, dim=0) avg = stacked.mean(dim=0) tensor.copy_(avg) return tensor # --------------------------------------------------------------------- # MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce # --------------------------------------------------------------------- class _ImmediateWork: """Work-shaped wrapper for an already-completed allreduce. `torchft.Manager.allreduce` returns a `torch.distributed.Work` (or `torchft.work._DummyWork`) which DiLoCo calls `.wait()` on inside `_StreamingDiLoCoFragment.perform_sync`. Our `ObjectStoreAllReduce` is synchronous — by the time it returns, the average is already in the tensor — so `.wait()` is a no-op. We deliberately don't subclass `torch.distributed._Work` to keep this module importable in environments without a full torch distributed build; DiLoCo only does `work.wait()`, nothing more. """ __slots__ = ("_tensor",) def __init__(self, tensor: torch.Tensor) -> None: self._tensor = tensor def wait(self, *_args: Any, **_kwargs: Any) -> bool: return True def get_future(self) -> Any: # Torch >=2.x sometimes calls Work.get_future(); provide a satisfied # future so callers don't crash. We only need to be defensive here; # DiLoCo itself doesn't call this. try: import torch.futures as _f fut = _f.Future() fut.set_result(self._tensor) return fut except Exception: # pragma: no cover — defensive only return None class MockManager: """Drop-in replacement for `torchft.Manager` that delegates allreduce to `ObjectStoreAllReduce`. The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce` method on the pseudo-gradient. By passing this mock instead, we route that call through the object store, leaving the rest of the DiLoCo machinery (sign convention, post-hook sequencing, etc.) untouched. Reference: `make_diloco_outer_loop` in `composer_replication/diloco/__init__.py` accepts an optional `manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo. torchft.Manager surface audited from ``torchft/local_sgd.py`` (DiLoCo + _StreamingDiLoCoFragment paths) and ``torchft/manager.py``. Methods/attributes DiLoCo touches: * ``allreduce(tensor, should_quantize=...) -> Work`` — must return an object with ``.wait()`` (DiLoCo calls ``work.wait()`` in ``perform_sync``). * ``should_commit() -> bool`` — gates the outer-optimizer step. * ``start_quorum()`` — called once per outer round, before ``prepare_sync``. * ``current_step() -> int`` — used to pick the streaming-DiLoCo fragment for this round (``step % len(fragments)``). * ``disallow_state_dict_read()`` / ``allow_state_dict_read()`` — called every inner step from the optimizer pre/post hooks. * ``register_state_dict_fn(key, load_fn, save_fn)`` — called once per fragment from ``DiLoCo.__init__``. * ``_use_async_quorum`` (attribute) — DiLoCo's constructor refuses to start if this is truthy. Must exist and be False. * ``num_participants`` / ``rank`` — read by upstream callers. """ def __init__(self, store: ObjectStoreAllReduce) -> None: self._store = store # torchft Manager attributes that DiLoCo consults at construction time # or in user code paths. self.num_participants = store.world_size self.rank = store.rank # DiLoCo.__init__ raises if this is truthy (line 622 of # torchft/local_sgd.py). Object-store sync is synchronous → False. self._use_async_quorum: bool = False # Mirror the upstream Manager's monotonic step counter. DiLoCo reads # this via current_step() to decide which fragment to sync each round. # Bumped from start_quorum() so it advances exactly once per outer round. self._step: int = 0 # State-dict-fn registry: torchft uses this for fault-tolerant # checkpoint restore. We're single-shot serverless — record but never # invoke. Tests can introspect this dict to confirm registration. self._state_dict_fns: dict[str, tuple[Any, Any]] = {} # ---- Core collective ------------------------------------------------ def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> _ImmediateWork: # DiLoCo expects a Work-like return value (it stores it in a list # then calls .wait() later). Object-store all-reduce is synchronous, # so the tensor is already averaged when we hand back the wrapper. averaged = self._store.allreduce(tensor) return _ImmediateWork(averaged) # ---- Quorum / commit lifecycle ------------------------------------- def should_commit(self) -> bool: # No fault-tolerance failover in serverless mode: every quorum # always commits. Replica failure is handled by the orchestration # layer (HF Jobs / Modal restart), not by DiLoCo skipping a round. return True def start_quorum(self) -> None: # The upstream Manager bumps its step counter inside the quorum # bookkeeping. Do the same so current_step() advances per round # and DiLoCo's fragment-rotation math matches across replicas. self._step += 1 def wait_quorum(self) -> int: return self.num_participants # ---- Step counter --------------------------------------------------- def current_step(self) -> int: return self._step # ---- State-dict read gating ---------------------------------------- # torchft uses these to make checkpoint restore thread-safe. In a # single-process serverless mock there's no concurrent reader, so they # are no-ops — but they MUST exist (DiLoCo's pre/post optimizer hooks # call them on every inner step). def allow_state_dict_read(self) -> None: pass def disallow_state_dict_read(self) -> None: pass # ---- Checkpoint hook registry -------------------------------------- def register_state_dict_fn( self, key: str, load_fn: Any, save_fn: Any, ) -> None: # DiLoCo registers one (load, save) pair per fragment so torchft can # checkpoint the outer-optimizer state and original-parameter backup. # In serverless mode we capture the registration so tests can verify # it happened, but never invoke it — there's no HA failover. self._state_dict_fns[key] = (load_fn, save_fn) # ---- Convenience ---------------------------------------------------- def is_leader(self) -> bool: # Not strictly required by DiLoCo but referenced in some torchft # integrations / our own code that may swap MockManager in. return self.rank == 0 __all__ = ["MockManager", "ObjectStoreAllReduce", "_ImmediateWork"]