Codeseys commited on
Commit
a384097
·
1 Parent(s): 03bf323

Wave 20: ModalSpawnExecutor — finish the Modal-backed serverless executor

Browse files

Wave 17 shipped a NotImplementedError skeleton for ModalExecutor pinned by
test_skeleton_executors.py. Wave 18 hardened the skeleton contract to be
strict per-code-path. This wave adds the real Modal-backed executor as
ModalSpawnExecutor (sibling class), preserving the skeleton contract for
backward compat while making DiLoCo across Modal containers actually
runnable.

Design choices vs the v0 skeleton's docstring:

1. **User-provided modal.Function instead of internal app construction.**
The skeleton showed a pattern where the executor builds its own
modal.App and registers run_replica internally. That couples the
executor to image/GPU/Volume choices the user actually wants to own.
ModalSpawnExecutor takes a *pre-decorated* modal.Function from the
caller. The user defines:

@app .function(gpu="H100:4", image=img, volumes={"/vol": vol},
secrets=[...], timeout=4*3600)
def diloco_replica(rank: int, rendezvous_uri: str, world_size: int, **kw):
os.environ["REPLICA_RANK"] = str(rank)
store = ObjectStoreAllReduce(rendezvous_uri, rank=rank, world_size=world_size)
manager = MockManager(store)
# ... user's training loop ...

then constructs:

executor = ModalSpawnExecutor(modal_function=diloco_replica)
handles = executor.launch_replicas(n_replicas=2, ...)
results = executor.collect(handles, timeout=4*3600)

2. **Rank as explicit kwarg, not env-var indirection.** Modal Functions
start with a clean env, so the rank-via-env pattern that
LocalProcessExecutor uses is fragile here. We pass rank as a kwarg to
.spawn(rank=i) so it's plumbed through Modal's call args directly.

3. **Stateless after launch.** Handles wrap FunctionCall.object_id strings,
so poll/cancel/collect use modal.FunctionCall.from_id(...) and survive
process restart. Lets the executor outlive its process.

Coverage:
- 14 new unit tests in test_modal_spawn_executor.py using a mock modal.Function
- Construction guards: rejects non-Function inputs, requires modal client
- launch_replicas: spawns N times w/ explicit rank kwarg, strips rank_env,
rejects ≤0 replicas, cancels siblings on partial failure
- poll: succeeded/running/failed branches all exercised
- collect: caches results so it never calls .get() twice, returns per-replica
dicts in rank order, handles TimeoutError + user-code exceptions
- Skeleton-executor test still passing (Wave 18 contract preserved)

Test totals: 151 pass / 3 skipped / 0 fail (was 137 / 3 / 0 before).

End-to-end Modal validation: composer-modal-wave-a app deployed at
https://modal.com/apps/baladithyab/main/ap-7kIV5YJuRHLd5SOmTR1xou with
all 9 functions (smoke + 7 stages + diloco_replica + stage_5_diloco_pretrain)
registered. The ModalSpawnExecutor import + construction is verified by
smoke_image_build's 7-check ✓ — composer_replication, ModalSpawnExecutor,
torchft (DiLoCo dep), Volume R/W all green inside the live Modal container.

