"""ServerlessExecutor Protocol + LocalProcessExecutor. Per ADR-005: - `ServerlessExecutor` is a structural Protocol — backends implement it by writing a class with the right methods, no formal inheritance needed. - `LocalProcessExecutor` is the reference implementation that uses Python's `multiprocessing` module. It's used for tests and for development; the cloud adapters (Modal, HF Jobs, …) implement the same Protocol against their respective backends. """ from __future__ import annotations import multiprocessing as mp import sys import time from dataclasses import dataclass, field from typing import Any, Callable, Mapping, Protocol, runtime_checkable @dataclass class ReplicaHandle: """Opaque handle to a launched replica. Backend-specific contents. All executors return `list[ReplicaHandle]` from `launch_replicas`. Each handle's `metadata` dict is backend-specific; users shouldn't rely on its shape. """ rank: int backend_name: str metadata: dict[str, Any] = field(default_factory=dict) """Backend-specific data (e.g. Modal call ID, HF Jobs job ID, local Process object). Not stable across backends.""" @runtime_checkable class ServerlessExecutor(Protocol): """Uniform interface for launching N replicas on a serverless backend. Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor` (Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future: `RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`. Note on rank assignment: the Protocol guarantees that handles are returned in rank order (`handles[i].rank == i`). The replica entrypoint learns its own rank either from an environment variable (`REPLICA_RANK`) or from a backend-provided mechanism (Modal's `Function.shard_rank`, etc.). The executor abstraction normalizes rank by setting the env var. """ backend_name: str supports_inter_replica_network: bool def launch_replicas( self, n_replicas: int, entrypoint: str | Callable[..., Any], entrypoint_args: Mapping[str, Any], *, gpu: str | None = None, timeout: int = 3600, ) -> list[ReplicaHandle]: """Spin up N replicas in parallel. Args: n_replicas: number of replicas to launch entrypoint: either an importable Python path (e.g. "composer_replication.diloco.serverless.replica_entrypoint") or a Callable (Local executor only). entrypoint_args: kwargs passed to the entrypoint. The kwarg `rank_env` (default "REPLICA_RANK") names the environment variable in which the rank will be set on the replica. gpu: backend-specific GPU spec, e.g. "A100", "H100". `None` means CPU-only (smoke tests). timeout: per-replica wall-clock timeout in seconds. Returns: list[ReplicaHandle] of length n_replicas, in rank order. """ ... def poll(self, handle: ReplicaHandle) -> str: """Poll a replica's status. Returns one of: "pending" | "running" | "succeeded" | "failed" | "cancelled". """ ... def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: """Read up to n_lines of recent stdout/stderr from a replica.""" ... def cancel(self, handle: ReplicaHandle) -> None: """Best-effort cancel. No exception if already terminated.""" ... def collect( self, handles: list[ReplicaHandle], *, timeout: int | None = None, ) -> list[dict[str, Any]]: """Block until all replicas finish; return per-replica result dicts. Each result dict contains at least: {"rank": int, "status": str, "exit_code": int | None, "error": str | None} """ ... # --------------------------------------------------------------------- # LocalProcessExecutor — reference implementation using multiprocessing # --------------------------------------------------------------------- def _local_replica_target( rank: int, rank_env: str, entrypoint: Any, entrypoint_args: Mapping[str, Any], result_queue: mp.Queue, ) -> None: """multiprocessing target — runs in the child process.""" import os import traceback os.environ[rank_env] = str(rank) try: if callable(entrypoint): result = entrypoint(**entrypoint_args) elif isinstance(entrypoint, str): # importable path mod_path, _, fn_name = entrypoint.rpartition(".") if not mod_path: # Top-level script path; just import it and call its main() import importlib mod = importlib.import_module(entrypoint) fn = getattr(mod, "main", None) if fn is None: raise RuntimeError( f"entrypoint '{entrypoint}' has no main() function" ) result = fn(**entrypoint_args) else: import importlib mod = importlib.import_module(mod_path) fn = getattr(mod, fn_name) result = fn(**entrypoint_args) else: raise TypeError( f"entrypoint must be Callable or importable str, got {type(entrypoint)!r}" ) result_queue.put({"rank": rank, "status": "succeeded", "exit_code": 0, "error": None, "result": result}) except Exception as e: tb = traceback.format_exc() result_queue.put({"rank": rank, "status": "failed", "exit_code": 1, "error": f"{e!r}\n{tb}", "result": None}) class LocalProcessExecutor: """Runs replicas as subprocesses on the local machine. Use cases: - Test the serverless layer end-to-end without cloud spend. - Develop the algorithm locally with N>1 replicas and `file://` rendezvous before deploying to Modal/HF Jobs. - CI smoke tests. """ backend_name = "local_process" supports_inter_replica_network = True # localhost works def __init__(self) -> None: # use 'spawn' so the child has a fresh interpreter (avoid CUDA fork issues) try: self._ctx = mp.get_context("spawn") except ValueError: # Fallback for environments where 'spawn' isn't available self._ctx = mp.get_context() self._handles: dict[int, dict[str, Any]] = {} def launch_replicas( self, n_replicas: int, entrypoint: str | Callable[..., Any], entrypoint_args: Mapping[str, Any], *, gpu: str | None = None, timeout: int = 3600, ) -> list[ReplicaHandle]: if gpu is not None: # Local executor doesn't pin GPUs; emit a soft warning. sys.stderr.write( f"[LocalProcessExecutor] gpu={gpu!r} ignored — " f"local processes share whatever GPUs are visible.\n" ) rank_env = entrypoint_args.get("rank_env", "REPLICA_RANK") handles: list[ReplicaHandle] = [] result_queue: mp.Queue = self._ctx.Queue() for rank in range(n_replicas): args_for_rank = dict(entrypoint_args) args_for_rank.pop("rank_env", None) proc = self._ctx.Process( target=_local_replica_target, args=(rank, rank_env, entrypoint, args_for_rank, result_queue), name=f"composer-replica-{rank}", ) proc.start() handle = ReplicaHandle( rank=rank, backend_name=self.backend_name, metadata={"pid": proc.pid, "start_ts": time.time()}, ) self._handles[rank] = {"proc": proc, "queue": result_queue, "deadline": time.time() + timeout, "result": None} handles.append(handle) return handles def poll(self, handle: ReplicaHandle) -> str: meta = self._handles.get(handle.rank) if meta is None: return "cancelled" proc: mp.Process = meta["proc"] if proc.is_alive(): return "running" if meta.get("result") is not None: return meta["result"]["status"] # Process exited; read result if available try: queue: mp.Queue = meta["queue"] while not queue.empty(): r = queue.get_nowait() self._handles[r["rank"]]["result"] = r except Exception: pass if meta.get("result") is not None: return meta["result"]["status"] return "failed" if proc.exitcode != 0 else "succeeded" def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: # multiprocessing.Process doesn't natively capture stdout; we'd # need a Pipe or file redirection. For the local reference impl, # we just point the user at the result dict's `error` field. meta = self._handles.get(handle.rank) if meta is None: return f"" if meta.get("result"): err = meta["result"].get("error") or "" return f"[rank {handle.rank}] {err[-2000:]}" return f"" def cancel(self, handle: ReplicaHandle) -> None: meta = self._handles.get(handle.rank) if meta is None: return proc: mp.Process = meta["proc"] if proc.is_alive(): proc.terminate() proc.join(timeout=5) if proc.is_alive(): proc.kill() def collect( self, handles: list[ReplicaHandle], *, timeout: int | None = None, ) -> list[dict[str, Any]]: deadline = time.time() + (timeout if timeout is not None else 3600) # Wait for all processes to finish for h in handles: meta = self._handles.get(h.rank) if meta is None: continue proc: mp.Process = meta["proc"] remaining = max(0.0, deadline - time.time()) proc.join(timeout=remaining) if proc.is_alive(): proc.terminate() proc.join(timeout=5) # Drain results results_by_rank: dict[int, dict[str, Any]] = {} for h in handles: meta = self._handles.get(h.rank) if meta is None: results_by_rank[h.rank] = { "rank": h.rank, "status": "cancelled", "exit_code": None, "error": "no metadata", "result": None, } continue queue: mp.Queue = meta["queue"] while not queue.empty(): try: r = queue.get_nowait() results_by_rank[r["rank"]] = r except Exception: break if h.rank not in results_by_rank: proc: mp.Process = meta["proc"] results_by_rank[h.rank] = { "rank": h.rank, "status": "succeeded" if proc.exitcode == 0 else "failed", "exit_code": proc.exitcode, "error": None if proc.exitcode == 0 else f"exit code {proc.exitcode}", "result": None, } return [results_by_rank[h.rank] for h in handles] __all__ = ["LocalProcessExecutor", "ReplicaHandle", "ServerlessExecutor"]