Codeseys's picture
Wave 13: serverless DiLoCo + replaysim normalization + 3 distillation losses + PRIME-RL + Monarch
b266c31
"""ServerlessExecutor Protocol + LocalProcessExecutor.
Per ADR-005:
- `ServerlessExecutor` is a structural Protocol — backends implement it
by writing a class with the right methods, no formal inheritance needed.
- `LocalProcessExecutor` is the reference implementation that uses Python's
`multiprocessing` module. It's used for tests and for development; the
cloud adapters (Modal, HF Jobs, …) implement the same Protocol against
their respective backends.
"""
from __future__ import annotations
import multiprocessing as mp
import sys
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Mapping, Protocol, runtime_checkable
@dataclass
class ReplicaHandle:
"""Opaque handle to a launched replica. Backend-specific contents.
All executors return `list[ReplicaHandle]` from `launch_replicas`.
Each handle's `metadata` dict is backend-specific; users shouldn't
rely on its shape.
"""
rank: int
backend_name: str
metadata: dict[str, Any] = field(default_factory=dict)
"""Backend-specific data (e.g. Modal call ID, HF Jobs job ID, local
Process object). Not stable across backends."""
@runtime_checkable
class ServerlessExecutor(Protocol):
"""Uniform interface for launching N replicas on a serverless backend.
Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor`
(Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future:
`RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`.
Note on rank assignment: the Protocol guarantees that handles are
returned in rank order (`handles[i].rank == i`). The replica entrypoint
learns its own rank either from an environment variable
(`REPLICA_RANK`) or from a backend-provided mechanism (Modal's
`Function.shard_rank`, etc.). The executor abstraction normalizes
rank by setting the env var.
"""
backend_name: str
supports_inter_replica_network: bool
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
"""Spin up N replicas in parallel.
Args:
n_replicas: number of replicas to launch
entrypoint: either an importable Python path (e.g.
"composer_replication.diloco.serverless.replica_entrypoint")
or a Callable (Local executor only).
entrypoint_args: kwargs passed to the entrypoint. The kwarg
`rank_env` (default "REPLICA_RANK") names the environment
variable in which the rank will be set on the replica.
gpu: backend-specific GPU spec, e.g. "A100", "H100". `None`
means CPU-only (smoke tests).
timeout: per-replica wall-clock timeout in seconds.
Returns:
list[ReplicaHandle] of length n_replicas, in rank order.
"""
...
def poll(self, handle: ReplicaHandle) -> str:
"""Poll a replica's status. Returns one of:
"pending" | "running" | "succeeded" | "failed" | "cancelled".
"""
...
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
"""Read up to n_lines of recent stdout/stderr from a replica."""
...
def cancel(self, handle: ReplicaHandle) -> None:
"""Best-effort cancel. No exception if already terminated."""
...
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
"""Block until all replicas finish; return per-replica result dicts.
Each result dict contains at least:
{"rank": int, "status": str, "exit_code": int | None,
"error": str | None}
"""
...
# ---------------------------------------------------------------------
# LocalProcessExecutor — reference implementation using multiprocessing
# ---------------------------------------------------------------------
def _local_replica_target(
rank: int,
rank_env: str,
entrypoint: Any,
entrypoint_args: Mapping[str, Any],
result_queue: mp.Queue,
) -> None:
"""multiprocessing target — runs in the child process."""
import os
import traceback
os.environ[rank_env] = str(rank)
try:
if callable(entrypoint):
result = entrypoint(**entrypoint_args)
elif isinstance(entrypoint, str):
# importable path
mod_path, _, fn_name = entrypoint.rpartition(".")
if not mod_path:
# Top-level script path; just import it and call its main()
import importlib
mod = importlib.import_module(entrypoint)
fn = getattr(mod, "main", None)
if fn is None:
raise RuntimeError(
f"entrypoint '{entrypoint}' has no main() function"
)
result = fn(**entrypoint_args)
else:
import importlib
mod = importlib.import_module(mod_path)
fn = getattr(mod, fn_name)
result = fn(**entrypoint_args)
else:
raise TypeError(
f"entrypoint must be Callable or importable str, got {type(entrypoint)!r}"
)
result_queue.put({"rank": rank, "status": "succeeded",
"exit_code": 0, "error": None, "result": result})
except Exception as e:
tb = traceback.format_exc()
result_queue.put({"rank": rank, "status": "failed",
"exit_code": 1, "error": f"{e!r}\n{tb}", "result": None})
class LocalProcessExecutor:
"""Runs replicas as subprocesses on the local machine.
Use cases:
- Test the serverless layer end-to-end without cloud spend.
- Develop the algorithm locally with N>1 replicas and `file://`
rendezvous before deploying to Modal/HF Jobs.
- CI smoke tests.
"""
backend_name = "local_process"
supports_inter_replica_network = True # localhost works
def __init__(self) -> None:
# use 'spawn' so the child has a fresh interpreter (avoid CUDA fork issues)
try:
self._ctx = mp.get_context("spawn")
except ValueError:
# Fallback for environments where 'spawn' isn't available
self._ctx = mp.get_context()
self._handles: dict[int, dict[str, Any]] = {}
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
if gpu is not None:
# Local executor doesn't pin GPUs; emit a soft warning.
sys.stderr.write(
f"[LocalProcessExecutor] gpu={gpu!r} ignored — "
f"local processes share whatever GPUs are visible.\n"
)
rank_env = entrypoint_args.get("rank_env", "REPLICA_RANK")
handles: list[ReplicaHandle] = []
result_queue: mp.Queue = self._ctx.Queue()
for rank in range(n_replicas):
args_for_rank = dict(entrypoint_args)
args_for_rank.pop("rank_env", None)
proc = self._ctx.Process(
target=_local_replica_target,
args=(rank, rank_env, entrypoint, args_for_rank, result_queue),
name=f"composer-replica-{rank}",
)
proc.start()
handle = ReplicaHandle(
rank=rank, backend_name=self.backend_name,
metadata={"pid": proc.pid, "start_ts": time.time()},
)
self._handles[rank] = {"proc": proc, "queue": result_queue,
"deadline": time.time() + timeout,
"result": None}
handles.append(handle)
return handles
def poll(self, handle: ReplicaHandle) -> str:
meta = self._handles.get(handle.rank)
if meta is None:
return "cancelled"
proc: mp.Process = meta["proc"]
if proc.is_alive():
return "running"
if meta.get("result") is not None:
return meta["result"]["status"]
# Process exited; read result if available
try:
queue: mp.Queue = meta["queue"]
while not queue.empty():
r = queue.get_nowait()
self._handles[r["rank"]]["result"] = r
except Exception:
pass
if meta.get("result") is not None:
return meta["result"]["status"]
return "failed" if proc.exitcode != 0 else "succeeded"
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
# multiprocessing.Process doesn't natively capture stdout; we'd
# need a Pipe or file redirection. For the local reference impl,
# we just point the user at the result dict's `error` field.
meta = self._handles.get(handle.rank)
if meta is None:
return f"<replica {handle.rank}: no metadata>"
if meta.get("result"):
err = meta["result"].get("error") or ""
return f"[rank {handle.rank}] {err[-2000:]}"
return f"<replica {handle.rank}: still running, no captured logs>"
def cancel(self, handle: ReplicaHandle) -> None:
meta = self._handles.get(handle.rank)
if meta is None:
return
proc: mp.Process = meta["proc"]
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
deadline = time.time() + (timeout if timeout is not None else 3600)
# Wait for all processes to finish
for h in handles:
meta = self._handles.get(h.rank)
if meta is None:
continue
proc: mp.Process = meta["proc"]
remaining = max(0.0, deadline - time.time())
proc.join(timeout=remaining)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
# Drain results
results_by_rank: dict[int, dict[str, Any]] = {}
for h in handles:
meta = self._handles.get(h.rank)
if meta is None:
results_by_rank[h.rank] = {
"rank": h.rank, "status": "cancelled",
"exit_code": None, "error": "no metadata", "result": None,
}
continue
queue: mp.Queue = meta["queue"]
while not queue.empty():
try:
r = queue.get_nowait()
results_by_rank[r["rank"]] = r
except Exception:
break
if h.rank not in results_by_rank:
proc: mp.Process = meta["proc"]
results_by_rank[h.rank] = {
"rank": h.rank,
"status": "succeeded" if proc.exitcode == 0 else "failed",
"exit_code": proc.exitcode,
"error": None if proc.exitcode == 0 else f"exit code {proc.exitcode}",
"result": None,
}
return [results_by_rank[h.rank] for h in handles]
__all__ = ["LocalProcessExecutor", "ReplicaHandle", "ServerlessExecutor"]