Followups for Wave 21:
- Real inner-step DiLoCo: patch nanochat's base_train.py to call
torchft.local_sgd.DiLoCo every H steps (currently stage_5 is FedAvg-with-
H=∞; the harness validates but the algorithm doesn't)
- HF Jobs equivalent (HFJobsExecutor still skeleton)
- Live-Modal end-to-end test (currently mocked; gated by spend cap)

composer_replication/diloco/serverless/__init__.py CHANGED
@@ -54,12 +54,14 @@ from composer_replication.diloco.serverless.executor import (
54
  )
55
  from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
56
  from composer_replication.diloco.serverless.modal import ModalExecutor
 
57
 
58
  __all__ = [
59
  "HFJobsExecutor",
60
  "LocalProcessExecutor",
61
  "MockManager",
62
  "ModalExecutor",
 
63
  "ObjectStoreAllReduce",
64
  "ReplicaHandle",
65
  "ServerlessExecutor",
 
54
  )
55
  from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
56
  from composer_replication.diloco.serverless.modal import ModalExecutor
57
+ from composer_replication.diloco.serverless.modal_spawn import ModalSpawnExecutor
58
 
59
  __all__ = [
60
  "HFJobsExecutor",
61
  "LocalProcessExecutor",
62
  "MockManager",
63
  "ModalExecutor",
64
+ "ModalSpawnExecutor",
65
  "ObjectStoreAllReduce",
66
  "ReplicaHandle",
67
  "ServerlessExecutor",
composer_replication/diloco/serverless/modal_spawn.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ModalSpawnExecutor — production Modal-backed serverless executor.
2
+
3
+ This is the v0-finished sibling of `ModalExecutor` (which remains a
4
+ skeleton per Wave 18 contract). The skeleton class stays unchanged to
5
+ preserve `test_skeleton_executors.py`'s pinned NotImplementedError
6
+ contract; this class is the working alternative for users who want
7
+ real Modal execution.
8
+
9
+ Design choices vs the skeleton's docstring:
10
+
11
+ 1. **User-provided `modal.Function` instead of internal app construction.**
12
+ The skeleton showed a pattern where ModalExecutor builds its own
13
+ `modal.App` and registers `run_replica` internally. That couples the
14
+ executor to image/GPU/Volume choices the user actually wants to own.
15
+ Instead, ModalSpawnExecutor takes a *pre-decorated* `modal.Function`
16
+ from the caller — the user defines:
17
+
18
+ @app.function(gpu="H100:4", image=my_image, volumes={"/vol": vol},
19
+ secrets=[modal.Secret.from_name("hf-token")],
20
+ timeout=4*3600)
21
+ def run_replica(rendezvous_uri: str, world_size: int,
22
+ rank: int, **entrypoint_args):
23
+ import os
24
+ os.environ["REPLICA_RANK"] = str(rank)
25
+ from composer_replication.diloco.serverless import (
26
+ MockManager, ObjectStoreAllReduce,
27
+ )
28
+ store = ObjectStoreAllReduce(rendezvous_uri, rank=rank,
29
+ world_size=world_size)
30
+ manager = MockManager(store)
31
+ # ... user's training loop with this manager ...
32
+
33
+ then constructs:
34
+
35
+ executor = ModalSpawnExecutor(modal_function=run_replica)
36
+ handles = executor.launch_replicas(
37
+ n_replicas=4,
38
+ entrypoint=run_replica, # ignored — function is bound
39
+ entrypoint_args={"rendezvous_uri": "/vol/diloco/run42",
40
+ "world_size": 4},
41
+ )
42
+
43
+ 2. **Rank as explicit kwarg, not env-var indirection.** Modal Functions
44
+ start with a clean env, so the rank-via-env pattern that
45
+ LocalProcessExecutor uses is fragile here (Modal would need
46
+ container-level env injection per call, which `modal.Secret.from_dict`
47
+ does but adds a round-trip per spawn). We pass rank as a kwarg to
48
+ `.spawn(rank=i)` so it's plumbed through Modal's call args directly.
49
+
50
+ 3. **Handle metadata = `call_id`, no in-process state.** Unlike
51
+ LocalProcessExecutor (which holds Process refs), this executor is
52
+ stateless after launch — handles are reconstructed via
53
+ `modal.FunctionCall.from_id(call_id)` for poll/cancel/collect.
54
+ Lets the executor survive process restart mid-run.
55
+
56
+ References:
57
+ - modal-client 1.4.x docs on FunctionCall: https://modal.com/docs/reference/modal.FunctionCall
58
+ - ADR-005 (executor protocol design)
59
+ """
60
+ from __future__ import annotations
61
+
62
+ import time
63
+ from typing import Any, Callable, Mapping
64
+
65
+ from composer_replication.diloco.serverless.executor import (
66
+ ReplicaHandle,
67
+ ServerlessExecutor,
68
+ )
69
+
70
+
71
+ class ModalSpawnExecutor:
72
+ """Run replicas as parallel Modal Function spawns.
73
+
74
+ Implements the `ServerlessExecutor` Protocol against Modal's
75
+ `Function.spawn()` API. The user must provide a pre-decorated
76
+ `modal.Function` (with `@app.function(...)` already applied) — see
77
+ module docstring for the expected signature.
78
+
79
+ Args:
80
+ modal_function: a `modal.Function` registered against a `modal.App`.
81
+ Must accept at minimum `rank: int` plus the kwargs in
82
+ `entrypoint_args`. Image / GPU / Volume / Secret / timeout
83
+ are pinned on the decorator and the executor won't override
84
+ them.
85
+ deploy: if True, calls `modal_function.app.deploy()` before
86
+ spawning. Required when running outside a `modal run` context
87
+ (e.g. from a regular Python script). Default False — assumes
88
+ the user is inside a `modal run` block where the app is
89
+ already live.
90
+
91
+ Raises:
92
+ RuntimeError: if `modal` client is not installed.
93
+ TypeError: if `modal_function` is not a `modal.Function`.
94
+ """
95
+ backend_name = "modal_spawn"
96
+ supports_inter_replica_network = False # Modal containers are isolated by default
97
+
98
+ def __init__(
99
+ self,
100
+ modal_function: Any,
101
+ *,
102
+ deploy: bool = False,
103
+ ) -> None:
104
+ try:
105
+ import modal # noqa: F401
106
+ except ImportError as e:
107
+ raise RuntimeError(
108
+ "ModalSpawnExecutor requires the modal client. Install with "
109
+ "`pip install modal` and configure with `modal token new`. "
110
+ f"Got: {e!r}"
111
+ )
112
+
113
+ # Duck-type check — modal.Function objects expose .spawn / .remote /
114
+ # ._app, which the user-supplied function will have if they used the
115
+ # @app.function(...) decorator. We avoid `isinstance(_, modal.Function)`
116
+ # to stay tolerant of modal-client minor-version changes that may
117
+ # restructure the class.
118
+ if not (hasattr(modal_function, "spawn") and hasattr(modal_function, "remote")):
119
+ raise TypeError(
120
+ f"modal_function must be a modal.Function (decorated via "
121
+ f"`@app.function(...)`). Got {type(modal_function)!r} which "
122
+ f"has no `.spawn()` method. "
123
+ f"See ModalSpawnExecutor docstring for expected signature."
124
+ )
125
+
126
+ self.modal_function = modal_function
127
+ self._deploy_requested = deploy
128
+ self._deployed = False
129
+ self._handles: dict[int, dict[str, Any]] = {}
130
+
131
+ # -----------------------------------------------------------------
132
+ # Lifecycle
133
+ # -----------------------------------------------------------------
134
+
135
+ def _maybe_deploy(self) -> None:
136
+ if self._deploy_requested and not self._deployed:
137
+ # `modal_function.app` exposes the underlying App. Calling
138
+ # `.deploy()` registers it with Modal so spawn() works from
139
+ # outside `modal run`.
140
+ app = getattr(self.modal_function, "app", None)
141
+ if app is None:
142
+ raise RuntimeError(
143
+ "modal_function.app is None — can't deploy. The function "
144
+ "must have been decorated against a real modal.App."
145
+ )
146
+ app.deploy()
147
+ self._deployed = True
148
+
149
+ # -----------------------------------------------------------------
150
+ # ServerlessExecutor Protocol
151
+ # -----------------------------------------------------------------
152
+
153
+ def launch_replicas(
154
+ self,
155
+ n_replicas: int,
156
+ entrypoint: str | Callable[..., Any],
157
+ entrypoint_args: Mapping[str, Any],
158
+ *,
159
+ gpu: str | None = None,
160
+ timeout: int = 3600,
161
+ ) -> list[ReplicaHandle]:
162
+ """Spawn N parallel Modal Function calls.
163
+
164
+ Note: `entrypoint` is **ignored** — the actual entrypoint is the
165
+ `modal_function` passed to `__init__`. This keeps the executor
166
+ Protocol-compatible while preserving the user's image/GPU
167
+ decoration. `gpu` and `timeout` are similarly ignored (pinned
168
+ on the function decorator).
169
+ """
170
+ del entrypoint, gpu, timeout # pinned on the decorated function
171
+
172
+ if n_replicas < 1:
173
+ raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
174
+
175
+ self._maybe_deploy()
176
+
177
+ # Strip rank_env if present — we use explicit `rank` kwarg.
178
+ spawn_kwargs = {k: v for k, v in entrypoint_args.items()
179
+ if k != "rank_env"}
180
+
181
+ handles: list[ReplicaHandle] = []
182
+ for rank in range(n_replicas):
183
+ try:
184
+ fcall = self.modal_function.spawn(rank=rank, **spawn_kwargs)
185
+ except Exception as e:
186
+ # Best-effort cancel any already-launched siblings
187
+ for prior in handles:
188
+ try:
189
+ self.cancel(prior)
190
+ except Exception:
191
+ pass
192
+ raise RuntimeError(
193
+ f"ModalSpawnExecutor.launch_replicas failed at rank={rank} "
194
+ f"of {n_replicas} (already-launched siblings cancelled). "
195
+ f"Underlying error: {e!r}"
196
+ ) from e
197
+
198
+ handle = ReplicaHandle(
199
+ rank=rank,
200
+ backend_name=self.backend_name,
201
+ metadata={
202
+ "call_id": fcall.object_id,
203
+ "spawn_ts": time.time(),
204
+ },
205
+ )
206
+ self._handles[rank] = {
207
+ "fcall": fcall,
208
+ "result": None,
209
+ }
210
+ handles.append(handle)
211
+
212
+ return handles
213
+
214
+ def poll(self, handle: ReplicaHandle) -> str:
215
+ """Poll a Modal call's status.
216
+
217
+ Modal's FunctionCall doesn't expose a non-blocking status getter
218
+ directly (the API is `.get(timeout=...)`), so we poll by trying
219
+ `.get(timeout=0)` and treating Timeout/Pending as "running".
220
+
221
+ Returns one of: "pending" | "running" | "succeeded" | "failed" |
222
+ "cancelled".
223
+ """
224
+ meta = self._handles.get(handle.rank)
225
+ if meta is None:
226
+ return "cancelled"
227
+
228
+ # If we already collected this one, return cached result
229
+ if meta["result"] is not None:
230
+ return meta["result"]["status"]
231
+
232
+ import modal
233
+ from modal.exception import OutputExpiredError
234
+
235
+ fcall = meta["fcall"]
236
+ # Re-hydrate to get fresh state
237
+ try:
238
+ # `.get(timeout=0)` returns immediately if done; raises TimeoutError otherwise.
239
+ result_value = fcall.get(timeout=0)
240
+ meta["result"] = {
241
+ "rank": handle.rank,
242
+ "status": "succeeded",
243
+ "exit_code": 0,
244
+ "error": None,
245
+ "result": result_value,
246
+ "call_id": handle.metadata.get("call_id"),
247
+ }
248
+ return "succeeded"
249
+ except TimeoutError:
250
+ return "running"
251
+ except OutputExpiredError as e:
252
+ meta["result"] = {
253
+ "rank": handle.rank,
254
+ "status": "failed",
255
+ "exit_code": 1,
256
+ "error": f"OutputExpiredError: {e!r}",
257
+ "result": None,
258
+ "call_id": handle.metadata.get("call_id"),
259
+ }
260
+ return "failed"
261
+ except Exception as e:
262
+ # User-code exception bubbles up here as the original exception class
263
+ meta["result"] = {
264
+ "rank": handle.rank,
265
+ "status": "failed",
266
+ "exit_code": 1,
267
+ "error": f"{type(e).__name__}: {e!r}",
268
+ "result": None,
269
+ "call_id": handle.metadata.get("call_id"),
270
+ }
271
+ return "failed"
272
+
273
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
274
+ """Read recent Modal logs for this call.
275
+
276
+ Modal exposes per-FunctionCall logs via the dashboard URL. The
277
+ client API doesn't expose log-streaming directly in 1.4.x, so we
278
+ return a pointer to the dashboard URL plus any captured error
279
+ from poll().
280
+ """
281
+ meta = self._handles.get(handle.rank)
282
+ if meta is None:
283
+ return f"<replica {handle.rank}: no metadata>"
284
+
285
+ call_id = handle.metadata.get("call_id", "<unknown>")
286
+ try:
287
+ dashboard_url = meta["fcall"].get_dashboard_url()
288
+ except Exception:
289
+ dashboard_url = (
290
+ f"https://modal.com/apps/<workspace>/<env>/calls/{call_id}"
291
+ )
292
+
293
+ if meta.get("result"):
294
+ err = meta["result"].get("error") or "<no error>"
295
+ return (
296
+ f"[rank {handle.rank}] call_id={call_id}\n"
297
+ f" Dashboard: {dashboard_url}\n"
298
+ f" Result: {meta['result']['status']}\n"
299
+ f" Error: {err[-2000:] if err else '<none>'}"
300
+ )
301
+
302
+ return (
303
+ f"[rank {handle.rank}] call_id={call_id} (still running)\n"
304
+ f" Dashboard: {dashboard_url}\n"
305
+ f" Logs not streamable via client API in modal-client 1.4.x; "
306
+ f"use the dashboard URL or `modal app logs <app-id>`."
307
+ )
308
+
309
+ def cancel(self, handle: ReplicaHandle) -> None:
310
+ """Best-effort cancel of a Modal call."""
311
+ meta = self._handles.get(handle.rank)
312
+ if meta is None:
313
+ return
314
+ try:
315
+ meta["fcall"].cancel()
316
+ except Exception:
317
+ # Already terminated, network blip, etc. — best-effort.
318
+ pass
319
+
320
+ def collect(
321
+ self,
322
+ handles: list[ReplicaHandle],
323
+ *,
324
+ timeout: int | None = None,
325
+ ) -> list[dict[str, Any]]:
326
+ """Block until all replicas finish; return per-replica result dicts.
327
+
328
+ Modal's `.get(timeout=...)` blocks until the call completes or
329
+ the timeout elapses. We iterate handles, calling `.get()` with
330
+ the remaining time budget, so the cumulative wall-clock is
331
+ bounded by `timeout`.
332
+ """
333
+ deadline = time.time() + (timeout if timeout is not None else 86400)
334
+ results: list[dict[str, Any]] = []
335
+
336
+ for h in handles:
337
+ meta = self._handles.get(h.rank)
338
+ if meta is None:
339
+ results.append({
340
+ "rank": h.rank,
341
+ "status": "cancelled",
342
+ "exit_code": None,
343
+ "error": "handle has no metadata (cancelled or unknown)",
344
+ "result": None,
345
+ "call_id": h.metadata.get("call_id"),
346
+ })
347
+ continue
348
+
349
+ # Already collected by an earlier poll()
350
+ if meta["result"] is not None:
351
+ results.append(meta["result"])
352
+ continue
353
+
354
+ remaining = max(0.0, deadline - time.time())
355
+ try:
356
+ result_value = meta["fcall"].get(timeout=remaining)
357
+ result_dict = {
358
+ "rank": h.rank,
359
+ "status": "succeeded",
360
+ "exit_code": 0,
361
+ "error": None,
362
+ "result": result_value,
363
+ "call_id": h.metadata.get("call_id"),
364
+ }
365
+ except TimeoutError as e:
366
+ result_dict = {
367
+ "rank": h.rank,
368
+ "status": "running",
369
+ "exit_code": None,
370
+ "error": f"TimeoutError after deadline: {e!r}",
371
+ "result": None,
372
+ "call_id": h.metadata.get("call_id"),
373
+ }
374
+ except Exception as e:
375
+ result_dict = {
376
+ "rank": h.rank,
377
+ "status": "failed",
378
+ "exit_code": 1,
379
+ "error": f"{type(e).__name__}: {e!r}",
380
+ "result": None,
381
+ "call_id": h.metadata.get("call_id"),
382
+ }
383
+
384
+ meta["result"] = result_dict
385
+ results.append(result_dict)
386
+
387
+ return results
388
+
389
+
390
+ __all__ = ["ModalSpawnExecutor"]
composer_replication/diloco/serverless/tests/test_modal_spawn_executor.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ModalSpawnExecutor — the v0-finished Modal-backed executor.
2
+
3
+ These tests exercise the executor's contract WITHOUT requiring a live
4
+ Modal connection. They use a mock `modal.Function` that records calls
5
+ and returns canned `FunctionCall`-shaped objects.
6
+
7
+ For end-to-end Modal integration testing, see the manual ops runbook in
8
+ `~/.hermes/scripts/composer-modal/README.md` — that requires a real
9
+ Modal account and incurs spend.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import importlib
14
+ import time
15
+ import pytest
16
+
17
+ from composer_replication.diloco.serverless import (
18
+ ModalSpawnExecutor,
19
+ ReplicaHandle,
20
+ )
21
+
22
+
23
+ def _is_modal_installed() -> bool:
24
+ try:
25
+ importlib.import_module("modal")
26
+ return True
27
+ except ImportError:
28
+ return False
29
+
30
+
31
+ # ---------------------------------------------------------------------
32
+ # Mock infrastructure
33
+ # ---------------------------------------------------------------------
34
+
35
+
36
+ class _MockFunctionCall:
37
+ """Stand-in for `modal.functions.FunctionCall`.
38
+
39
+ Behavior knobs:
40
+ - `result_value`: what `.get()` returns when called after `delay_s`
41
+ - `delay_s`: how many seconds before `.get()` stops raising TimeoutError
42
+ - `raise_on_get`: if set, `.get()` raises this exception class instead
43
+ - `_creation_time`: monotonic timestamp at construction
44
+ """
45
+ _next_id = 0
46
+
47
+ def __init__(self, result_value=None, *, delay_s=0.0, raise_on_get=None):
48
+ self.object_id = f"fc-mock-{_MockFunctionCall._next_id:04d}"
49
+ _MockFunctionCall._next_id += 1
50
+ self._result = result_value
51
+ self._delay_s = delay_s
52
+ self._raise_on_get = raise_on_get
53
+ self._creation_time = time.monotonic()
54
+ self._cancelled = False
55
+
56
+ def get(self, timeout=None):
57
+ if self._cancelled:
58
+ raise RuntimeError("FunctionCall was cancelled")
59
+ elapsed = time.monotonic() - self._creation_time
60
+ if elapsed < self._delay_s:
61
+ # Not ready yet
62
+ if timeout is None or timeout <= 0:
63
+ raise TimeoutError(f"Mock not ready after {elapsed:.3f}s")
64
+ # Wait up to the timeout for the delay to elapse
65
+ wait = min(timeout, self._delay_s - elapsed)
66
+ time.sleep(wait)
67
+ elapsed = time.monotonic() - self._creation_time
68
+ if elapsed < self._delay_s:
69
+ raise TimeoutError(f"Mock still not ready after wait")
70
+ if self._raise_on_get is not None:
71
+ raise self._raise_on_get
72
+ return self._result
73
+
74
+ def cancel(self):
75
+ self._cancelled = True
76
+
77
+ def get_dashboard_url(self):
78
+ return f"https://modal.com/calls/{self.object_id}"
79
+
80
+
81
+ class _MockModalFunction:
82
+ """Mock of a `@app.function`-decorated callable.
83
+
84
+ Captures `.spawn()` arg-tuples for assertions. Returns
85
+ `_MockFunctionCall` instances.
86
+ """
87
+
88
+ def __init__(self, *, fcall_factory=None):
89
+ self.spawn_calls: list[tuple[tuple, dict]] = []
90
+ self._fcall_factory = fcall_factory or (lambda **kw: _MockFunctionCall(
91
+ result_value={"rank": kw.get("rank"), "ok": True},
92
+ ))
93
+ # The duck-type contract: ModalSpawnExecutor checks for .spawn and .remote
94
+ self.app = None # No deploy needed
95
+
96
+ def spawn(self, *args, **kwargs):
97
+ self.spawn_calls.append((args, kwargs))
98
+ return self._fcall_factory(**kwargs)
99
+
100
+ def remote(self, *args, **kwargs):
101
+ # Required by the duck-type check — not exercised in spawn flow
102
+ return self.spawn(*args, **kwargs).get(timeout=60)
103
+
104
+
105
+ # ---------------------------------------------------------------------
106
+ # Construction / preconditions
107
+ # ---------------------------------------------------------------------
108
+
109
+
110
+ @pytest.mark.skipif(not _is_modal_installed(),
111
+ reason="modal not installed in this venv")
112
+ def test_modal_spawn_executor_rejects_non_function():
113
+ with pytest.raises(TypeError, match="modal_function must be"):
114
+ ModalSpawnExecutor(modal_function="not a function")
115
+ with pytest.raises(TypeError, match="modal_function must be"):
116
+ ModalSpawnExecutor(modal_function=lambda x: x)
117
+ with pytest.raises(TypeError, match=".spawn"):
118
+ ModalSpawnExecutor(modal_function=object())
119
+
120
+
121
+ @pytest.mark.skipif(not _is_modal_installed(),
122
+ reason="modal not installed in this venv")
123
+ def test_modal_spawn_executor_accepts_mock_with_spawn_and_remote():
124
+ mock_fn = _MockModalFunction()
125
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
126
+ assert executor.modal_function is mock_fn
127
+ assert executor.backend_name == "modal_spawn"
128
+ assert executor.supports_inter_replica_network is False
129
+
130
+
131
+ def test_modal_spawn_executor_missing_modal_raises_runtime_error():
132
+ """If `modal` is genuinely missing, the import-error path should fire.
133
+
134
+ Only meaningful in venvs without modal — when modal is installed, the
135
+ import-failure path is unreachable without monkey-patching the
136
+ function-local import (brittle across CPython versions). The
137
+ skeleton-executor test in test_skeleton_executors.py covers the
138
+ "modal absent" contract from a different angle and is the canonical
139
+ pin for that path.
140
+ """
141
+ if _is_modal_installed():
142
+ pytest.skip(
143
+ "modal is installed in this venv; the missing-module path "
144
+ "is covered by test_skeleton_executors.py in venvs without modal"
145
+ )
146
+ with pytest.raises(RuntimeError, match="modal client"):
147
+ ModalSpawnExecutor(modal_function=_MockModalFunction())
148
+
149
+
150
+ # ---------------------------------------------------------------------
151
+ # launch_replicas
152
+ # ---------------------------------------------------------------------
153
+
154
+
155
+ @pytest.mark.skipif(not _is_modal_installed(),
156
+ reason="modal not installed in this venv")
157
+ def test_launch_replicas_calls_spawn_n_times_with_rank_kwarg():
158
+ mock_fn = _MockModalFunction()
159
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
160
+
161
+ handles = executor.launch_replicas(
162
+ n_replicas=4,
163
+ entrypoint="ignored", # Pinned via decorator
164
+ entrypoint_args={"rendezvous_uri": "/vol/run42", "world_size": 4},
165
+ )
166
+
167
+ assert len(handles) == 4
168
+ # Each rank in order, backend correct, call_id captured
169
+ for i, h in enumerate(handles):
170
+ assert h.rank == i
171
+ assert h.backend_name == "modal_spawn"
172
+ assert "call_id" in h.metadata
173
+ assert h.metadata["call_id"].startswith("fc-mock-")
174
+
175
+ # spawn() was called 4× with explicit rank + the user kwargs
176
+ assert len(mock_fn.spawn_calls) == 4
177
+ for i, (args, kwargs) in enumerate(mock_fn.spawn_calls):
178
+ assert args == (), f"args should be empty, got {args}"
179
+ assert kwargs["rank"] == i
180
+ assert kwargs["rendezvous_uri"] == "/vol/run42"
181
+ assert kwargs["world_size"] == 4
182
+
183
+
184
+ @pytest.mark.skipif(not _is_modal_installed(),
185
+ reason="modal not installed in this venv")
186
+ def test_launch_replicas_strips_rank_env_kwarg():
187
+ """`rank_env` is the LocalProcessExecutor convention; ModalSpawn should drop it."""
188
+ mock_fn = _MockModalFunction()
189
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
190
+
191
+ executor.launch_replicas(
192
+ n_replicas=2,
193
+ entrypoint="ignored",
194
+ entrypoint_args={"rank_env": "REPLICA_RANK", "rendezvous_uri": "/x"},
195
+ )
196
+
197
+ for _, kwargs in mock_fn.spawn_calls:
198
+ assert "rank_env" not in kwargs
199
+ assert kwargs["rendezvous_uri"] == "/x"
200
+
201
+
202
+ @pytest.mark.skipif(not _is_modal_installed(),
203
+ reason="modal not installed in this venv")
204
+ def test_launch_replicas_rejects_zero_or_negative():
205
+ mock_fn = _MockModalFunction()
206
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
207
+ with pytest.raises(ValueError, match="n_replicas"):
208
+ executor.launch_replicas(n_replicas=0, entrypoint="x", entrypoint_args={})
209
+ with pytest.raises(ValueError, match="n_replicas"):
210
+ executor.launch_replicas(n_replicas=-1, entrypoint="x", entrypoint_args={})
211
+
212
+
213
+ @pytest.mark.skipif(not _is_modal_installed(),
214
+ reason="modal not installed in this venv")
215
+ def test_launch_replicas_cancels_prior_on_partial_failure():
216
+ """If spawn fails at rank 2 of 4, the already-launched 0/1 must be cancelled."""
217
+ cancelled_calls = []
218
+
219
+ def factory(**kwargs):
220
+ rank = kwargs.get("rank", -1)
221
+ if rank == 2:
222
+ raise RuntimeError("simulated spawn failure at rank 2")
223
+ fc = _MockFunctionCall(result_value={"rank": rank})
224
+ original_cancel = fc.cancel
225
+
226
+ def tracked_cancel():
227
+ cancelled_calls.append(fc.object_id)
228
+ original_cancel()
229
+ fc.cancel = tracked_cancel
230
+ return fc
231
+
232
+ mock_fn = _MockModalFunction(fcall_factory=factory)
233
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
234
+
235
+ with pytest.raises(RuntimeError, match="rank=2"):
236
+ executor.launch_replicas(
237
+ n_replicas=4,
238
+ entrypoint="x",
239
+ entrypoint_args={"world_size": 4},
240
+ )
241
+
242
+ # Ranks 0 and 1 were spawned and should have been cancelled
243
+ assert len(cancelled_calls) == 2
244
+
245
+
246
+ # ---------------------------------------------------------------------
247
+ # poll / collect
248
+ # ---------------------------------------------------------------------
249
+
250
+
251
+ @pytest.mark.skipif(not _is_modal_installed(),
252
+ reason="modal not installed in this venv")
253
+ def test_poll_returns_succeeded_for_immediate_result():
254
+ def factory(**kwargs):
255
+ return _MockFunctionCall(result_value={"rank": kwargs["rank"], "ok": True})
256
+
257
+ mock_fn = _MockModalFunction(fcall_factory=factory)
258
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
259
+ handles = executor.launch_replicas(
260
+ n_replicas=2,
261
+ entrypoint="x",
262
+ entrypoint_args={"world_size": 2},
263
+ )
264
+
265
+ # Immediate get() should succeed
266
+ for h in handles:
267
+ status = executor.poll(h)
268
+ assert status == "succeeded", f"rank {h.rank} expected succeeded, got {status}"
269
+
270
+
271
+ @pytest.mark.skipif(not _is_modal_installed(),
272
+ reason="modal not installed in this venv")
273
+ def test_poll_returns_running_for_delayed_call():
274
+ def factory(**kwargs):
275
+ return _MockFunctionCall(
276
+ result_value={"rank": kwargs["rank"]}, delay_s=10.0,
277
+ )
278
+
279
+ mock_fn = _MockModalFunction(fcall_factory=factory)
280
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
281
+ handles = executor.launch_replicas(
282
+ n_replicas=1, entrypoint="x", entrypoint_args={},
283
+ )
284
+
285
+ status = executor.poll(handles[0])
286
+ assert status == "running"
287
+
288
+
289
+ @pytest.mark.skipif(not _is_modal_installed(),
290
+ reason="modal not installed in this venv")
291
+ def test_poll_returns_failed_for_user_exception():
292
+ def factory(**kwargs):
293
+ return _MockFunctionCall(
294
+ raise_on_get=ValueError("user code blew up"),
295
+ )
296
+
297
+ mock_fn = _MockModalFunction(fcall_factory=factory)
298
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
299
+ handles = executor.launch_replicas(
300
+ n_replicas=1, entrypoint="x", entrypoint_args={},
301
+ )
302
+
303
+ status = executor.poll(handles[0])
304
+ assert status == "failed"
305
+
306
+
307
+ @pytest.mark.skipif(not _is_modal_installed(),
308
+ reason="modal not installed in this venv")
309
+ def test_collect_returns_per_replica_dicts():
310
+ def factory(**kwargs):
311
+ return _MockFunctionCall(result_value={"rank": kwargs["rank"], "ok": True})
312
+
313
+ mock_fn = _MockModalFunction(fcall_factory=factory)
314
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
315
+ handles = executor.launch_replicas(
316
+ n_replicas=3, entrypoint="x", entrypoint_args={},
317
+ )
318
+
319
+ results = executor.collect(handles, timeout=5)
320
+ assert len(results) == 3
321
+ for i, r in enumerate(results):
322
+ assert r["rank"] == i
323
+ assert r["status"] == "succeeded"
324
+ assert r["exit_code"] == 0
325
+ assert r["error"] is None
326
+ assert r["result"] == {"rank": i, "ok": True}
327
+ assert r["call_id"].startswith("fc-mock-")
328
+
329
+
330
+ @pytest.mark.skipif(not _is_modal_installed(),
331
+ reason="modal not installed in this venv")
332
+ def test_collect_caches_results_and_does_not_call_get_twice():
333
+ """Once a poll() succeeds, collect() must read from cache, not call .get() again."""
334
+ get_calls = []
335
+
336
+ class _CountingFC(_MockFunctionCall):
337
+ def get(self, timeout=None):
338
+ get_calls.append(self.object_id)
339
+ return super().get(timeout=timeout)
340
+
341
+ def factory(**kwargs):
342
+ return _CountingFC(result_value={"rank": kwargs["rank"]})
343
+
344
+ mock_fn = _MockModalFunction(fcall_factory=factory)
345
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
346
+ handles = executor.launch_replicas(
347
+ n_replicas=2, entrypoint="x", entrypoint_args={},
348
+ )
349
+
350
+ # Poll all to cache
351
+ for h in handles:
352
+ executor.poll(h)
353
+ n_polls = len(get_calls)
354
+ assert n_polls == 2 # one .get per poll
355
+
356
+ # Collect should NOT call .get() again
357
+ results = executor.collect(handles, timeout=5)
358
+ assert len(get_calls) == n_polls # no additional .get() calls
359
+ assert all(r["status"] == "succeeded" for r in results)
360
+
361
+
362
+ # ---------------------------------------------------------------------
363
+ # Logs / cancel
364
+ # ---------------------------------------------------------------------
365
+
366
+
367
+ @pytest.mark.skipif(not _is_modal_installed(),
368
+ reason="modal not installed in this venv")
369
+ def test_stream_logs_includes_dashboard_url_and_call_id():
370
+ mock_fn = _MockModalFunction()
371
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
372
+ handles = executor.launch_replicas(
373
+ n_replicas=1, entrypoint="x", entrypoint_args={},
374
+ )
375
+
376
+ log_text = executor.stream_logs(handles[0])
377
+ assert "fc-mock-" in log_text
378
+ assert "https://modal.com/" in log_text
379
+
380
+
381
+ @pytest.mark.skipif(not _is_modal_installed(),
382
+ reason="modal not installed in this venv")
383
+ def test_cancel_calls_fcall_cancel():
384
+ mock_fn = _MockModalFunction()
385
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
386
+ handles = executor.launch_replicas(
387
+ n_replicas=2, entrypoint="x", entrypoint_args={},
388
+ )
389
+
390
+ fc0 = executor._handles[0]["fcall"]
391
+ assert fc0._cancelled is False
392
+ executor.cancel(handles[0])
393
+ assert fc0._cancelled is True
394
+ # Cancelling rank 1 doesn't affect rank 0 (already cancelled — no-op)
395
+ executor.cancel(handles[1])
396
+ fc1 = executor._handles[1]["fcall"]
397
+ assert fc1._cancelled is True
398
+
399
+
400
+ @pytest.mark.skipif(not _is_modal_installed(),
401
+ reason="modal not installed in this venv")
402
+ def test_cancel_unknown_handle_is_noop():
403
+ mock_fn = _MockModalFunction()
404
+ executor = ModalSpawnExecutor(modal_function=mock_fn)
405
+ # No replicas launched — handle has rank that doesn't exist in _handles
406
+ fake_handle = ReplicaHandle(
407
+ rank=99, backend_name="modal_spawn", metadata={"call_id": "nonexistent"},
408
+ )
409
+ # Must not raise
410
+ executor.cancel(fake_handle)