File size: 3,480 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Replica entrypoint — what each serverless replica runs.

This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`,
`HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env
var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI,
wraps it in a `MockManager`, and hands it off to the user's training
function.

Usage from an executor:

    >>> executor.launch_replicas(
    ...     n_replicas=4,
    ...     entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
    ...     entrypoint_args={
    ...         "rendezvous_uri": "/tmp/run42/",
    ...         "world_size": 4,
    ...         "trainer_module": "my_project.trainer",
    ...         "trainer_fn": "train",
    ...         "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"},
    ...     },
    ... )

The entrypoint expects:
- `REPLICA_RANK` env var set to the rank (0..world_size-1)
- `rendezvous_uri`: fsspec URI for object-store rendezvous
- `world_size`: total replicas
- `trainer_module`, `trainer_fn`: importable path to the user's train fn
- `trainer_kwargs`: dict passed to the user's train fn, plus an injected
  `manager` kwarg containing the `MockManager`
"""
from __future__ import annotations

import importlib
import os
from typing import Any


def main(
    rendezvous_uri: str,
    world_size: int,
    trainer_module: str,
    trainer_fn: str = "train",
    trainer_kwargs: dict[str, Any] | None = None,
) -> Any:
    """Entrypoint executed inside each replica.

    Args:
        rendezvous_uri: fsspec URI (or local path) for the rendezvous
        world_size: total replicas
        trainer_module: importable Python module containing the user's
            train function
        trainer_fn: name of the function to call (default "train")
        trainer_kwargs: kwargs passed to the train function

    Returns:
        Whatever the train function returns.
    """
    from composer_replication.diloco.serverless.allreduce import (
        MockManager,
        ObjectStoreAllReduce,
    )

    rank_str = os.environ.get("REPLICA_RANK")
    if rank_str is None:
        raise RuntimeError(
            "REPLICA_RANK env var not set. The serverless executor "
            "should set this for each replica."
        )
    rank = int(rank_str)

    if not (0 <= rank < world_size):
        raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})")

    store = ObjectStoreAllReduce(
        uri=rendezvous_uri,
        rank=rank,
        world_size=world_size,
    )
    manager = MockManager(store)

    mod = importlib.import_module(trainer_module)
    fn = getattr(mod, trainer_fn)

    kwargs = dict(trainer_kwargs or {})
    kwargs["manager"] = manager  # injected
    kwargs["rank"] = rank
    kwargs["world_size"] = world_size
    return fn(**kwargs)


if __name__ == "__main__":
    import argparse
    import json

    parser = argparse.ArgumentParser()
    parser.add_argument("--rendezvous", required=True)
    parser.add_argument("--world-size", type=int, required=True)
    parser.add_argument("--trainer-module", required=True)
    parser.add_argument("--trainer-fn", default="train")
    parser.add_argument("--trainer-kwargs-json", default="{}")
    args = parser.parse_args()

    main(
        rendezvous_uri=args.rendezvous,
        world_size=args.world_size,
        trainer_module=args.trainer_module,
        trainer_fn=args.trainer_fn,
        trainer_kwargs=json.loads(args.trainer_kwargs_json),
    )