Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 11,732 Bytes
b266c31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 | """ServerlessExecutor Protocol + LocalProcessExecutor.
Per ADR-005:
- `ServerlessExecutor` is a structural Protocol — backends implement it
by writing a class with the right methods, no formal inheritance needed.
- `LocalProcessExecutor` is the reference implementation that uses Python's
`multiprocessing` module. It's used for tests and for development; the
cloud adapters (Modal, HF Jobs, …) implement the same Protocol against
their respective backends.
"""
from __future__ import annotations
import multiprocessing as mp
import sys
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Mapping, Protocol, runtime_checkable
@dataclass
class ReplicaHandle:
"""Opaque handle to a launched replica. Backend-specific contents.
All executors return `list[ReplicaHandle]` from `launch_replicas`.
Each handle's `metadata` dict is backend-specific; users shouldn't
rely on its shape.
"""
rank: int
backend_name: str
metadata: dict[str, Any] = field(default_factory=dict)
"""Backend-specific data (e.g. Modal call ID, HF Jobs job ID, local
Process object). Not stable across backends."""
@runtime_checkable
class ServerlessExecutor(Protocol):
"""Uniform interface for launching N replicas on a serverless backend.
Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor`
(Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future:
`RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`.
Note on rank assignment: the Protocol guarantees that handles are
returned in rank order (`handles[i].rank == i`). The replica entrypoint
learns its own rank either from an environment variable
(`REPLICA_RANK`) or from a backend-provided mechanism (Modal's
`Function.shard_rank`, etc.). The executor abstraction normalizes
rank by setting the env var.
"""
backend_name: str
supports_inter_replica_network: bool
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
"""Spin up N replicas in parallel.
Args:
n_replicas: number of replicas to launch
entrypoint: either an importable Python path (e.g.
"composer_replication.diloco.serverless.replica_entrypoint")
or a Callable (Local executor only).
entrypoint_args: kwargs passed to the entrypoint. The kwarg
`rank_env` (default "REPLICA_RANK") names the environment
variable in which the rank will be set on the replica.
gpu: backend-specific GPU spec, e.g. "A100", "H100". `None`
means CPU-only (smoke tests).
timeout: per-replica wall-clock timeout in seconds.
Returns:
list[ReplicaHandle] of length n_replicas, in rank order.
"""
...
def poll(self, handle: ReplicaHandle) -> str:
"""Poll a replica's status. Returns one of:
"pending" | "running" | "succeeded" | "failed" | "cancelled".
"""
...
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
"""Read up to n_lines of recent stdout/stderr from a replica."""
...
def cancel(self, handle: ReplicaHandle) -> None:
"""Best-effort cancel. No exception if already terminated."""
...
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
"""Block until all replicas finish; return per-replica result dicts.
Each result dict contains at least:
{"rank": int, "status": str, "exit_code": int | None,
"error": str | None}
"""
...
# ---------------------------------------------------------------------
# LocalProcessExecutor — reference implementation using multiprocessing
# ---------------------------------------------------------------------
def _local_replica_target(
rank: int,
rank_env: str,
entrypoint: Any,
entrypoint_args: Mapping[str, Any],
result_queue: mp.Queue,
) -> None:
"""multiprocessing target — runs in the child process."""
import os
import traceback
os.environ[rank_env] = str(rank)
try:
if callable(entrypoint):
result = entrypoint(**entrypoint_args)
elif isinstance(entrypoint, str):
# importable path
mod_path, _, fn_name = entrypoint.rpartition(".")
if not mod_path:
# Top-level script path; just import it and call its main()
import importlib
mod = importlib.import_module(entrypoint)
fn = getattr(mod, "main", None)
if fn is None:
raise RuntimeError(
f"entrypoint '{entrypoint}' has no main() function"
)
result = fn(**entrypoint_args)
else:
import importlib
mod = importlib.import_module(mod_path)
fn = getattr(mod, fn_name)
result = fn(**entrypoint_args)
else:
raise TypeError(
f"entrypoint must be Callable or importable str, got {type(entrypoint)!r}"
)
result_queue.put({"rank": rank, "status": "succeeded",
"exit_code": 0, "error": None, "result": result})
except Exception as e:
tb = traceback.format_exc()
result_queue.put({"rank": rank, "status": "failed",
"exit_code": 1, "error": f"{e!r}\n{tb}", "result": None})
class LocalProcessExecutor:
"""Runs replicas as subprocesses on the local machine.
Use cases:
- Test the serverless layer end-to-end without cloud spend.
- Develop the algorithm locally with N>1 replicas and `file://`
rendezvous before deploying to Modal/HF Jobs.
- CI smoke tests.
"""
backend_name = "local_process"
supports_inter_replica_network = True # localhost works
def __init__(self) -> None:
# use 'spawn' so the child has a fresh interpreter (avoid CUDA fork issues)
try:
self._ctx = mp.get_context("spawn")
except ValueError:
# Fallback for environments where 'spawn' isn't available
self._ctx = mp.get_context()
self._handles: dict[int, dict[str, Any]] = {}
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
if gpu is not None:
# Local executor doesn't pin GPUs; emit a soft warning.
sys.stderr.write(
f"[LocalProcessExecutor] gpu={gpu!r} ignored — "
f"local processes share whatever GPUs are visible.\n"
)
rank_env = entrypoint_args.get("rank_env", "REPLICA_RANK")
handles: list[ReplicaHandle] = []
result_queue: mp.Queue = self._ctx.Queue()
for rank in range(n_replicas):
args_for_rank = dict(entrypoint_args)
args_for_rank.pop("rank_env", None)
proc = self._ctx.Process(
target=_local_replica_target,
args=(rank, rank_env, entrypoint, args_for_rank, result_queue),
name=f"composer-replica-{rank}",
)
proc.start()
handle = ReplicaHandle(
rank=rank, backend_name=self.backend_name,
metadata={"pid": proc.pid, "start_ts": time.time()},
)
self._handles[rank] = {"proc": proc, "queue": result_queue,
"deadline": time.time() + timeout,
"result": None}
handles.append(handle)
return handles
def poll(self, handle: ReplicaHandle) -> str:
meta = self._handles.get(handle.rank)
if meta is None:
return "cancelled"
proc: mp.Process = meta["proc"]
if proc.is_alive():
return "running"
if meta.get("result") is not None:
return meta["result"]["status"]
# Process exited; read result if available
try:
queue: mp.Queue = meta["queue"]
while not queue.empty():
r = queue.get_nowait()
self._handles[r["rank"]]["result"] = r
except Exception:
pass
if meta.get("result") is not None:
return meta["result"]["status"]
return "failed" if proc.exitcode != 0 else "succeeded"
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
# multiprocessing.Process doesn't natively capture stdout; we'd
# need a Pipe or file redirection. For the local reference impl,
# we just point the user at the result dict's `error` field.
meta = self._handles.get(handle.rank)
if meta is None:
return f"<replica {handle.rank}: no metadata>"
if meta.get("result"):
err = meta["result"].get("error") or ""
return f"[rank {handle.rank}] {err[-2000:]}"
return f"<replica {handle.rank}: still running, no captured logs>"
def cancel(self, handle: ReplicaHandle) -> None:
meta = self._handles.get(handle.rank)
if meta is None:
return
proc: mp.Process = meta["proc"]
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
deadline = time.time() + (timeout if timeout is not None else 3600)
# Wait for all processes to finish
for h in handles:
meta = self._handles.get(h.rank)
if meta is None:
continue
proc: mp.Process = meta["proc"]
remaining = max(0.0, deadline - time.time())
proc.join(timeout=remaining)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
# Drain results
results_by_rank: dict[int, dict[str, Any]] = {}
for h in handles:
meta = self._handles.get(h.rank)
if meta is None:
results_by_rank[h.rank] = {
"rank": h.rank, "status": "cancelled",
"exit_code": None, "error": "no metadata", "result": None,
}
continue
queue: mp.Queue = meta["queue"]
while not queue.empty():
try:
r = queue.get_nowait()
results_by_rank[r["rank"]] = r
except Exception:
break
if h.rank not in results_by_rank:
proc: mp.Process = meta["proc"]
results_by_rank[h.rank] = {
"rank": h.rank,
"status": "succeeded" if proc.exitcode == 0 else "failed",
"exit_code": proc.exitcode,
"error": None if proc.exitcode == 0 else f"exit code {proc.exitcode}",
"result": None,
}
return [results_by_rank[h.rank] for h in handles]
__all__ = ["LocalProcessExecutor", "ReplicaHandle", "ServerlessExecutor"]
|