Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ServerlessExecutor Protocol + LocalProcessExecutor. | |
| Per ADR-005: | |
| - `ServerlessExecutor` is a structural Protocol — backends implement it | |
| by writing a class with the right methods, no formal inheritance needed. | |
| - `LocalProcessExecutor` is the reference implementation that uses Python's | |
| `multiprocessing` module. It's used for tests and for development; the | |
| cloud adapters (Modal, HF Jobs, …) implement the same Protocol against | |
| their respective backends. | |
| """ | |
| from __future__ import annotations | |
| import multiprocessing as mp | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Mapping, Protocol, runtime_checkable | |
| class ReplicaHandle: | |
| """Opaque handle to a launched replica. Backend-specific contents. | |
| All executors return `list[ReplicaHandle]` from `launch_replicas`. | |
| Each handle's `metadata` dict is backend-specific; users shouldn't | |
| rely on its shape. | |
| """ | |
| rank: int | |
| backend_name: str | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| """Backend-specific data (e.g. Modal call ID, HF Jobs job ID, local | |
| Process object). Not stable across backends.""" | |
| class ServerlessExecutor(Protocol): | |
| """Uniform interface for launching N replicas on a serverless backend. | |
| Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor` | |
| (Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future: | |
| `RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`. | |
| Note on rank assignment: the Protocol guarantees that handles are | |
| returned in rank order (`handles[i].rank == i`). The replica entrypoint | |
| learns its own rank either from an environment variable | |
| (`REPLICA_RANK`) or from a backend-provided mechanism (Modal's | |
| `Function.shard_rank`, etc.). The executor abstraction normalizes | |
| rank by setting the env var. | |
| """ | |
| backend_name: str | |
| supports_inter_replica_network: bool | |
| def launch_replicas( | |
| self, | |
| n_replicas: int, | |
| entrypoint: str | Callable[..., Any], | |
| entrypoint_args: Mapping[str, Any], | |
| *, | |
| gpu: str | None = None, | |
| timeout: int = 3600, | |
| ) -> list[ReplicaHandle]: | |
| """Spin up N replicas in parallel. | |
| Args: | |
| n_replicas: number of replicas to launch | |
| entrypoint: either an importable Python path (e.g. | |
| "composer_replication.diloco.serverless.replica_entrypoint") | |
| or a Callable (Local executor only). | |
| entrypoint_args: kwargs passed to the entrypoint. The kwarg | |
| `rank_env` (default "REPLICA_RANK") names the environment | |
| variable in which the rank will be set on the replica. | |
| gpu: backend-specific GPU spec, e.g. "A100", "H100". `None` | |
| means CPU-only (smoke tests). | |
| timeout: per-replica wall-clock timeout in seconds. | |
| Returns: | |
| list[ReplicaHandle] of length n_replicas, in rank order. | |
| """ | |
| ... | |
| def poll(self, handle: ReplicaHandle) -> str: | |
| """Poll a replica's status. Returns one of: | |
| "pending" | "running" | "succeeded" | "failed" | "cancelled". | |
| """ | |
| ... | |
| def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: | |
| """Read up to n_lines of recent stdout/stderr from a replica.""" | |
| ... | |
| def cancel(self, handle: ReplicaHandle) -> None: | |
| """Best-effort cancel. No exception if already terminated.""" | |
| ... | |
| def collect( | |
| self, | |
| handles: list[ReplicaHandle], | |
| *, | |
| timeout: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Block until all replicas finish; return per-replica result dicts. | |
| Each result dict contains at least: | |
| {"rank": int, "status": str, "exit_code": int | None, | |
| "error": str | None} | |
| """ | |
| ... | |
| # --------------------------------------------------------------------- | |
| # LocalProcessExecutor — reference implementation using multiprocessing | |
| # --------------------------------------------------------------------- | |
| def _local_replica_target( | |
| rank: int, | |
| rank_env: str, | |
| entrypoint: Any, | |
| entrypoint_args: Mapping[str, Any], | |
| result_queue: mp.Queue, | |
| ) -> None: | |
| """multiprocessing target — runs in the child process.""" | |
| import os | |
| import traceback | |
| os.environ[rank_env] = str(rank) | |
| try: | |
| if callable(entrypoint): | |
| result = entrypoint(**entrypoint_args) | |
| elif isinstance(entrypoint, str): | |
| # importable path | |
| mod_path, _, fn_name = entrypoint.rpartition(".") | |
| if not mod_path: | |
| # Top-level script path; just import it and call its main() | |
| import importlib | |
| mod = importlib.import_module(entrypoint) | |
| fn = getattr(mod, "main", None) | |
| if fn is None: | |
| raise RuntimeError( | |
| f"entrypoint '{entrypoint}' has no main() function" | |
| ) | |
| result = fn(**entrypoint_args) | |
| else: | |
| import importlib | |
| mod = importlib.import_module(mod_path) | |
| fn = getattr(mod, fn_name) | |
| result = fn(**entrypoint_args) | |
| else: | |
| raise TypeError( | |
| f"entrypoint must be Callable or importable str, got {type(entrypoint)!r}" | |
| ) | |
| result_queue.put({"rank": rank, "status": "succeeded", | |
| "exit_code": 0, "error": None, "result": result}) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| result_queue.put({"rank": rank, "status": "failed", | |
| "exit_code": 1, "error": f"{e!r}\n{tb}", "result": None}) | |
| class LocalProcessExecutor: | |
| """Runs replicas as subprocesses on the local machine. | |
| Use cases: | |
| - Test the serverless layer end-to-end without cloud spend. | |
| - Develop the algorithm locally with N>1 replicas and `file://` | |
| rendezvous before deploying to Modal/HF Jobs. | |
| - CI smoke tests. | |
| """ | |
| backend_name = "local_process" | |
| supports_inter_replica_network = True # localhost works | |
| def __init__(self) -> None: | |
| # use 'spawn' so the child has a fresh interpreter (avoid CUDA fork issues) | |
| try: | |
| self._ctx = mp.get_context("spawn") | |
| except ValueError: | |
| # Fallback for environments where 'spawn' isn't available | |
| self._ctx = mp.get_context() | |
| self._handles: dict[int, dict[str, Any]] = {} | |
| def launch_replicas( | |
| self, | |
| n_replicas: int, | |
| entrypoint: str | Callable[..., Any], | |
| entrypoint_args: Mapping[str, Any], | |
| *, | |
| gpu: str | None = None, | |
| timeout: int = 3600, | |
| ) -> list[ReplicaHandle]: | |
| if gpu is not None: | |
| # Local executor doesn't pin GPUs; emit a soft warning. | |
| sys.stderr.write( | |
| f"[LocalProcessExecutor] gpu={gpu!r} ignored — " | |
| f"local processes share whatever GPUs are visible.\n" | |
| ) | |
| rank_env = entrypoint_args.get("rank_env", "REPLICA_RANK") | |
| handles: list[ReplicaHandle] = [] | |
| result_queue: mp.Queue = self._ctx.Queue() | |
| for rank in range(n_replicas): | |
| args_for_rank = dict(entrypoint_args) | |
| args_for_rank.pop("rank_env", None) | |
| proc = self._ctx.Process( | |
| target=_local_replica_target, | |
| args=(rank, rank_env, entrypoint, args_for_rank, result_queue), | |
| name=f"composer-replica-{rank}", | |
| ) | |
| proc.start() | |
| handle = ReplicaHandle( | |
| rank=rank, backend_name=self.backend_name, | |
| metadata={"pid": proc.pid, "start_ts": time.time()}, | |
| ) | |
| self._handles[rank] = {"proc": proc, "queue": result_queue, | |
| "deadline": time.time() + timeout, | |
| "result": None} | |
| handles.append(handle) | |
| return handles | |
| def poll(self, handle: ReplicaHandle) -> str: | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return "cancelled" | |
| proc: mp.Process = meta["proc"] | |
| if proc.is_alive(): | |
| return "running" | |
| if meta.get("result") is not None: | |
| return meta["result"]["status"] | |
| # Process exited; read result if available | |
| try: | |
| queue: mp.Queue = meta["queue"] | |
| while not queue.empty(): | |
| r = queue.get_nowait() | |
| self._handles[r["rank"]]["result"] = r | |
| except Exception: | |
| pass | |
| if meta.get("result") is not None: | |
| return meta["result"]["status"] | |
| return "failed" if proc.exitcode != 0 else "succeeded" | |
| def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: | |
| # multiprocessing.Process doesn't natively capture stdout; we'd | |
| # need a Pipe or file redirection. For the local reference impl, | |
| # we just point the user at the result dict's `error` field. | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return f"<replica {handle.rank}: no metadata>" | |
| if meta.get("result"): | |
| err = meta["result"].get("error") or "" | |
| return f"[rank {handle.rank}] {err[-2000:]}" | |
| return f"<replica {handle.rank}: still running, no captured logs>" | |
| def cancel(self, handle: ReplicaHandle) -> None: | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return | |
| proc: mp.Process = meta["proc"] | |
| if proc.is_alive(): | |
| proc.terminate() | |
| proc.join(timeout=5) | |
| if proc.is_alive(): | |
| proc.kill() | |
| def collect( | |
| self, | |
| handles: list[ReplicaHandle], | |
| *, | |
| timeout: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| deadline = time.time() + (timeout if timeout is not None else 3600) | |
| # Wait for all processes to finish | |
| for h in handles: | |
| meta = self._handles.get(h.rank) | |
| if meta is None: | |
| continue | |
| proc: mp.Process = meta["proc"] | |
| remaining = max(0.0, deadline - time.time()) | |
| proc.join(timeout=remaining) | |
| if proc.is_alive(): | |
| proc.terminate() | |
| proc.join(timeout=5) | |
| # Drain results | |
| results_by_rank: dict[int, dict[str, Any]] = {} | |
| for h in handles: | |
| meta = self._handles.get(h.rank) | |
| if meta is None: | |
| results_by_rank[h.rank] = { | |
| "rank": h.rank, "status": "cancelled", | |
| "exit_code": None, "error": "no metadata", "result": None, | |
| } | |
| continue | |
| queue: mp.Queue = meta["queue"] | |
| while not queue.empty(): | |
| try: | |
| r = queue.get_nowait() | |
| results_by_rank[r["rank"]] = r | |
| except Exception: | |
| break | |
| if h.rank not in results_by_rank: | |
| proc: mp.Process = meta["proc"] | |
| results_by_rank[h.rank] = { | |
| "rank": h.rank, | |
| "status": "succeeded" if proc.exitcode == 0 else "failed", | |
| "exit_code": proc.exitcode, | |
| "error": None if proc.exitcode == 0 else f"exit code {proc.exitcode}", | |
| "result": None, | |
| } | |
| return [results_by_rank[h.rank] for h in handles] | |
| __all__ = ["LocalProcessExecutor", "ReplicaHandle", "ServerlessExecutor"] | |