"""Replica entrypoint — what each serverless replica runs. This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`, `HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI, wraps it in a `MockManager`, and hands it off to the user's training function. Usage from an executor: >>> executor.launch_replicas( ... n_replicas=4, ... entrypoint="composer_replication.diloco.serverless.replica_entrypoint", ... entrypoint_args={ ... "rendezvous_uri": "/tmp/run42/", ... "world_size": 4, ... "trainer_module": "my_project.trainer", ... "trainer_fn": "train", ... "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"}, ... }, ... ) The entrypoint expects: - `REPLICA_RANK` env var set to the rank (0..world_size-1) - `rendezvous_uri`: fsspec URI for object-store rendezvous - `world_size`: total replicas - `trainer_module`, `trainer_fn`: importable path to the user's train fn - `trainer_kwargs`: dict passed to the user's train fn, plus an injected `manager` kwarg containing the `MockManager` """ from __future__ import annotations import importlib import os from typing import Any def main( rendezvous_uri: str, world_size: int, trainer_module: str, trainer_fn: str = "train", trainer_kwargs: dict[str, Any] | None = None, ) -> Any: """Entrypoint executed inside each replica. Args: rendezvous_uri: fsspec URI (or local path) for the rendezvous world_size: total replicas trainer_module: importable Python module containing the user's train function trainer_fn: name of the function to call (default "train") trainer_kwargs: kwargs passed to the train function Returns: Whatever the train function returns. """ from composer_replication.diloco.serverless.allreduce import ( MockManager, ObjectStoreAllReduce, ) rank_str = os.environ.get("REPLICA_RANK") if rank_str is None: raise RuntimeError( "REPLICA_RANK env var not set. The serverless executor " "should set this for each replica." ) rank = int(rank_str) if not (0 <= rank < world_size): raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})") store = ObjectStoreAllReduce( uri=rendezvous_uri, rank=rank, world_size=world_size, ) manager = MockManager(store) mod = importlib.import_module(trainer_module) fn = getattr(mod, trainer_fn) kwargs = dict(trainer_kwargs or {}) kwargs["manager"] = manager # injected kwargs["rank"] = rank kwargs["world_size"] = world_size return fn(**kwargs) if __name__ == "__main__": import argparse import json parser = argparse.ArgumentParser() parser.add_argument("--rendezvous", required=True) parser.add_argument("--world-size", type=int, required=True) parser.add_argument("--trainer-module", required=True) parser.add_argument("--trainer-fn", default="train") parser.add_argument("--trainer-kwargs-json", default="{}") args = parser.parse_args() main( rendezvous_uri=args.rendezvous, world_size=args.world_size, trainer_module=args.trainer_module, trainer_fn=args.trainer_fn, trainer_kwargs=json.loads(args.trainer_kwargs_json), )