Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 13,400 Bytes
b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | """ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo.
DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local)
to a shared location once per H ≈ 500-1000 inner steps, then averages across
all replicas before the outer SGD step. With cross-job NCCL unavailable on
most serverless backends, we use object storage as the rendezvous medium.
Communication pattern per outer round:
1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt)
2. Each replica reads all peer pseudo-gradients: GET × N
3. Average locally → applied as `Manager.allreduce()` would have.
Backend support via fsspec: s3://, gs://, az://, hf://, file://.
The same code path works across all of them.
License compatibility: this module re-implements the contract of
`torchft.Manager.allreduce` through duck-typing — no torchft code is
copied. torchft itself is BSD-3.
"""
from __future__ import annotations
import io
import os
import time
from typing import Any
import torch
class ObjectStoreAllReduce:
"""fsspec-backed pseudo-gradient rendezvous.
Each call to `allreduce(tensor, name)` blocks until all peers have
written their version of `tensor` to the rendezvous location, then
returns the average.
Args:
uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or
a plain path "/tmp/diloco/run42/" (treated as file://).
rank: this replica's rank (0-indexed)
world_size: total number of replicas
round_id: optional, used to namespace successive sync rounds.
If None, a monotonically increasing counter is used internally.
timeout_s: per-allreduce timeout in seconds.
poll_interval_s: how often to check for peer files.
"""
def __init__(
self,
uri: str,
rank: int,
world_size: int,
*,
round_id: int | None = None,
timeout_s: float = 1800.0,
poll_interval_s: float = 1.0,
) -> None:
if not (0 <= rank < world_size):
raise ValueError(f"rank {rank} not in [0, {world_size})")
self.uri = uri.rstrip("/") + "/"
self.rank = rank
self.world_size = world_size
self.timeout_s = timeout_s
self.poll_interval_s = poll_interval_s
self._round_counter = 0 if round_id is None else round_id
# Lazy fsspec init; deferred so that local-only smoke tests don't
# require fsspec install in the dev environment.
self._fs = None
self._is_local = self.uri.startswith("file://") or self.uri.startswith("/")
if self._is_local:
local_path = self.uri.removeprefix("file://")
os.makedirs(local_path, exist_ok=True)
self._local_root = local_path
else:
self._init_fsspec()
def _init_fsspec(self) -> None:
try:
import fsspec # noqa: F401
except ImportError as e:
raise RuntimeError(
"Non-local rendezvous requires fsspec; install with "
"`pip install -e .[serverless]`. Got: " + repr(e)
)
import fsspec
protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file"
self._fs = fsspec.filesystem(protocol)
@property
def round_id(self) -> int:
return self._round_counter
def _round_dir(self, round_id: int) -> str:
return f"round_{round_id:06d}"
def _path_for(self, round_id: int, rank: int) -> str:
return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt"
def _put(self, relpath: str, payload: bytes) -> None:
if self._is_local:
full = os.path.join(self._local_root, relpath)
os.makedirs(os.path.dirname(full), exist_ok=True)
tmp = full + ".tmp"
with open(tmp, "wb") as f:
f.write(payload)
os.replace(tmp, full) # atomic on POSIX
else:
full = self.uri + relpath
assert self._fs is not None
with self._fs.open(full, "wb") as f:
f.write(payload)
def _get(self, relpath: str) -> bytes:
if self._is_local:
full = os.path.join(self._local_root, relpath)
with open(full, "rb") as f:
return f.read()
full = self.uri + relpath
assert self._fs is not None
with self._fs.open(full, "rb") as f:
return f.read()
def _exists(self, relpath: str) -> bool:
if self._is_local:
return os.path.exists(os.path.join(self._local_root, relpath))
full = self.uri + relpath
assert self._fs is not None
return self._fs.exists(full)
def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor:
"""Average `tensor` across all replicas via the object store.
Args:
tensor: the tensor to average. Modified in-place AND returned.
name: ignored — provided for API compat with torchft.Manager.
Returns:
The averaged tensor (modifies in-place; returns the same object).
"""
round_id = self._round_counter
self._round_counter += 1
# Serialize my tensor
buf = io.BytesIO()
torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf)
my_path = self._path_for(round_id, self.rank)
self._put(my_path, buf.getvalue())
# Wait for all peers
deadline = time.time() + self.timeout_s
peer_tensors: list[torch.Tensor] = []
for peer_rank in range(self.world_size):
peer_path = self._path_for(round_id, peer_rank)
while not self._exists(peer_path):
if time.time() > deadline:
raise TimeoutError(
f"ObjectStoreAllReduce: timed out waiting for "
f"rank {peer_rank} at {self.uri}{peer_path} "
f"(world_size={self.world_size}, round={round_id})"
)
time.sleep(self.poll_interval_s)
payload = self._get(peer_path)
peer_data = torch.load(io.BytesIO(payload), weights_only=False)
peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype))
# Compute average
stacked = torch.stack(peer_tensors, dim=0)
avg = stacked.mean(dim=0)
tensor.copy_(avg)
return tensor
# ---------------------------------------------------------------------
# MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce
# ---------------------------------------------------------------------
class _ImmediateWork:
"""Work-shaped wrapper for an already-completed allreduce.
`torchft.Manager.allreduce` returns a `torch.distributed.Work` (or
`torchft.work._DummyWork`) which DiLoCo calls `.wait()` on inside
`_StreamingDiLoCoFragment.perform_sync`. Our `ObjectStoreAllReduce`
is synchronous — by the time it returns, the average is already in
the tensor — so `.wait()` is a no-op.
We deliberately don't subclass `torch.distributed._Work` to keep this
module importable in environments without a full torch distributed
build; DiLoCo only does `work.wait()`, nothing more.
"""
__slots__ = ("_tensor",)
def __init__(self, tensor: torch.Tensor) -> None:
self._tensor = tensor
def wait(self, *_args: Any, **_kwargs: Any) -> bool:
return True
def get_future(self) -> Any:
# Torch >=2.x sometimes calls Work.get_future(); provide a satisfied
# future so callers don't crash. We only need to be defensive here;
# DiLoCo itself doesn't call this.
try:
import torch.futures as _f
fut = _f.Future()
fut.set_result(self._tensor)
return fut
except Exception: # pragma: no cover — defensive only
return None
class MockManager:
"""Drop-in replacement for `torchft.Manager` that delegates allreduce
to `ObjectStoreAllReduce`.
The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce`
method on the pseudo-gradient. By passing this mock instead, we route
that call through the object store, leaving the rest of the DiLoCo
machinery (sign convention, post-hook sequencing, etc.) untouched.
Reference: `make_diloco_outer_loop` in
`composer_replication/diloco/__init__.py` accepts an optional
`manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.
torchft.Manager surface audited from
``torchft/local_sgd.py`` (DiLoCo + _StreamingDiLoCoFragment paths) and
``torchft/manager.py``. Methods/attributes DiLoCo touches:
* ``allreduce(tensor, should_quantize=...) -> Work`` — must return an
object with ``.wait()`` (DiLoCo calls ``work.wait()`` in
``perform_sync``).
* ``should_commit() -> bool`` — gates the outer-optimizer step.
* ``start_quorum()`` — called once per outer round, before
``prepare_sync``.
* ``current_step() -> int`` — used to pick the streaming-DiLoCo
fragment for this round (``step % len(fragments)``).
* ``disallow_state_dict_read()`` / ``allow_state_dict_read()`` —
called every inner step from the optimizer pre/post hooks.
* ``register_state_dict_fn(key, load_fn, save_fn)`` — called once
per fragment from ``DiLoCo.__init__``.
* ``_use_async_quorum`` (attribute) — DiLoCo's constructor refuses
to start if this is truthy. Must exist and be False.
* ``num_participants`` / ``rank`` — read by upstream callers.
"""
def __init__(self, store: ObjectStoreAllReduce) -> None:
self._store = store
# torchft Manager attributes that DiLoCo consults at construction time
# or in user code paths.
self.num_participants = store.world_size
self.rank = store.rank
# DiLoCo.__init__ raises if this is truthy (line 622 of
# torchft/local_sgd.py). Object-store sync is synchronous → False.
self._use_async_quorum: bool = False
# Mirror the upstream Manager's monotonic step counter. DiLoCo reads
# this via current_step() to decide which fragment to sync each round.
# Bumped from start_quorum() so it advances exactly once per outer round.
self._step: int = 0
# State-dict-fn registry: torchft uses this for fault-tolerant
# checkpoint restore. We're single-shot serverless — record but never
# invoke. Tests can introspect this dict to confirm registration.
self._state_dict_fns: dict[str, tuple[Any, Any]] = {}
# ---- Core collective ------------------------------------------------
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> _ImmediateWork:
# DiLoCo expects a Work-like return value (it stores it in a list
# then calls .wait() later). Object-store all-reduce is synchronous,
# so the tensor is already averaged when we hand back the wrapper.
averaged = self._store.allreduce(tensor)
return _ImmediateWork(averaged)
# ---- Quorum / commit lifecycle -------------------------------------
def should_commit(self) -> bool:
# No fault-tolerance failover in serverless mode: every quorum
# always commits. Replica failure is handled by the orchestration
# layer (HF Jobs / Modal restart), not by DiLoCo skipping a round.
return True
def start_quorum(self) -> None:
# The upstream Manager bumps its step counter inside the quorum
# bookkeeping. Do the same so current_step() advances per round
# and DiLoCo's fragment-rotation math matches across replicas.
self._step += 1
def wait_quorum(self) -> int:
return self.num_participants
# ---- Step counter ---------------------------------------------------
def current_step(self) -> int:
return self._step
# ---- State-dict read gating ----------------------------------------
# torchft uses these to make checkpoint restore thread-safe. In a
# single-process serverless mock there's no concurrent reader, so they
# are no-ops — but they MUST exist (DiLoCo's pre/post optimizer hooks
# call them on every inner step).
def allow_state_dict_read(self) -> None:
pass
def disallow_state_dict_read(self) -> None:
pass
# ---- Checkpoint hook registry --------------------------------------
def register_state_dict_fn(
self,
key: str,
load_fn: Any,
save_fn: Any,
) -> None:
# DiLoCo registers one (load, save) pair per fragment so torchft can
# checkpoint the outer-optimizer state and original-parameter backup.
# In serverless mode we capture the registration so tests can verify
# it happened, but never invoke it — there's no HA failover.
self._state_dict_fns[key] = (load_fn, save_fn)
# ---- Convenience ----------------------------------------------------
def is_leader(self) -> bool:
# Not strictly required by DiLoCo but referenced in some torchft
# integrations / our own code that may swap MockManager in.
return self.rank == 0
__all__ = ["MockManager", "ObjectStoreAllReduce", "_ImmediateWork"]
|