Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo. | |
| DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local) | |
| to a shared location once per H ≈ 500-1000 inner steps, then averages across | |
| all replicas before the outer SGD step. With cross-job NCCL unavailable on | |
| most serverless backends, we use object storage as the rendezvous medium. | |
| Communication pattern per outer round: | |
| 1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt) | |
| 2. Each replica reads all peer pseudo-gradients: GET × N | |
| 3. Average locally → applied as `Manager.allreduce()` would have. | |
| Backend support via fsspec: s3://, gs://, az://, hf://, file://. | |
| The same code path works across all of them. | |
| License compatibility: this module re-implements the contract of | |
| `torchft.Manager.allreduce` through duck-typing — no torchft code is | |
| copied. torchft itself is BSD-3. | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import os | |
| import time | |
| from typing import Any | |
| import torch | |
| class ObjectStoreAllReduce: | |
| """fsspec-backed pseudo-gradient rendezvous. | |
| Each call to `allreduce(tensor, name)` blocks until all peers have | |
| written their version of `tensor` to the rendezvous location, then | |
| returns the average. | |
| Args: | |
| uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or | |
| a plain path "/tmp/diloco/run42/" (treated as file://). | |
| rank: this replica's rank (0-indexed) | |
| world_size: total number of replicas | |
| round_id: optional, used to namespace successive sync rounds. | |
| If None, a monotonically increasing counter is used internally. | |
| timeout_s: per-allreduce timeout in seconds. | |
| poll_interval_s: how often to check for peer files. | |
| """ | |
| def __init__( | |
| self, | |
| uri: str, | |
| rank: int, | |
| world_size: int, | |
| *, | |
| round_id: int | None = None, | |
| timeout_s: float = 1800.0, | |
| poll_interval_s: float = 1.0, | |
| ) -> None: | |
| if not (0 <= rank < world_size): | |
| raise ValueError(f"rank {rank} not in [0, {world_size})") | |
| self.uri = uri.rstrip("/") + "/" | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.timeout_s = timeout_s | |
| self.poll_interval_s = poll_interval_s | |
| self._round_counter = 0 if round_id is None else round_id | |
| # Lazy fsspec init; deferred so that local-only smoke tests don't | |
| # require fsspec install in the dev environment. | |
| self._fs = None | |
| self._is_local = self.uri.startswith("file://") or self.uri.startswith("/") | |
| if self._is_local: | |
| local_path = self.uri.removeprefix("file://") | |
| os.makedirs(local_path, exist_ok=True) | |
| self._local_root = local_path | |
| else: | |
| self._init_fsspec() | |
| def _init_fsspec(self) -> None: | |
| try: | |
| import fsspec # noqa: F401 | |
| except ImportError as e: | |
| raise RuntimeError( | |
| "Non-local rendezvous requires fsspec; install with " | |
| "`pip install -e .[serverless]`. Got: " + repr(e) | |
| ) | |
| import fsspec | |
| protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file" | |
| self._fs = fsspec.filesystem(protocol) | |
| def round_id(self) -> int: | |
| return self._round_counter | |
| def _round_dir(self, round_id: int) -> str: | |
| return f"round_{round_id:06d}" | |
| def _path_for(self, round_id: int, rank: int) -> str: | |
| return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt" | |
| def _put(self, relpath: str, payload: bytes) -> None: | |
| if self._is_local: | |
| full = os.path.join(self._local_root, relpath) | |
| os.makedirs(os.path.dirname(full), exist_ok=True) | |
| tmp = full + ".tmp" | |
| with open(tmp, "wb") as f: | |
| f.write(payload) | |
| os.replace(tmp, full) # atomic on POSIX | |
| else: | |
| full = self.uri + relpath | |
| assert self._fs is not None | |
| with self._fs.open(full, "wb") as f: | |
| f.write(payload) | |
| def _get(self, relpath: str) -> bytes: | |
| if self._is_local: | |
| full = os.path.join(self._local_root, relpath) | |
| with open(full, "rb") as f: | |
| return f.read() | |
| full = self.uri + relpath | |
| assert self._fs is not None | |
| with self._fs.open(full, "rb") as f: | |
| return f.read() | |
| def _exists(self, relpath: str) -> bool: | |
| if self._is_local: | |
| return os.path.exists(os.path.join(self._local_root, relpath)) | |
| full = self.uri + relpath | |
| assert self._fs is not None | |
| return self._fs.exists(full) | |
| def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor: | |
| """Average `tensor` across all replicas via the object store. | |
| Args: | |
| tensor: the tensor to average. Modified in-place AND returned. | |
| name: ignored — provided for API compat with torchft.Manager. | |
| Returns: | |
| The averaged tensor (modifies in-place; returns the same object). | |
| """ | |
| round_id = self._round_counter | |
| self._round_counter += 1 | |
| # Serialize my tensor | |
| buf = io.BytesIO() | |
| torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf) | |
| my_path = self._path_for(round_id, self.rank) | |
| self._put(my_path, buf.getvalue()) | |
| # Wait for all peers | |
| deadline = time.time() + self.timeout_s | |
| peer_tensors: list[torch.Tensor] = [] | |
| for peer_rank in range(self.world_size): | |
| peer_path = self._path_for(round_id, peer_rank) | |
| while not self._exists(peer_path): | |
| if time.time() > deadline: | |
| raise TimeoutError( | |
| f"ObjectStoreAllReduce: timed out waiting for " | |
| f"rank {peer_rank} at {self.uri}{peer_path} " | |
| f"(world_size={self.world_size}, round={round_id})" | |
| ) | |
| time.sleep(self.poll_interval_s) | |
| payload = self._get(peer_path) | |
| peer_data = torch.load(io.BytesIO(payload), weights_only=False) | |
| peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype)) | |
| # Compute average | |
| stacked = torch.stack(peer_tensors, dim=0) | |
| avg = stacked.mean(dim=0) | |
| tensor.copy_(avg) | |
| return tensor | |
| # --------------------------------------------------------------------- | |
| # MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce | |
| # --------------------------------------------------------------------- | |
| class _ImmediateWork: | |
| """Work-shaped wrapper for an already-completed allreduce. | |
| `torchft.Manager.allreduce` returns a `torch.distributed.Work` (or | |
| `torchft.work._DummyWork`) which DiLoCo calls `.wait()` on inside | |
| `_StreamingDiLoCoFragment.perform_sync`. Our `ObjectStoreAllReduce` | |
| is synchronous — by the time it returns, the average is already in | |
| the tensor — so `.wait()` is a no-op. | |
| We deliberately don't subclass `torch.distributed._Work` to keep this | |
| module importable in environments without a full torch distributed | |
| build; DiLoCo only does `work.wait()`, nothing more. | |
| """ | |
| __slots__ = ("_tensor",) | |
| def __init__(self, tensor: torch.Tensor) -> None: | |
| self._tensor = tensor | |
| def wait(self, *_args: Any, **_kwargs: Any) -> bool: | |
| return True | |
| def get_future(self) -> Any: | |
| # Torch >=2.x sometimes calls Work.get_future(); provide a satisfied | |
| # future so callers don't crash. We only need to be defensive here; | |
| # DiLoCo itself doesn't call this. | |
| try: | |
| import torch.futures as _f | |
| fut = _f.Future() | |
| fut.set_result(self._tensor) | |
| return fut | |
| except Exception: # pragma: no cover — defensive only | |
| return None | |
| class MockManager: | |
| """Drop-in replacement for `torchft.Manager` that delegates allreduce | |
| to `ObjectStoreAllReduce`. | |
| The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce` | |
| method on the pseudo-gradient. By passing this mock instead, we route | |
| that call through the object store, leaving the rest of the DiLoCo | |
| machinery (sign convention, post-hook sequencing, etc.) untouched. | |
| Reference: `make_diloco_outer_loop` in | |
| `composer_replication/diloco/__init__.py` accepts an optional | |
| `manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo. | |
| torchft.Manager surface audited from | |
| ``torchft/local_sgd.py`` (DiLoCo + _StreamingDiLoCoFragment paths) and | |
| ``torchft/manager.py``. Methods/attributes DiLoCo touches: | |
| * ``allreduce(tensor, should_quantize=...) -> Work`` — must return an | |
| object with ``.wait()`` (DiLoCo calls ``work.wait()`` in | |
| ``perform_sync``). | |
| * ``should_commit() -> bool`` — gates the outer-optimizer step. | |
| * ``start_quorum()`` — called once per outer round, before | |
| ``prepare_sync``. | |
| * ``current_step() -> int`` — used to pick the streaming-DiLoCo | |
| fragment for this round (``step % len(fragments)``). | |
| * ``disallow_state_dict_read()`` / ``allow_state_dict_read()`` — | |
| called every inner step from the optimizer pre/post hooks. | |
| * ``register_state_dict_fn(key, load_fn, save_fn)`` — called once | |
| per fragment from ``DiLoCo.__init__``. | |
| * ``_use_async_quorum`` (attribute) — DiLoCo's constructor refuses | |
| to start if this is truthy. Must exist and be False. | |
| * ``num_participants`` / ``rank`` — read by upstream callers. | |
| """ | |
| def __init__(self, store: ObjectStoreAllReduce) -> None: | |
| self._store = store | |
| # torchft Manager attributes that DiLoCo consults at construction time | |
| # or in user code paths. | |
| self.num_participants = store.world_size | |
| self.rank = store.rank | |
| # DiLoCo.__init__ raises if this is truthy (line 622 of | |
| # torchft/local_sgd.py). Object-store sync is synchronous → False. | |
| self._use_async_quorum: bool = False | |
| # Mirror the upstream Manager's monotonic step counter. DiLoCo reads | |
| # this via current_step() to decide which fragment to sync each round. | |
| # Bumped from start_quorum() so it advances exactly once per outer round. | |
| self._step: int = 0 | |
| # State-dict-fn registry: torchft uses this for fault-tolerant | |
| # checkpoint restore. We're single-shot serverless — record but never | |
| # invoke. Tests can introspect this dict to confirm registration. | |
| self._state_dict_fns: dict[str, tuple[Any, Any]] = {} | |
| # ---- Core collective ------------------------------------------------ | |
| def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> _ImmediateWork: | |
| # DiLoCo expects a Work-like return value (it stores it in a list | |
| # then calls .wait() later). Object-store all-reduce is synchronous, | |
| # so the tensor is already averaged when we hand back the wrapper. | |
| averaged = self._store.allreduce(tensor) | |
| return _ImmediateWork(averaged) | |
| # ---- Quorum / commit lifecycle ------------------------------------- | |
| def should_commit(self) -> bool: | |
| # No fault-tolerance failover in serverless mode: every quorum | |
| # always commits. Replica failure is handled by the orchestration | |
| # layer (HF Jobs / Modal restart), not by DiLoCo skipping a round. | |
| return True | |
| def start_quorum(self) -> None: | |
| # The upstream Manager bumps its step counter inside the quorum | |
| # bookkeeping. Do the same so current_step() advances per round | |
| # and DiLoCo's fragment-rotation math matches across replicas. | |
| self._step += 1 | |
| def wait_quorum(self) -> int: | |
| return self.num_participants | |
| # ---- Step counter --------------------------------------------------- | |
| def current_step(self) -> int: | |
| return self._step | |
| # ---- State-dict read gating ---------------------------------------- | |
| # torchft uses these to make checkpoint restore thread-safe. In a | |
| # single-process serverless mock there's no concurrent reader, so they | |
| # are no-ops — but they MUST exist (DiLoCo's pre/post optimizer hooks | |
| # call them on every inner step). | |
| def allow_state_dict_read(self) -> None: | |
| pass | |
| def disallow_state_dict_read(self) -> None: | |
| pass | |
| # ---- Checkpoint hook registry -------------------------------------- | |
| def register_state_dict_fn( | |
| self, | |
| key: str, | |
| load_fn: Any, | |
| save_fn: Any, | |
| ) -> None: | |
| # DiLoCo registers one (load, save) pair per fragment so torchft can | |
| # checkpoint the outer-optimizer state and original-parameter backup. | |
| # In serverless mode we capture the registration so tests can verify | |
| # it happened, but never invoke it — there's no HA failover. | |
| self._state_dict_fns[key] = (load_fn, save_fn) | |
| # ---- Convenience ---------------------------------------------------- | |
| def is_leader(self) -> bool: | |
| # Not strictly required by DiLoCo but referenced in some torchft | |
| # integrations / our own code that may swap MockManager in. | |
| return self.rank == 0 | |
| __all__ = ["MockManager", "ObjectStoreAllReduce", "_ImmediateWork"] | |