File size: 13,400 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
d9dd3a5
b266c31
 
d9dd3a5
 
b266c31
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
d9dd3a5
 
 
b266c31
 
 
d9dd3a5
 
 
 
b266c31
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
d9dd3a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo.

DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local)
to a shared location once per H ≈ 500-1000 inner steps, then averages across
all replicas before the outer SGD step. With cross-job NCCL unavailable on
most serverless backends, we use object storage as the rendezvous medium.

Communication pattern per outer round:
1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt)
2. Each replica reads all peer pseudo-gradients: GET × N
3. Average locally → applied as `Manager.allreduce()` would have.

Backend support via fsspec: s3://, gs://, az://, hf://, file://.
The same code path works across all of them.

License compatibility: this module re-implements the contract of
`torchft.Manager.allreduce` through duck-typing — no torchft code is
copied. torchft itself is BSD-3.
"""
from __future__ import annotations

import io
import os
import time
from typing import Any

import torch


class ObjectStoreAllReduce:
    """fsspec-backed pseudo-gradient rendezvous.

    Each call to `allreduce(tensor, name)` blocks until all peers have
    written their version of `tensor` to the rendezvous location, then
    returns the average.

    Args:
        uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or
            a plain path "/tmp/diloco/run42/" (treated as file://).
        rank: this replica's rank (0-indexed)
        world_size: total number of replicas
        round_id: optional, used to namespace successive sync rounds.
            If None, a monotonically increasing counter is used internally.
        timeout_s: per-allreduce timeout in seconds.
        poll_interval_s: how often to check for peer files.
    """

    def __init__(
        self,
        uri: str,
        rank: int,
        world_size: int,
        *,
        round_id: int | None = None,
        timeout_s: float = 1800.0,
        poll_interval_s: float = 1.0,
    ) -> None:
        if not (0 <= rank < world_size):
            raise ValueError(f"rank {rank} not in [0, {world_size})")
        self.uri = uri.rstrip("/") + "/"
        self.rank = rank
        self.world_size = world_size
        self.timeout_s = timeout_s
        self.poll_interval_s = poll_interval_s
        self._round_counter = 0 if round_id is None else round_id

        # Lazy fsspec init; deferred so that local-only smoke tests don't
        # require fsspec install in the dev environment.
        self._fs = None
        self._is_local = self.uri.startswith("file://") or self.uri.startswith("/")
        if self._is_local:
            local_path = self.uri.removeprefix("file://")
            os.makedirs(local_path, exist_ok=True)
            self._local_root = local_path
        else:
            self._init_fsspec()

    def _init_fsspec(self) -> None:
        try:
            import fsspec  # noqa: F401
        except ImportError as e:
            raise RuntimeError(
                "Non-local rendezvous requires fsspec; install with "
                "`pip install -e .[serverless]`. Got: " + repr(e)
            )
        import fsspec
        protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file"
        self._fs = fsspec.filesystem(protocol)

    @property
    def round_id(self) -> int:
        return self._round_counter

    def _round_dir(self, round_id: int) -> str:
        return f"round_{round_id:06d}"

    def _path_for(self, round_id: int, rank: int) -> str:
        return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt"

    def _put(self, relpath: str, payload: bytes) -> None:
        if self._is_local:
            full = os.path.join(self._local_root, relpath)
            os.makedirs(os.path.dirname(full), exist_ok=True)
            tmp = full + ".tmp"
            with open(tmp, "wb") as f:
                f.write(payload)
            os.replace(tmp, full)  # atomic on POSIX
        else:
            full = self.uri + relpath
            assert self._fs is not None
            with self._fs.open(full, "wb") as f:
                f.write(payload)

    def _get(self, relpath: str) -> bytes:
        if self._is_local:
            full = os.path.join(self._local_root, relpath)
            with open(full, "rb") as f:
                return f.read()
        full = self.uri + relpath
        assert self._fs is not None
        with self._fs.open(full, "rb") as f:
            return f.read()

    def _exists(self, relpath: str) -> bool:
        if self._is_local:
            return os.path.exists(os.path.join(self._local_root, relpath))
        full = self.uri + relpath
        assert self._fs is not None
        return self._fs.exists(full)

    def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor:
        """Average `tensor` across all replicas via the object store.

        Args:
            tensor: the tensor to average. Modified in-place AND returned.
            name: ignored — provided for API compat with torchft.Manager.

        Returns:
            The averaged tensor (modifies in-place; returns the same object).
        """
        round_id = self._round_counter
        self._round_counter += 1

        # Serialize my tensor
        buf = io.BytesIO()
        torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf)
        my_path = self._path_for(round_id, self.rank)
        self._put(my_path, buf.getvalue())

        # Wait for all peers
        deadline = time.time() + self.timeout_s
        peer_tensors: list[torch.Tensor] = []
        for peer_rank in range(self.world_size):
            peer_path = self._path_for(round_id, peer_rank)
            while not self._exists(peer_path):
                if time.time() > deadline:
                    raise TimeoutError(
                        f"ObjectStoreAllReduce: timed out waiting for "
                        f"rank {peer_rank} at {self.uri}{peer_path} "
                        f"(world_size={self.world_size}, round={round_id})"
                    )
                time.sleep(self.poll_interval_s)
            payload = self._get(peer_path)
            peer_data = torch.load(io.BytesIO(payload), weights_only=False)
            peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype))

        # Compute average
        stacked = torch.stack(peer_tensors, dim=0)
        avg = stacked.mean(dim=0)
        tensor.copy_(avg)
        return tensor


# ---------------------------------------------------------------------
# MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce
# ---------------------------------------------------------------------


class _ImmediateWork:
    """Work-shaped wrapper for an already-completed allreduce.

    `torchft.Manager.allreduce` returns a `torch.distributed.Work` (or
    `torchft.work._DummyWork`) which DiLoCo calls `.wait()` on inside
    `_StreamingDiLoCoFragment.perform_sync`. Our `ObjectStoreAllReduce`
    is synchronous — by the time it returns, the average is already in
    the tensor — so `.wait()` is a no-op.

    We deliberately don't subclass `torch.distributed._Work` to keep this
    module importable in environments without a full torch distributed
    build; DiLoCo only does `work.wait()`, nothing more.
    """

    __slots__ = ("_tensor",)

    def __init__(self, tensor: torch.Tensor) -> None:
        self._tensor = tensor

    def wait(self, *_args: Any, **_kwargs: Any) -> bool:
        return True

    def get_future(self) -> Any:
        # Torch >=2.x sometimes calls Work.get_future(); provide a satisfied
        # future so callers don't crash. We only need to be defensive here;
        # DiLoCo itself doesn't call this.
        try:
            import torch.futures as _f

            fut = _f.Future()
            fut.set_result(self._tensor)
            return fut
        except Exception:  # pragma: no cover — defensive only
            return None


class MockManager:
    """Drop-in replacement for `torchft.Manager` that delegates allreduce
    to `ObjectStoreAllReduce`.

    The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce`
    method on the pseudo-gradient. By passing this mock instead, we route
    that call through the object store, leaving the rest of the DiLoCo
    machinery (sign convention, post-hook sequencing, etc.) untouched.

    Reference: `make_diloco_outer_loop` in
    `composer_replication/diloco/__init__.py` accepts an optional
    `manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.

    torchft.Manager surface audited from
    ``torchft/local_sgd.py`` (DiLoCo + _StreamingDiLoCoFragment paths) and
    ``torchft/manager.py``. Methods/attributes DiLoCo touches:

    * ``allreduce(tensor, should_quantize=...) -> Work`` — must return an
      object with ``.wait()`` (DiLoCo calls ``work.wait()`` in
      ``perform_sync``).
    * ``should_commit() -> bool`` — gates the outer-optimizer step.
    * ``start_quorum()`` — called once per outer round, before
      ``prepare_sync``.
    * ``current_step() -> int`` — used to pick the streaming-DiLoCo
      fragment for this round (``step % len(fragments)``).
    * ``disallow_state_dict_read()`` / ``allow_state_dict_read()`` —
      called every inner step from the optimizer pre/post hooks.
    * ``register_state_dict_fn(key, load_fn, save_fn)`` — called once
      per fragment from ``DiLoCo.__init__``.
    * ``_use_async_quorum`` (attribute) — DiLoCo's constructor refuses
      to start if this is truthy. Must exist and be False.
    * ``num_participants`` / ``rank`` — read by upstream callers.
    """

    def __init__(self, store: ObjectStoreAllReduce) -> None:
        self._store = store
        # torchft Manager attributes that DiLoCo consults at construction time
        # or in user code paths.
        self.num_participants = store.world_size
        self.rank = store.rank
        # DiLoCo.__init__ raises if this is truthy (line 622 of
        # torchft/local_sgd.py). Object-store sync is synchronous → False.
        self._use_async_quorum: bool = False
        # Mirror the upstream Manager's monotonic step counter. DiLoCo reads
        # this via current_step() to decide which fragment to sync each round.
        # Bumped from start_quorum() so it advances exactly once per outer round.
        self._step: int = 0
        # State-dict-fn registry: torchft uses this for fault-tolerant
        # checkpoint restore. We're single-shot serverless — record but never
        # invoke. Tests can introspect this dict to confirm registration.
        self._state_dict_fns: dict[str, tuple[Any, Any]] = {}

    # ---- Core collective ------------------------------------------------
    def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> _ImmediateWork:
        # DiLoCo expects a Work-like return value (it stores it in a list
        # then calls .wait() later). Object-store all-reduce is synchronous,
        # so the tensor is already averaged when we hand back the wrapper.
        averaged = self._store.allreduce(tensor)
        return _ImmediateWork(averaged)

    # ---- Quorum / commit lifecycle -------------------------------------
    def should_commit(self) -> bool:
        # No fault-tolerance failover in serverless mode: every quorum
        # always commits. Replica failure is handled by the orchestration
        # layer (HF Jobs / Modal restart), not by DiLoCo skipping a round.
        return True

    def start_quorum(self) -> None:
        # The upstream Manager bumps its step counter inside the quorum
        # bookkeeping. Do the same so current_step() advances per round
        # and DiLoCo's fragment-rotation math matches across replicas.
        self._step += 1

    def wait_quorum(self) -> int:
        return self.num_participants

    # ---- Step counter ---------------------------------------------------
    def current_step(self) -> int:
        return self._step

    # ---- State-dict read gating ----------------------------------------
    # torchft uses these to make checkpoint restore thread-safe. In a
    # single-process serverless mock there's no concurrent reader, so they
    # are no-ops — but they MUST exist (DiLoCo's pre/post optimizer hooks
    # call them on every inner step).
    def allow_state_dict_read(self) -> None:
        pass

    def disallow_state_dict_read(self) -> None:
        pass

    # ---- Checkpoint hook registry --------------------------------------
    def register_state_dict_fn(
        self,
        key: str,
        load_fn: Any,
        save_fn: Any,
    ) -> None:
        # DiLoCo registers one (load, save) pair per fragment so torchft can
        # checkpoint the outer-optimizer state and original-parameter backup.
        # In serverless mode we capture the registration so tests can verify
        # it happened, but never invoke it — there's no HA failover.
        self._state_dict_fns[key] = (load_fn, save_fn)

    # ---- Convenience ----------------------------------------------------
    def is_leader(self) -> bool:
        # Not strictly required by DiLoCo but referenced in some torchft
        # integrations / our own code that may swap MockManager in.
        return self.rank == 0


__all__ = ["MockManager", "ObjectStoreAllReduce", "_ImmediateWork"]