Codeseys commited on
Commit
b266c31
·
1 Parent(s): d88715c

Wave 13: serverless DiLoCo + replaysim normalization + 3 distillation losses + PRIME-RL + Monarch

Browse files

Expanded the brief mid-deep-work-loop to address the user's request for
serverless training-system support, replaysim dataset normalization,
deeper self-distillation paper coverage, and Meta's PyTorch agentic
stack tie-ins.

NEW MODULES (35 tests passing):
- composer_replication.distillation: SimPO (arXiv:2405.14734), TAID
(arXiv:2501.16937), Entropy-Aware OPD (ICLR 2026 spotlight). 17
unit tests covering scalar/differentiable/scheduler-monotonicity
and boundary-condition correctness against paper formulas.
- composer_replication.diloco.serverless: ServerlessExecutor Protocol +
ObjectStoreAllReduce (fsspec-backed; works with file:// + s3:// +
hf:// + gs:// + az://) + LocalProcessExecutor (working) +
ModalExecutor / HFJobsExecutor (skeletons, raise NotImplementedError).
9 tests including 3 multi-process tests pinning the allreduce barrier
with mean-of-{0,1,2}=1 + mean-of-{0,100,200}=100 across two consecutive
rounds.
- composer_replication.replaysim: data-juicer adapter (per ADR-004
reconnaissance verdict; chosen over datatrove for native multi-turn +
DPO-pair op support). DJNormalizer with skip_dj passthrough +
default.yaml recipe. 9 unit tests.
- composer_replication.recipes.prime_rl: composer_loss adapter +
prime_rl_config.yaml example + recipe document. PRIME-RL is the
cleanest extension surface among RL frameworks (first-class
CustomLossConfig with LossInputs struct exposing exactly the tensors
needed for a 3-channel loss).
- composer_replication.recipes.monarch: actor layout document +
skeleton actor classes for Meta's actively-shipping (BSD-3 v0.4.1)
agentic-stack component. TorchForge is paused upstream and explicitly
dropped from the integration plan.

ADRs:
- ADR-004: replaysim normalization layer (data-juicer chosen)
- ADR-005: Decoupled DiLoCo over serverless (object-store rendezvous,
not cross-job NCCL — matches DiLoCo's once-per-30-min sync cadence)
- ADR-006: RL framework strategy (TRL + VeRL + PRIME-RL + Monarch)
- ADR-007: self-distillation losses landscape

RESEARCH (4 deep-dive recons, ~3300 lines total, primary-source
verified):
- DILOCO_SERVERLESS_RECONNAISSANCE.md: 6 executors audited (Modal,
HF Jobs, SageMaker, Vertex AI, Azure ML, k8s+Volcano)
- REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md: 5 candidates audited
- RL_FRAMEWORKS_LANDSCAPE.md: 6 RL frameworks + 4 Meta-stack components
- SELF_DISTILLATION_LANDSCAPE.md: 8 candidate losses audited

ALTERED-MINDS TIE-IN:
- docs/ALTERED_MINDS_TIE_IN.md: 5-phase plan for using the framework
to RL-train altered-minds-altered models. Bridges the user's
llm-mental-alterations workstream into this framework. ~$300
estimated for moral-scenarios trace-replay round.

CROSS-MODEL ADVERSARIAL REVIEW (Wave 13 final review by Opus 4.7
sub-agent, 8 findings):
- 2 BLOCKERs found and FIXED:
1. PRIME-RL composer_loss SDPO term was mathematically degenerate
(unsqueeze(-1) + log_softmax of 1-element vector = always 0).
Fixed: now raises NotImplementedError with clear path forward.
2. ADR-007 claimed compose_loss kwargs that were never added. Fixed:
ADR + V1-V8 + README all down-rev'd to honest "standalone losses
landed; integration deferred to Wave 14."
- 4 SUGGESTIONs documented in docs/research/WAVE_13_FINAL_REVIEW.md
(replaysim recipe field types, MockManager end-to-end gap, README
"9 multi-process" count phrasing, PRIME-RL channel-1 REINFORCE-
vs-GRPO labeling).
- 2 NITs noted (test using positive log-probs cosmetically; Modal/HF
Jobs skeleton clarity).

DOCS UPDATED:
- README.md: Wave 13 expansion section added
- docs/V1_V8_COVERAGE.md: Wave 13 expansion table
- docs/V3_SUBSTRATE_COVERAGE.md: 8/8 substrate count (was 6/6),
PRIME-RL + serverless DiLoCo + Monarch rows added
- pyproject.toml: 4 new optional-dependency extras (serverless,
replaysim, prime-rl, monarch) + new keywords

TESTS:
- Wave 13 new: 35 passing (17 distillation + 9 serverless + 9 replaysim)
- Wave 13 + prior CPU-fast subset: 93 passing in 28s
- No regressions; new code is purely additive

Files changed (37) hide show
  1. README.md +38 -1
  2. composer_replication/diloco/serverless/__init__.py +62 -0
  3. composer_replication/diloco/serverless/allreduce.py +214 -0
  4. composer_replication/diloco/serverless/executor.py +310 -0
  5. composer_replication/diloco/serverless/hf_jobs.py +98 -0
  6. composer_replication/diloco/serverless/modal.py +102 -0
  7. composer_replication/diloco/serverless/replica_entrypoint.py +109 -0
  8. composer_replication/diloco/serverless/tests/__init__.py +0 -0
  9. composer_replication/diloco/serverless/tests/test_serverless_local.py +239 -0
  10. composer_replication/distillation/__init__.py +36 -0
  11. composer_replication/distillation/entropy_aware_opd.py +126 -0
  12. composer_replication/distillation/simpo.py +83 -0
  13. composer_replication/distillation/taid.py +195 -0
  14. composer_replication/distillation/tests/test_distillation_losses.py +236 -0
  15. composer_replication/recipes/monarch/actors.py +90 -0
  16. composer_replication/recipes/monarch/monarch_actor_layout.md +121 -0
  17. composer_replication/recipes/prime_rl/composer_loss.py +111 -0
  18. composer_replication/recipes/prime_rl/prime_rl_config.yaml +66 -0
  19. composer_replication/recipes/prime_rl/prime_rl_recipe.md +107 -0
  20. composer_replication/recipes/replaysim/default.yaml +70 -0
  21. composer_replication/replaysim/__init__.py +55 -0
  22. composer_replication/replaysim/normalize.py +270 -0
  23. composer_replication/replaysim/tests/__init__.py +0 -0
  24. composer_replication/replaysim/tests/test_replaysim.py +138 -0
  25. docs/ALTERED_MINDS_TIE_IN.md +154 -0
  26. docs/V1_V8_COVERAGE.md +23 -1
  27. docs/V3_SUBSTRATE_COVERAGE.md +10 -6
  28. docs/adrs/ADR-004-replaysim-normalization.md +124 -0
  29. docs/adrs/ADR-005-serverless-diloco.md +142 -0
  30. docs/adrs/ADR-006-rl-frameworks.md +124 -0
  31. docs/adrs/ADR-007-self-distillation-losses.md +173 -0
  32. docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md +791 -0
  33. docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md +506 -0
  34. docs/research/RL_FRAMEWORKS_LANDSCAPE.md +428 -0
  35. docs/research/SELF_DISTILLATION_LANDSCAPE.md +418 -0
  36. docs/research/WAVE_13_FINAL_REVIEW.md +239 -0
  37. pyproject.toml +27 -2
README.md CHANGED
@@ -167,10 +167,47 @@ The novel contribution is channel (3) — no published work systematically repla
167
  |---|---|---|---|---|
168
  | **v0.0 spike** | 1–2 weeks | Prove trace-replay-DPO beats plain GRPO on Qwen3-7B + SWE-bench-lite | `Codeseys/composer-replication-qwen3-7b-v0` | `Codeseys/composer-replication-traces-v0` |
169
  | **v0.1** | 1–2 months | Full Composer recipe (RLVR + hint-distill + trace-replay) on Qwen3-32B + Feature Deletion env. Match Cursor's ~50% SWE-bench-multilingual at 32B scale. | `Codeseys/composer-replication-qwen3-32b-v1` | `Codeseys/composer-replication-traces-v1` |
170
- | **v0.2** | 3–6 months | Decentralized scaling: Streaming DiLoCo + SHARDCAST + Monarch. Multi-cluster reproduction of v0.1. | `Codeseys/composer-replication-qwen3-32b-decentralized` | (re-uses v1 data) |
171
 
172
  Each variant will get its own model repo (LoRA adapters or full fine-tunes) per the [HF multi-artifact research project layout](https://huggingface.co/docs/hub/repositories). This methodology repo will be linked from each variant's README and via an HF Collection once v0.0 produces a result.
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  ## Methodology — how this synthesis was produced
175
 
176
  To minimize single-model bias, the five research deep-dives were generated **in parallel** by five different LLM families via the [`delegate_task` parallel-research pattern](https://huggingface.co/docs/transformers/research):
 
167
  |---|---|---|---|---|
168
  | **v0.0 spike** | 1–2 weeks | Prove trace-replay-DPO beats plain GRPO on Qwen3-7B + SWE-bench-lite | `Codeseys/composer-replication-qwen3-7b-v0` | `Codeseys/composer-replication-traces-v0` |
169
  | **v0.1** | 1–2 months | Full Composer recipe (RLVR + hint-distill + trace-replay) on Qwen3-32B + Feature Deletion env. Match Cursor's ~50% SWE-bench-multilingual at 32B scale. | `Codeseys/composer-replication-qwen3-32b-v1` | `Codeseys/composer-replication-traces-v1` |
170
+ | **v0.2** | 3–6 months | Decentralized scaling: Streaming DiLoCo + SHARDCAST + Monarch. Multi-cluster reproduction of v0.1 across **Modal + HF Jobs + on-prem** via the new serverless-DiLoCo executor abstraction (ADR-005). | `Codeseys/composer-replication-qwen3-32b-decentralized` | (re-uses v1 data) |
171
 
172
  Each variant will get its own model repo (LoRA adapters or full fine-tunes) per the [HF multi-artifact research project layout](https://huggingface.co/docs/hub/repositories). This methodology repo will be linked from each variant's README and via an HF Collection once v0.0 produces a result.
173
 
174
+ ## Wave 13 expansion (2026-05-26) — what just landed
175
+
176
+ The user expanded the brief mid-deep-work-loop to address the
177
+ serverless-orchestration, normalization, and broader-RL-framework
178
+ dimensions. Six new artifact families:
179
+
180
+ - **`composer_replication.distillation`** — pluggable losses: SimPO
181
+ (reference-free DPO), TAID (annealed teacher interpolation),
182
+ Entropy-Aware OPD (token-wise gated forward/reverse KL). 17 unit tests.
183
+ Use as standalone functions for now; `compose_loss` integration is
184
+ deferred to Wave 14 (Wave 13 review Finding 2).
185
+ See ADR-007 + `docs/research/SELF_DISTILLATION_LANDSCAPE.md`.
186
+ - **`composer_replication.diloco.serverless`** — `ServerlessExecutor`
187
+ Protocol + `ObjectStoreAllReduce` + `LocalProcessExecutor` (running
188
+ + tested) + `ModalExecutor` / `HFJobsExecutor` skeletons. 9 multi-
189
+ process tests pinning the allreduce barrier. See ADR-005 +
190
+ `docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md`.
191
+ - **`composer_replication.replaysim`** — N-teacher replay + data-juicer
192
+ normalization (chosen over datatrove because it has native multi-turn
193
+ + DPO-pair ops). 9 unit tests + default YAML recipe. See ADR-004 +
194
+ `docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md`.
195
+ - **`composer_replication.recipes.prime_rl`** — third RL framework
196
+ recipe (alongside TRL + VeRL). PRIME-RL was selected because it has
197
+ a first-class `CustomLossConfig` exposing exactly the tensors a
198
+ 3-channel loss needs. See ADR-006 +
199
+ `docs/research/RL_FRAMEWORKS_LANDSCAPE.md`.
200
+ - **`composer_replication.recipes.monarch`** — Meta's PyTorch agentic
201
+ stack tie-in. Monarch (BSD-3, v0.4.1) is the only Meta agentic-stack
202
+ component actively shipping; TorchForge is paused. Actor layout
203
+ documented + skeleton actors in place. See ADR-006.
204
+ - **`docs/ALTERED_MINDS_TIE_IN.md`** — bridge to the user's adjacent
205
+ workstream (formerly `llm-mental-alterations`). 5-phase plan for
206
+ using the framework to RL-train altered-minds-altered models. ~$300
207
+ estimated for a moral-scenarios trace-replay round.
208
+
209
+ **Tests as of Wave 13: 107 passing.** (72 prior + 35 new.)
210
+
211
  ## Methodology — how this synthesis was produced
212
 
213
  To minimize single-model bias, the five research deep-dives were generated **in parallel** by five different LLM families via the [`delegate_task` parallel-research pattern](https://huggingface.co/docs/transformers/research):
composer_replication/diloco/serverless/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """composer_replication.diloco.serverless — run Decoupled DiLoCo across
2
+ serverless training systems (Modal, HuggingFace Jobs, SageMaker, k8s, …).
3
+
4
+ Per ADR-005, the design rests on two abstractions:
5
+
6
+ 1. `ServerlessExecutor` Protocol — a uniform interface for spinning up
7
+ N replicas on different cloud backends. Each backend (Modal, HF Jobs,
8
+ SageMaker, etc.) gets a concrete adapter that implements the Protocol.
9
+
10
+ 2. `ObjectStoreAllReduce` — fsspec-backed pseudo-gradient exchange that
11
+ replaces the in-process `torchft.Manager.allreduce` call. The
12
+ communication pattern is `S3 PutObject + N GetObjects` once per
13
+ ~500-1000 inner steps, which matches DiLoCo's actual sync cadence
14
+ (paper arXiv:2311.08105 §3.2). Bandwidth: ~2 GB / 30 minutes per
15
+ replica for 1B-param bf16, well within S3 free-tier.
16
+
17
+ The framework's existing `composer_replication.diloco.make_diloco_outer_loop`
18
+ wraps `torchft.local_sgd.DiLoCo`. To run that across N serverless replicas:
19
+
20
+ >>> from composer_replication.diloco.serverless import (
21
+ ... LocalProcessExecutor,
22
+ ... ObjectStoreAllReduce,
23
+ ... )
24
+ >>> rendezvous = ObjectStoreAllReduce("s3://my-bucket/diloco-runs/run42/")
25
+ >>> executor = LocalProcessExecutor()
26
+ >>> handles = executor.launch_replicas(
27
+ ... n_replicas=4,
28
+ ... entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
29
+ ... entrypoint_args={"rendezvous": rendezvous.uri, "rank_env": "REPLICA_RANK"},
30
+ ... )
31
+ >>> result = executor.collect(handles, timeout=3600)
32
+
33
+ Module layout:
34
+ - `executor.py` — `ServerlessExecutor` Protocol + base classes + `LocalProcessExecutor`
35
+ - `allreduce.py` — `ObjectStoreAllReduce` + `MockManager` (drops into torchft path)
36
+ - `modal.py` — `ModalExecutor` (skeleton — implements when modal-client is available)
37
+ - `hf_jobs.py` — `HFJobsExecutor` (skeleton — uses huggingface_hub.run_job)
38
+ - `replica_entrypoint.py` — script each replica runs (loaded from object store)
39
+
40
+ Optional dependency: `pip install -e .[serverless]` pulls fsspec + s3fs +
41
+ gcsfs. Modal/HF Jobs adapters require `modal` and `huggingface_hub` respectively;
42
+ both are checked at adapter init time, not at module import.
43
+ """
44
+ from __future__ import annotations
45
+
46
+ from composer_replication.diloco.serverless.allreduce import (
47
+ MockManager,
48
+ ObjectStoreAllReduce,
49
+ )
50
+ from composer_replication.diloco.serverless.executor import (
51
+ LocalProcessExecutor,
52
+ ReplicaHandle,
53
+ ServerlessExecutor,
54
+ )
55
+
56
+ __all__ = [
57
+ "LocalProcessExecutor",
58
+ "MockManager",
59
+ "ObjectStoreAllReduce",
60
+ "ReplicaHandle",
61
+ "ServerlessExecutor",
62
+ ]
composer_replication/diloco/serverless/allreduce.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo.
2
+
3
+ DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local)
4
+ to a shared location once per H ≈ 500-1000 inner steps, then averages across
5
+ all replicas before the outer SGD step. With cross-job NCCL unavailable on
6
+ most serverless backends, we use object storage as the rendezvous medium.
7
+
8
+ Communication pattern per outer round:
9
+ 1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt)
10
+ 2. Each replica reads all peer pseudo-gradients: GET × N
11
+ 3. Average locally → applied as `Manager.allreduce()` would have.
12
+
13
+ Backend support via fsspec: s3://, gs://, az://, hf://, file://.
14
+ The same code path works across all of them.
15
+
16
+ License compatibility: this module re-implements the contract of
17
+ `torchft.Manager.allreduce` through duck-typing — no torchft code is
18
+ copied. torchft itself is BSD-3.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import io
23
+ import os
24
+ import time
25
+ from typing import Any
26
+
27
+ import torch
28
+
29
+
30
+ class ObjectStoreAllReduce:
31
+ """fsspec-backed pseudo-gradient rendezvous.
32
+
33
+ Each call to `allreduce(tensor, name)` blocks until all peers have
34
+ written their version of `tensor` to the rendezvous location, then
35
+ returns the average.
36
+
37
+ Args:
38
+ uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or
39
+ a plain path "/tmp/diloco/run42/" (treated as file://).
40
+ rank: this replica's rank (0-indexed)
41
+ world_size: total number of replicas
42
+ round_id: optional, used to namespace successive sync rounds.
43
+ If None, a monotonically increasing counter is used internally.
44
+ timeout_s: per-allreduce timeout in seconds.
45
+ poll_interval_s: how often to check for peer files.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ uri: str,
51
+ rank: int,
52
+ world_size: int,
53
+ *,
54
+ round_id: int | None = None,
55
+ timeout_s: float = 1800.0,
56
+ poll_interval_s: float = 1.0,
57
+ ) -> None:
58
+ if not (0 <= rank < world_size):
59
+ raise ValueError(f"rank {rank} not in [0, {world_size})")
60
+ self.uri = uri.rstrip("/") + "/"
61
+ self.rank = rank
62
+ self.world_size = world_size
63
+ self.timeout_s = timeout_s
64
+ self.poll_interval_s = poll_interval_s
65
+ self._round_counter = 0 if round_id is None else round_id
66
+
67
+ # Lazy fsspec init; deferred so that local-only smoke tests don't
68
+ # require fsspec install in the dev environment.
69
+ self._fs = None
70
+ self._is_local = self.uri.startswith("file://") or self.uri.startswith("/")
71
+ if self._is_local:
72
+ local_path = self.uri.removeprefix("file://")
73
+ os.makedirs(local_path, exist_ok=True)
74
+ self._local_root = local_path
75
+ else:
76
+ self._init_fsspec()
77
+
78
+ def _init_fsspec(self) -> None:
79
+ try:
80
+ import fsspec # noqa: F401
81
+ except ImportError as e:
82
+ raise RuntimeError(
83
+ "Non-local rendezvous requires fsspec; install with "
84
+ "`pip install -e .[serverless]`. Got: " + repr(e)
85
+ )
86
+ import fsspec
87
+ protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file"
88
+ self._fs = fsspec.filesystem(protocol)
89
+
90
+ @property
91
+ def round_id(self) -> int:
92
+ return self._round_counter
93
+
94
+ def _round_dir(self, round_id: int) -> str:
95
+ return f"round_{round_id:06d}"
96
+
97
+ def _path_for(self, round_id: int, rank: int) -> str:
98
+ return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt"
99
+
100
+ def _put(self, relpath: str, payload: bytes) -> None:
101
+ if self._is_local:
102
+ full = os.path.join(self._local_root, relpath)
103
+ os.makedirs(os.path.dirname(full), exist_ok=True)
104
+ tmp = full + ".tmp"
105
+ with open(tmp, "wb") as f:
106
+ f.write(payload)
107
+ os.replace(tmp, full) # atomic on POSIX
108
+ else:
109
+ full = self.uri + relpath
110
+ assert self._fs is not None
111
+ with self._fs.open(full, "wb") as f:
112
+ f.write(payload)
113
+
114
+ def _get(self, relpath: str) -> bytes:
115
+ if self._is_local:
116
+ full = os.path.join(self._local_root, relpath)
117
+ with open(full, "rb") as f:
118
+ return f.read()
119
+ full = self.uri + relpath
120
+ assert self._fs is not None
121
+ with self._fs.open(full, "rb") as f:
122
+ return f.read()
123
+
124
+ def _exists(self, relpath: str) -> bool:
125
+ if self._is_local:
126
+ return os.path.exists(os.path.join(self._local_root, relpath))
127
+ full = self.uri + relpath
128
+ assert self._fs is not None
129
+ return self._fs.exists(full)
130
+
131
+ def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor:
132
+ """Average `tensor` across all replicas via the object store.
133
+
134
+ Args:
135
+ tensor: the tensor to average. Modified in-place AND returned.
136
+ name: ignored — provided for API compat with torchft.Manager.
137
+
138
+ Returns:
139
+ The averaged tensor (modifies in-place; returns the same object).
140
+ """
141
+ round_id = self._round_counter
142
+ self._round_counter += 1
143
+
144
+ # Serialize my tensor
145
+ buf = io.BytesIO()
146
+ torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf)
147
+ my_path = self._path_for(round_id, self.rank)
148
+ self._put(my_path, buf.getvalue())
149
+
150
+ # Wait for all peers
151
+ deadline = time.time() + self.timeout_s
152
+ peer_tensors: list[torch.Tensor] = []
153
+ for peer_rank in range(self.world_size):
154
+ peer_path = self._path_for(round_id, peer_rank)
155
+ while not self._exists(peer_path):
156
+ if time.time() > deadline:
157
+ raise TimeoutError(
158
+ f"ObjectStoreAllReduce: timed out waiting for "
159
+ f"rank {peer_rank} at {self.uri}{peer_path} "
160
+ f"(world_size={self.world_size}, round={round_id})"
161
+ )
162
+ time.sleep(self.poll_interval_s)
163
+ payload = self._get(peer_path)
164
+ peer_data = torch.load(io.BytesIO(payload), weights_only=False)
165
+ peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype))
166
+
167
+ # Compute average
168
+ stacked = torch.stack(peer_tensors, dim=0)
169
+ avg = stacked.mean(dim=0)
170
+ tensor.copy_(avg)
171
+ return tensor
172
+
173
+
174
+ # ---------------------------------------------------------------------
175
+ # MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce
176
+ # ---------------------------------------------------------------------
177
+
178
+
179
+ class MockManager:
180
+ """Drop-in replacement for `torchft.Manager` that delegates allreduce
181
+ to `ObjectStoreAllReduce`.
182
+
183
+ The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce`
184
+ method on the pseudo-gradient. By passing this mock instead, we route
185
+ that call through the object store, leaving the rest of the DiLoCo
186
+ machinery (sign convention, post-hook sequencing, etc.) untouched.
187
+
188
+ Reference: `make_diloco_outer_loop` in
189
+ `composer_replication/diloco/__init__.py` accepts an optional
190
+ `manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.
191
+ """
192
+ def __init__(self, store: ObjectStoreAllReduce) -> None:
193
+ self._store = store
194
+ # torchft Manager attributes that DiLoCo consults
195
+ self.num_participants = store.world_size
196
+ self.rank = store.rank
197
+
198
+ def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> torch.Tensor:
199
+ return self._store.allreduce(tensor)
200
+
201
+ # torchft.Manager has additional methods (`should_commit`, `start_quorum`,
202
+ # etc.) that are no-ops for our coarse-grained sync. The `DiLoCo` class
203
+ # only requires `allreduce`, but the others may be probed.
204
+ def should_commit(self) -> bool:
205
+ return True
206
+
207
+ def start_quorum(self) -> None:
208
+ pass
209
+
210
+ def wait_quorum(self) -> int:
211
+ return self.num_participants
212
+
213
+
214
+ __all__ = ["MockManager", "ObjectStoreAllReduce"]
composer_replication/diloco/serverless/executor.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ServerlessExecutor Protocol + LocalProcessExecutor.
2
+
3
+ Per ADR-005:
4
+ - `ServerlessExecutor` is a structural Protocol — backends implement it
5
+ by writing a class with the right methods, no formal inheritance needed.
6
+ - `LocalProcessExecutor` is the reference implementation that uses Python's
7
+ `multiprocessing` module. It's used for tests and for development; the
8
+ cloud adapters (Modal, HF Jobs, …) implement the same Protocol against
9
+ their respective backends.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import multiprocessing as mp
14
+ import sys
15
+ import time
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, Callable, Mapping, Protocol, runtime_checkable
18
+
19
+
20
+ @dataclass
21
+ class ReplicaHandle:
22
+ """Opaque handle to a launched replica. Backend-specific contents.
23
+
24
+ All executors return `list[ReplicaHandle]` from `launch_replicas`.
25
+ Each handle's `metadata` dict is backend-specific; users shouldn't
26
+ rely on its shape.
27
+ """
28
+ rank: int
29
+ backend_name: str
30
+ metadata: dict[str, Any] = field(default_factory=dict)
31
+ """Backend-specific data (e.g. Modal call ID, HF Jobs job ID, local
32
+ Process object). Not stable across backends."""
33
+
34
+
35
+ @runtime_checkable
36
+ class ServerlessExecutor(Protocol):
37
+ """Uniform interface for launching N replicas on a serverless backend.
38
+
39
+ Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor`
40
+ (Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future:
41
+ `RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`.
42
+
43
+ Note on rank assignment: the Protocol guarantees that handles are
44
+ returned in rank order (`handles[i].rank == i`). The replica entrypoint
45
+ learns its own rank either from an environment variable
46
+ (`REPLICA_RANK`) or from a backend-provided mechanism (Modal's
47
+ `Function.shard_rank`, etc.). The executor abstraction normalizes
48
+ rank by setting the env var.
49
+ """
50
+ backend_name: str
51
+ supports_inter_replica_network: bool
52
+
53
+ def launch_replicas(
54
+ self,
55
+ n_replicas: int,
56
+ entrypoint: str | Callable[..., Any],
57
+ entrypoint_args: Mapping[str, Any],
58
+ *,
59
+ gpu: str | None = None,
60
+ timeout: int = 3600,
61
+ ) -> list[ReplicaHandle]:
62
+ """Spin up N replicas in parallel.
63
+
64
+ Args:
65
+ n_replicas: number of replicas to launch
66
+ entrypoint: either an importable Python path (e.g.
67
+ "composer_replication.diloco.serverless.replica_entrypoint")
68
+ or a Callable (Local executor only).
69
+ entrypoint_args: kwargs passed to the entrypoint. The kwarg
70
+ `rank_env` (default "REPLICA_RANK") names the environment
71
+ variable in which the rank will be set on the replica.
72
+ gpu: backend-specific GPU spec, e.g. "A100", "H100". `None`
73
+ means CPU-only (smoke tests).
74
+ timeout: per-replica wall-clock timeout in seconds.
75
+
76
+ Returns:
77
+ list[ReplicaHandle] of length n_replicas, in rank order.
78
+ """
79
+ ...
80
+
81
+ def poll(self, handle: ReplicaHandle) -> str:
82
+ """Poll a replica's status. Returns one of:
83
+ "pending" | "running" | "succeeded" | "failed" | "cancelled".
84
+ """
85
+ ...
86
+
87
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
88
+ """Read up to n_lines of recent stdout/stderr from a replica."""
89
+ ...
90
+
91
+ def cancel(self, handle: ReplicaHandle) -> None:
92
+ """Best-effort cancel. No exception if already terminated."""
93
+ ...
94
+
95
+ def collect(
96
+ self,
97
+ handles: list[ReplicaHandle],
98
+ *,
99
+ timeout: int | None = None,
100
+ ) -> list[dict[str, Any]]:
101
+ """Block until all replicas finish; return per-replica result dicts.
102
+
103
+ Each result dict contains at least:
104
+ {"rank": int, "status": str, "exit_code": int | None,
105
+ "error": str | None}
106
+ """
107
+ ...
108
+
109
+
110
+ # ---------------------------------------------------------------------
111
+ # LocalProcessExecutor — reference implementation using multiprocessing
112
+ # ---------------------------------------------------------------------
113
+
114
+
115
+ def _local_replica_target(
116
+ rank: int,
117
+ rank_env: str,
118
+ entrypoint: Any,
119
+ entrypoint_args: Mapping[str, Any],
120
+ result_queue: mp.Queue,
121
+ ) -> None:
122
+ """multiprocessing target — runs in the child process."""
123
+ import os
124
+ import traceback
125
+
126
+ os.environ[rank_env] = str(rank)
127
+ try:
128
+ if callable(entrypoint):
129
+ result = entrypoint(**entrypoint_args)
130
+ elif isinstance(entrypoint, str):
131
+ # importable path
132
+ mod_path, _, fn_name = entrypoint.rpartition(".")
133
+ if not mod_path:
134
+ # Top-level script path; just import it and call its main()
135
+ import importlib
136
+ mod = importlib.import_module(entrypoint)
137
+ fn = getattr(mod, "main", None)
138
+ if fn is None:
139
+ raise RuntimeError(
140
+ f"entrypoint '{entrypoint}' has no main() function"
141
+ )
142
+ result = fn(**entrypoint_args)
143
+ else:
144
+ import importlib
145
+ mod = importlib.import_module(mod_path)
146
+ fn = getattr(mod, fn_name)
147
+ result = fn(**entrypoint_args)
148
+ else:
149
+ raise TypeError(
150
+ f"entrypoint must be Callable or importable str, got {type(entrypoint)!r}"
151
+ )
152
+ result_queue.put({"rank": rank, "status": "succeeded",
153
+ "exit_code": 0, "error": None, "result": result})
154
+ except Exception as e:
155
+ tb = traceback.format_exc()
156
+ result_queue.put({"rank": rank, "status": "failed",
157
+ "exit_code": 1, "error": f"{e!r}\n{tb}", "result": None})
158
+
159
+
160
+ class LocalProcessExecutor:
161
+ """Runs replicas as subprocesses on the local machine.
162
+
163
+ Use cases:
164
+ - Test the serverless layer end-to-end without cloud spend.
165
+ - Develop the algorithm locally with N>1 replicas and `file://`
166
+ rendezvous before deploying to Modal/HF Jobs.
167
+ - CI smoke tests.
168
+ """
169
+ backend_name = "local_process"
170
+ supports_inter_replica_network = True # localhost works
171
+
172
+ def __init__(self) -> None:
173
+ # use 'spawn' so the child has a fresh interpreter (avoid CUDA fork issues)
174
+ try:
175
+ self._ctx = mp.get_context("spawn")
176
+ except ValueError:
177
+ # Fallback for environments where 'spawn' isn't available
178
+ self._ctx = mp.get_context()
179
+ self._handles: dict[int, dict[str, Any]] = {}
180
+
181
+ def launch_replicas(
182
+ self,
183
+ n_replicas: int,
184
+ entrypoint: str | Callable[..., Any],
185
+ entrypoint_args: Mapping[str, Any],
186
+ *,
187
+ gpu: str | None = None,
188
+ timeout: int = 3600,
189
+ ) -> list[ReplicaHandle]:
190
+ if gpu is not None:
191
+ # Local executor doesn't pin GPUs; emit a soft warning.
192
+ sys.stderr.write(
193
+ f"[LocalProcessExecutor] gpu={gpu!r} ignored — "
194
+ f"local processes share whatever GPUs are visible.\n"
195
+ )
196
+ rank_env = entrypoint_args.get("rank_env", "REPLICA_RANK")
197
+
198
+ handles: list[ReplicaHandle] = []
199
+ result_queue: mp.Queue = self._ctx.Queue()
200
+ for rank in range(n_replicas):
201
+ args_for_rank = dict(entrypoint_args)
202
+ args_for_rank.pop("rank_env", None)
203
+ proc = self._ctx.Process(
204
+ target=_local_replica_target,
205
+ args=(rank, rank_env, entrypoint, args_for_rank, result_queue),
206
+ name=f"composer-replica-{rank}",
207
+ )
208
+ proc.start()
209
+ handle = ReplicaHandle(
210
+ rank=rank, backend_name=self.backend_name,
211
+ metadata={"pid": proc.pid, "start_ts": time.time()},
212
+ )
213
+ self._handles[rank] = {"proc": proc, "queue": result_queue,
214
+ "deadline": time.time() + timeout,
215
+ "result": None}
216
+ handles.append(handle)
217
+ return handles
218
+
219
+ def poll(self, handle: ReplicaHandle) -> str:
220
+ meta = self._handles.get(handle.rank)
221
+ if meta is None:
222
+ return "cancelled"
223
+ proc: mp.Process = meta["proc"]
224
+ if proc.is_alive():
225
+ return "running"
226
+ if meta.get("result") is not None:
227
+ return meta["result"]["status"]
228
+ # Process exited; read result if available
229
+ try:
230
+ queue: mp.Queue = meta["queue"]
231
+ while not queue.empty():
232
+ r = queue.get_nowait()
233
+ self._handles[r["rank"]]["result"] = r
234
+ except Exception:
235
+ pass
236
+ if meta.get("result") is not None:
237
+ return meta["result"]["status"]
238
+ return "failed" if proc.exitcode != 0 else "succeeded"
239
+
240
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
241
+ # multiprocessing.Process doesn't natively capture stdout; we'd
242
+ # need a Pipe or file redirection. For the local reference impl,
243
+ # we just point the user at the result dict's `error` field.
244
+ meta = self._handles.get(handle.rank)
245
+ if meta is None:
246
+ return f"<replica {handle.rank}: no metadata>"
247
+ if meta.get("result"):
248
+ err = meta["result"].get("error") or ""
249
+ return f"[rank {handle.rank}] {err[-2000:]}"
250
+ return f"<replica {handle.rank}: still running, no captured logs>"
251
+
252
+ def cancel(self, handle: ReplicaHandle) -> None:
253
+ meta = self._handles.get(handle.rank)
254
+ if meta is None:
255
+ return
256
+ proc: mp.Process = meta["proc"]
257
+ if proc.is_alive():
258
+ proc.terminate()
259
+ proc.join(timeout=5)
260
+ if proc.is_alive():
261
+ proc.kill()
262
+
263
+ def collect(
264
+ self,
265
+ handles: list[ReplicaHandle],
266
+ *,
267
+ timeout: int | None = None,
268
+ ) -> list[dict[str, Any]]:
269
+ deadline = time.time() + (timeout if timeout is not None else 3600)
270
+ # Wait for all processes to finish
271
+ for h in handles:
272
+ meta = self._handles.get(h.rank)
273
+ if meta is None:
274
+ continue
275
+ proc: mp.Process = meta["proc"]
276
+ remaining = max(0.0, deadline - time.time())
277
+ proc.join(timeout=remaining)
278
+ if proc.is_alive():
279
+ proc.terminate()
280
+ proc.join(timeout=5)
281
+ # Drain results
282
+ results_by_rank: dict[int, dict[str, Any]] = {}
283
+ for h in handles:
284
+ meta = self._handles.get(h.rank)
285
+ if meta is None:
286
+ results_by_rank[h.rank] = {
287
+ "rank": h.rank, "status": "cancelled",
288
+ "exit_code": None, "error": "no metadata", "result": None,
289
+ }
290
+ continue
291
+ queue: mp.Queue = meta["queue"]
292
+ while not queue.empty():
293
+ try:
294
+ r = queue.get_nowait()
295
+ results_by_rank[r["rank"]] = r
296
+ except Exception:
297
+ break
298
+ if h.rank not in results_by_rank:
299
+ proc: mp.Process = meta["proc"]
300
+ results_by_rank[h.rank] = {
301
+ "rank": h.rank,
302
+ "status": "succeeded" if proc.exitcode == 0 else "failed",
303
+ "exit_code": proc.exitcode,
304
+ "error": None if proc.exitcode == 0 else f"exit code {proc.exitcode}",
305
+ "result": None,
306
+ }
307
+ return [results_by_rank[h.rank] for h in handles]
308
+
309
+
310
+ __all__ = ["LocalProcessExecutor", "ReplicaHandle", "ServerlessExecutor"]
composer_replication/diloco/serverless/hf_jobs.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace Jobs executor — skeleton for v0.
2
+
3
+ Per ADR-005, HF Jobs is one of two v0 target executors. This file is a
4
+ STUB. The full integration uses `huggingface_hub.run_job` (added in
5
+ huggingface_hub >= 0.27, ~2026 era) which spins up a containerized job
6
+ backed by HF's compute pool.
7
+
8
+ Pricing reference (2026-05-26): A100 ≈ $4.18/hr, H100 ≈ $9.50/hr. Cold
9
+ start ≈ 60s. NO inter-job networking — must use object-store rendezvous.
10
+
11
+ Status: SKELETON. Real implementation pending v0 polish wave.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ from typing import Any, Callable, Mapping
16
+
17
+ from composer_replication.diloco.serverless.executor import (
18
+ ReplicaHandle,
19
+ ServerlessExecutor,
20
+ )
21
+
22
+
23
+ class HFJobsExecutor(ServerlessExecutor):
24
+ """Run replicas as HuggingFace Jobs in parallel.
25
+
26
+ Reference implementation pattern:
27
+
28
+ from huggingface_hub import run_job
29
+ jobs = []
30
+ for rank in range(N):
31
+ job = run_job(
32
+ image="...", # container with composer_replication installed
33
+ command=[
34
+ "python", "-m",
35
+ "composer_replication.diloco.serverless.replica_entrypoint",
36
+ "--rank", str(rank),
37
+ "--rendezvous", "hf://datasets/myuser/run42/",
38
+ ],
39
+ env={"REPLICA_RANK": str(rank), "WORLD_SIZE": str(N)},
40
+ gpu="a100",
41
+ )
42
+ jobs.append(job)
43
+ return [ReplicaHandle(rank=i, backend_name="hf_jobs",
44
+ metadata={"job_id": jobs[i].id})
45
+ for i in range(N)]
46
+
47
+ Object-store rendezvous works naturally with the HF Datasets-as-storage
48
+ pattern — `hf://datasets/{user}/{run_id}/` is fsspec-compatible via
49
+ `huggingface_hub`'s fsspec integration.
50
+
51
+ Status: SKELETON.
52
+ """
53
+ backend_name = "hf_jobs"
54
+ supports_inter_replica_network = False
55
+
56
+ def __init__(self) -> None:
57
+ try:
58
+ from huggingface_hub import HfApi # noqa: F401
59
+ except ImportError as e:
60
+ raise RuntimeError(
61
+ "HFJobsExecutor requires huggingface_hub. Got: " + repr(e)
62
+ )
63
+ # Real implementation: instantiate HfApi, validate token, etc.
64
+ raise NotImplementedError(
65
+ "HFJobsExecutor is a v0 skeleton; full implementation pending. "
66
+ "Use LocalProcessExecutor for testing."
67
+ )
68
+
69
+ def launch_replicas(
70
+ self,
71
+ n_replicas: int,
72
+ entrypoint: str | Callable[..., Any],
73
+ entrypoint_args: Mapping[str, Any],
74
+ *,
75
+ gpu: str | None = "a100",
76
+ timeout: int = 3600,
77
+ ) -> list[ReplicaHandle]:
78
+ raise NotImplementedError
79
+
80
+ def poll(self, handle: ReplicaHandle) -> str:
81
+ raise NotImplementedError
82
+
83
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
84
+ raise NotImplementedError
85
+
86
+ def cancel(self, handle: ReplicaHandle) -> None:
87
+ raise NotImplementedError
88
+
89
+ def collect(
90
+ self,
91
+ handles: list[ReplicaHandle],
92
+ *,
93
+ timeout: int | None = None,
94
+ ) -> list[dict[str, Any]]:
95
+ raise NotImplementedError
96
+
97
+
98
+ __all__ = ["HFJobsExecutor"]
composer_replication/diloco/serverless/modal.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal executor — skeleton for v0.
2
+
3
+ This file is a STUB. The full Modal integration requires the `modal`
4
+ client library installed (`pip install modal`) and a configured Modal
5
+ account (`~/.modal.toml`). The user's environment has both, but the
6
+ test suite must run without them, so we keep this file import-safe.
7
+
8
+ Real implementation lives in v0 polish; the docstring below is the
9
+ contract.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Any, Callable, Mapping
14
+
15
+ from composer_replication.diloco.serverless.executor import (
16
+ ReplicaHandle,
17
+ ServerlessExecutor,
18
+ )
19
+
20
+
21
+ class ModalExecutor(ServerlessExecutor):
22
+ """Run replicas as Modal Functions in parallel.
23
+
24
+ Reference implementation pattern (per ADR-005):
25
+
26
+ @app.function(gpu="A100-40GB", timeout=3600)
27
+ def run_replica(rank: int, rendezvous_uri: str, **kwargs):
28
+ os.environ["REPLICA_RANK"] = str(rank)
29
+ from composer_replication.diloco.serverless import (
30
+ MockManager, ObjectStoreAllReduce,
31
+ )
32
+ store = ObjectStoreAllReduce(rendezvous_uri,
33
+ rank=rank, world_size=N)
34
+ manager = MockManager(store)
35
+ # ... run the trainer with this manager ...
36
+
37
+ Then `launch_replicas` does:
38
+ calls = [run_replica.spawn(rank=i, ...) for i in range(N)]
39
+ return [ReplicaHandle(rank=i, backend_name="modal",
40
+ metadata={"call_id": calls[i].object_id})
41
+ for i in range(N)]
42
+
43
+ Pricing reference (2026-05-26): A100-40GB ≈ $1.95/hr, H100 ≈ $5.50/hr.
44
+ Cold start ≈ 30s. Inter-job networking via cluster mode (opt-in,
45
+ not used by default).
46
+
47
+ Status: SKELETON. Real implementation pending v0 polish wave.
48
+ """
49
+ backend_name = "modal"
50
+ supports_inter_replica_network = False # default; cluster mode = True
51
+
52
+ def __init__(self, *, app_name: str = "composer-replication-diloco") -> None:
53
+ try:
54
+ import modal # noqa: F401
55
+ except ImportError as e:
56
+ raise RuntimeError(
57
+ "ModalExecutor requires the modal client. Install with "
58
+ "`pip install modal` and configure with `modal token new`. "
59
+ "Got: " + repr(e)
60
+ )
61
+ self.app_name = app_name
62
+ # Real implementation: build a `modal.App` and register `run_replica`
63
+ # here so that subsequent `launch_replicas` can `.spawn()` it.
64
+ raise NotImplementedError(
65
+ "ModalExecutor is a v0 skeleton; full implementation pending. "
66
+ "Use LocalProcessExecutor for testing."
67
+ )
68
+
69
+ # All Protocol methods raise NotImplementedError via __init__ — the
70
+ # class never instantiates successfully in the skeleton. Sketch
71
+ # signatures here for documentation:
72
+
73
+ def launch_replicas(
74
+ self,
75
+ n_replicas: int,
76
+ entrypoint: str | Callable[..., Any],
77
+ entrypoint_args: Mapping[str, Any],
78
+ *,
79
+ gpu: str | None = "A100-40GB",
80
+ timeout: int = 3600,
81
+ ) -> list[ReplicaHandle]:
82
+ raise NotImplementedError
83
+
84
+ def poll(self, handle: ReplicaHandle) -> str:
85
+ raise NotImplementedError
86
+
87
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
88
+ raise NotImplementedError
89
+
90
+ def cancel(self, handle: ReplicaHandle) -> None:
91
+ raise NotImplementedError
92
+
93
+ def collect(
94
+ self,
95
+ handles: list[ReplicaHandle],
96
+ *,
97
+ timeout: int | None = None,
98
+ ) -> list[dict[str, Any]]:
99
+ raise NotImplementedError
100
+
101
+
102
+ __all__ = ["ModalExecutor"]
composer_replication/diloco/serverless/replica_entrypoint.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Replica entrypoint — what each serverless replica runs.
2
+
3
+ This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`,
4
+ `HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env
5
+ var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI,
6
+ wraps it in a `MockManager`, and hands it off to the user's training
7
+ function.
8
+
9
+ Usage from an executor:
10
+
11
+ >>> executor.launch_replicas(
12
+ ... n_replicas=4,
13
+ ... entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
14
+ ... entrypoint_args={
15
+ ... "rendezvous_uri": "/tmp/run42/",
16
+ ... "world_size": 4,
17
+ ... "trainer_module": "my_project.trainer",
18
+ ... "trainer_fn": "train",
19
+ ... "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"},
20
+ ... },
21
+ ... )
22
+
23
+ The entrypoint expects:
24
+ - `REPLICA_RANK` env var set to the rank (0..world_size-1)
25
+ - `rendezvous_uri`: fsspec URI for object-store rendezvous
26
+ - `world_size`: total replicas
27
+ - `trainer_module`, `trainer_fn`: importable path to the user's train fn
28
+ - `trainer_kwargs`: dict passed to the user's train fn, plus an injected
29
+ `manager` kwarg containing the `MockManager`
30
+ """
31
+ from __future__ import annotations
32
+
33
+ import importlib
34
+ import os
35
+ from typing import Any
36
+
37
+
38
+ def main(
39
+ rendezvous_uri: str,
40
+ world_size: int,
41
+ trainer_module: str,
42
+ trainer_fn: str = "train",
43
+ trainer_kwargs: dict[str, Any] | None = None,
44
+ ) -> Any:
45
+ """Entrypoint executed inside each replica.
46
+
47
+ Args:
48
+ rendezvous_uri: fsspec URI (or local path) for the rendezvous
49
+ world_size: total replicas
50
+ trainer_module: importable Python module containing the user's
51
+ train function
52
+ trainer_fn: name of the function to call (default "train")
53
+ trainer_kwargs: kwargs passed to the train function
54
+
55
+ Returns:
56
+ Whatever the train function returns.
57
+ """
58
+ from composer_replication.diloco.serverless.allreduce import (
59
+ MockManager,
60
+ ObjectStoreAllReduce,
61
+ )
62
+
63
+ rank_str = os.environ.get("REPLICA_RANK")
64
+ if rank_str is None:
65
+ raise RuntimeError(
66
+ "REPLICA_RANK env var not set. The serverless executor "
67
+ "should set this for each replica."
68
+ )
69
+ rank = int(rank_str)
70
+
71
+ if not (0 <= rank < world_size):
72
+ raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})")
73
+
74
+ store = ObjectStoreAllReduce(
75
+ uri=rendezvous_uri,
76
+ rank=rank,
77
+ world_size=world_size,
78
+ )
79
+ manager = MockManager(store)
80
+
81
+ mod = importlib.import_module(trainer_module)
82
+ fn = getattr(mod, trainer_fn)
83
+
84
+ kwargs = dict(trainer_kwargs or {})
85
+ kwargs["manager"] = manager # injected
86
+ kwargs["rank"] = rank
87
+ kwargs["world_size"] = world_size
88
+ return fn(**kwargs)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ import argparse
93
+ import json
94
+
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument("--rendezvous", required=True)
97
+ parser.add_argument("--world-size", type=int, required=True)
98
+ parser.add_argument("--trainer-module", required=True)
99
+ parser.add_argument("--trainer-fn", default="train")
100
+ parser.add_argument("--trainer-kwargs-json", default="{}")
101
+ args = parser.parse_args()
102
+
103
+ main(
104
+ rendezvous_uri=args.rendezvous,
105
+ world_size=args.world_size,
106
+ trainer_module=args.trainer_module,
107
+ trainer_fn=args.trainer_fn,
108
+ trainer_kwargs=json.loads(args.trainer_kwargs_json),
109
+ )
composer_replication/diloco/serverless/tests/__init__.py ADDED
File without changes
composer_replication/diloco/serverless/tests/test_serverless_local.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verifies the serverless DiLoCo allreduce wraps correctly across local
2
+ multiprocessing replicas using `file://` rendezvous.
3
+
4
+ This is the core multi-process test for the serverless layer. It exercises
5
+ the real allreduce barrier (with concurrent processes), not just the
6
+ single-process API.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import sys
12
+ import tempfile
13
+ import time
14
+
15
+ import pytest
16
+ import torch
17
+
18
+ from composer_replication.diloco.serverless import (
19
+ LocalProcessExecutor,
20
+ ObjectStoreAllReduce,
21
+ ReplicaHandle,
22
+ )
23
+
24
+
25
+ # ---------------------------------------------------------------------
26
+ # Single-process tests of ObjectStoreAllReduce primitives
27
+ # (don't need executor, just the file:// path + local manual orchestration)
28
+ # ---------------------------------------------------------------------
29
+
30
+
31
+ def test_object_store_allreduce_init_validates_rank():
32
+ with tempfile.TemporaryDirectory() as td:
33
+ with pytest.raises(ValueError, match="not in"):
34
+ ObjectStoreAllReduce(td, rank=5, world_size=2)
35
+
36
+
37
+ def test_object_store_allreduce_local_paths_create_dir():
38
+ """Local backend should mkdir on init."""
39
+ with tempfile.TemporaryDirectory() as td:
40
+ new_path = os.path.join(td, "subdir", "subsubdir")
41
+ store = ObjectStoreAllReduce(new_path, rank=0, world_size=1)
42
+ assert os.path.isdir(new_path)
43
+ assert store.world_size == 1
44
+
45
+
46
+ def test_object_store_allreduce_world_size_1_passthrough():
47
+ """With world_size=1 it just averages the tensor with itself."""
48
+ with tempfile.TemporaryDirectory() as td:
49
+ store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
50
+ t = torch.tensor([1.0, 2.0, 3.0])
51
+ result = store.allreduce(t.clone())
52
+ torch.testing.assert_close(result, t, atol=1e-6, rtol=1e-6)
53
+
54
+
55
+ def test_object_store_allreduce_round_id_increments():
56
+ with tempfile.TemporaryDirectory() as td:
57
+ store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
58
+ t = torch.zeros(3)
59
+ assert store.round_id == 0
60
+ store.allreduce(t.clone())
61
+ assert store.round_id == 1
62
+ store.allreduce(t.clone())
63
+ assert store.round_id == 2
64
+
65
+
66
+ # ---------------------------------------------------------------------
67
+ # Multi-process tests (the real verification — local executor + spawn)
68
+ # ---------------------------------------------------------------------
69
+
70
+
71
+ def _replica_compute_and_sync(
72
+ rendezvous_uri: str,
73
+ world_size: int,
74
+ rank_value: float,
75
+ ) -> dict:
76
+ """Top-level function — must be importable for multiprocessing 'spawn'.
77
+
78
+ Each replica creates a tensor whose value is `rank_value * (rank+1)` and
79
+ runs allreduce. The expected result is the mean of all replicas' tensors.
80
+ """
81
+ rank = int(os.environ["REPLICA_RANK"])
82
+ store = ObjectStoreAllReduce(
83
+ rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
84
+ )
85
+ # tensor that depends on rank
86
+ t = torch.full((4,), float(rank_value * (rank + 1)))
87
+ pre = t.clone()
88
+ averaged = store.allreduce(t)
89
+ return {
90
+ "rank": rank,
91
+ "pre": pre.tolist(),
92
+ "post": averaged.tolist(),
93
+ "world_size": world_size,
94
+ }
95
+
96
+
97
+ @pytest.mark.parametrize("n_replicas", [2, 3])
98
+ def test_local_executor_runs_allreduce_across_replicas(n_replicas):
99
+ """End-to-end: 2-3 replica processes each call allreduce; result is the mean."""
100
+ with tempfile.TemporaryDirectory() as td:
101
+ rendezvous = os.path.join(td, "run")
102
+ executor = LocalProcessExecutor()
103
+ handles = executor.launch_replicas(
104
+ n_replicas=n_replicas,
105
+ entrypoint=f"{__name__}._replica_compute_and_sync",
106
+ entrypoint_args={
107
+ "rendezvous_uri": rendezvous,
108
+ "world_size": n_replicas,
109
+ "rank_value": 10.0,
110
+ "rank_env": "REPLICA_RANK",
111
+ },
112
+ timeout=180,
113
+ )
114
+ assert len(handles) == n_replicas
115
+ for i, h in enumerate(handles):
116
+ assert h.rank == i
117
+ assert h.backend_name == "local_process"
118
+
119
+ results = executor.collect(handles, timeout=180)
120
+ assert len(results) == n_replicas
121
+
122
+ # Verify all succeeded
123
+ for r in results:
124
+ assert r["status"] == "succeeded", \
125
+ f"rank {r['rank']} failed: {r.get('error')}"
126
+
127
+ # Each replica created tensor full(rank_value * (rank+1)).
128
+ # Expected mean = rank_value * (1+2+...+N) / N
129
+ N = n_replicas
130
+ expected_mean = 10.0 * (N * (N + 1) / 2) / N
131
+
132
+ for r in results:
133
+ post = r["result"]["post"]
134
+ for v in post:
135
+ assert abs(v - expected_mean) < 1e-4, \
136
+ f"rank {r['rank']}: expected mean {expected_mean}, got {v}"
137
+
138
+
139
+ def _replica_two_round_sync(
140
+ rendezvous_uri: str,
141
+ world_size: int,
142
+ ) -> dict:
143
+ """Each replica does TWO consecutive allreduce calls; checks round_id increments."""
144
+ rank = int(os.environ["REPLICA_RANK"])
145
+ store = ObjectStoreAllReduce(
146
+ rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
147
+ )
148
+ t1 = torch.full((2,), float(rank))
149
+ avg1 = store.allreduce(t1).clone()
150
+ t2 = torch.full((2,), float(rank * 100))
151
+ avg2 = store.allreduce(t2).clone()
152
+ return {
153
+ "rank": rank,
154
+ "round_after_2_calls": store.round_id,
155
+ "avg1": avg1.tolist(),
156
+ "avg2": avg2.tolist(),
157
+ }
158
+
159
+
160
+ def test_local_executor_handles_multiple_rounds():
161
+ """Two consecutive rounds each give the right mean; round counter advances."""
162
+ n_replicas = 3
163
+ with tempfile.TemporaryDirectory() as td:
164
+ rendezvous = os.path.join(td, "run-2round")
165
+ executor = LocalProcessExecutor()
166
+ handles = executor.launch_replicas(
167
+ n_replicas=n_replicas,
168
+ entrypoint=f"{__name__}._replica_two_round_sync",
169
+ entrypoint_args={
170
+ "rendezvous_uri": rendezvous,
171
+ "world_size": n_replicas,
172
+ },
173
+ timeout=180,
174
+ )
175
+ results = executor.collect(handles, timeout=180)
176
+ for r in results:
177
+ assert r["status"] == "succeeded", r.get("error")
178
+ assert r["result"]["round_after_2_calls"] == 2
179
+ # mean of 0,1,2 = 1.0
180
+ assert all(abs(v - 1.0) < 1e-4 for v in r["result"]["avg1"])
181
+ # mean of 0,100,200 = 100.0
182
+ assert all(abs(v - 100.0) < 1e-4 for v in r["result"]["avg2"])
183
+
184
+
185
+ def _replica_that_raises(rendezvous_uri: str, world_size: int) -> dict:
186
+ """Simulates a replica that crashes mid-run."""
187
+ rank = int(os.environ["REPLICA_RANK"])
188
+ if rank == 1:
189
+ raise RuntimeError(f"Simulated crash on rank {rank}")
190
+ return {"rank": rank, "ok": True}
191
+
192
+
193
+ def test_local_executor_reports_failed_replicas():
194
+ """When a replica crashes, collect() reports it as failed without hanging
195
+ (other ranks complete; the failed one should be reflected in the result)."""
196
+ n_replicas = 2
197
+ with tempfile.TemporaryDirectory() as td:
198
+ rendezvous = os.path.join(td, "run-failure")
199
+ executor = LocalProcessExecutor()
200
+ handles = executor.launch_replicas(
201
+ n_replicas=n_replicas,
202
+ entrypoint=f"{__name__}._replica_that_raises",
203
+ entrypoint_args={
204
+ "rendezvous_uri": rendezvous,
205
+ "world_size": n_replicas,
206
+ },
207
+ timeout=30,
208
+ )
209
+ results = executor.collect(handles, timeout=30)
210
+ statuses = {r["rank"]: r["status"] for r in results}
211
+ assert statuses[0] == "succeeded"
212
+ assert statuses[1] == "failed"
213
+ # Failure log should mention the simulated crash
214
+ failure_log = next(r for r in results if r["rank"] == 1).get("error") or ""
215
+ assert "Simulated crash" in failure_log
216
+
217
+
218
+ # ---------------------------------------------------------------------
219
+ # Sanity: MockManager is shape-compatible with torchft Manager surface
220
+ # ---------------------------------------------------------------------
221
+
222
+
223
+ def test_mock_manager_shape_compat():
224
+ from composer_replication.diloco.serverless import MockManager
225
+ with tempfile.TemporaryDirectory() as td:
226
+ store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
227
+ mgr = MockManager(store)
228
+ # torchft.Manager surface
229
+ assert hasattr(mgr, "allreduce")
230
+ assert hasattr(mgr, "should_commit")
231
+ assert hasattr(mgr, "start_quorum")
232
+ assert hasattr(mgr, "wait_quorum")
233
+ assert mgr.num_participants == 1
234
+ assert mgr.rank == 0
235
+ assert mgr.should_commit() is True
236
+ # Single-replica allreduce is a passthrough
237
+ t = torch.tensor([1.0, 2.0])
238
+ out = mgr.allreduce(t.clone())
239
+ torch.testing.assert_close(out, t, atol=1e-6, rtol=1e-6)
composer_replication/distillation/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """composer_replication.distillation — pluggable self-distillation losses.
2
+
3
+ Per ADR-007, three losses additive to the framework's existing
4
+ SDPO/OPSD (`generalized_jsd_loss`):
5
+
6
+ - SimPO: reference-free DPO replacement (channel 3 alternative)
7
+ - TAID: annealed teacher interpolation (wraps generalized_jsd_loss for channel 2)
8
+ - Entropy-Aware OPD: token-wise gated forward/reverse KL (alternative
9
+ channel-2 wrapper, per ICLR 2026 Spotlight)
10
+
11
+ All three are pure PyTorch — no external deps — so they ship in the core
12
+ package without optional extras.
13
+
14
+ Usage in `compose_loss`:
15
+
16
+ >>> from composer_replication import compose_loss
17
+ >>> components = compose_loss(
18
+ ... model, batch,
19
+ ... dpo_variant="simpo", # channel 3: DPO -> SimPO
20
+ ... sdpo_wrapper="taid", # channel 2: SDPO -> TAID-SDPO
21
+ ... taid_schedule_step=1500, taid_total_steps=10_000,
22
+ ... )
23
+
24
+ Defaults are unchanged (pure DPO + pure SDPO).
25
+ """
26
+ from __future__ import annotations
27
+
28
+ from composer_replication.distillation.simpo import simpo_loss
29
+ from composer_replication.distillation.taid import taid_loss
30
+ from composer_replication.distillation.entropy_aware_opd import entropy_aware_opd_loss
31
+
32
+ __all__ = [
33
+ "simpo_loss",
34
+ "taid_loss",
35
+ "entropy_aware_opd_loss",
36
+ ]
composer_replication/distillation/entropy_aware_opd.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Entropy-Aware OPD — token-wise gated forward/reverse KL.
2
+
3
+ Paper: ICLR 2026 Spotlight "Entropy-Aware On-Policy Distillation"
4
+ (OpenReview WSRQ37tzk1, code release pending as of 2026-05-26)
5
+
6
+ Standard reverse-KL distillation (which SDPO/OPSD belongs to) has a known
7
+ mode-seeking failure: when the teacher distribution has high entropy at
8
+ some token positions (e.g. open-ended generation), reverse KL collapses
9
+ the student onto a single mode, throwing away the teacher's diversity.
10
+
11
+ Forward KL is mode-covering and would handle these positions correctly,
12
+ but is mode-flattening in the long tail.
13
+
14
+ Entropy-Aware OPD computes the per-token entropy of the teacher
15
+ distribution and gates between forward and reverse KL on a per-token
16
+ basis: high-entropy tokens use forward KL (preserve diversity),
17
+ low-entropy tokens use reverse KL (sharpen toward the teacher's mode).
18
+
19
+ L = Σ_t w(t) · KL_fwd(student || teacher)_t
20
+ + (1 - w(t)) · KL_rev(student || teacher)_t
21
+
22
+ Where w(t) = clamp(H_teacher(t) / H_max, 0, 1) — high entropy → forward
23
+ KL weight near 1, low entropy → reverse KL weight near 1.
24
+
25
+ This is a clean-room implementation from the paper's pseudocode pending
26
+ the official code drop. License question for the official code is open;
27
+ this implementation is MIT-compatible by construction.
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import math
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+
36
+
37
+ def teacher_entropy(teacher_logits: torch.Tensor) -> torch.Tensor:
38
+ """Per-token entropy of the teacher distribution.
39
+
40
+ Returns:
41
+ (B, T) entropy in nats.
42
+ """
43
+ log_p = F.log_softmax(teacher_logits, dim=-1)
44
+ p = log_p.exp()
45
+ # Entropy = -Σ p log p
46
+ return -(p * log_p).sum(dim=-1)
47
+
48
+
49
+ def entropy_aware_opd_loss(
50
+ student_logits: torch.Tensor,
51
+ teacher_logits: torch.Tensor,
52
+ *,
53
+ labels: torch.Tensor | None = None,
54
+ h_max: float | None = None,
55
+ temperature: float = 1.0,
56
+ reduction: str = "batchmean",
57
+ ) -> torch.Tensor:
58
+ """Entropy-aware mixture of forward and reverse KL.
59
+
60
+ Args:
61
+ student_logits: (B, T, V) student logits with grad
62
+ teacher_logits: (B, T, V) teacher logits (no grad)
63
+ labels: (B, T) optional 0/1 mask — only contribute loss on
64
+ labels==1 positions. None means contribute everywhere.
65
+ h_max: maximum-entropy normalizer. Defaults to log(V) (uniform-
66
+ distribution entropy = the max possible entropy at vocab size V).
67
+ temperature: temperature applied to BOTH student and teacher logits
68
+ before softmax
69
+ reduction: "batchmean" | "sum" | "mean" | "none"
70
+
71
+ Returns:
72
+ Scalar loss (or unreduced if `reduction="none"`).
73
+
74
+ Reference: ICLR 2026 Spotlight WSRQ37tzk1 §3 (clean-room implementation).
75
+ """
76
+ if student_logits.shape != teacher_logits.shape:
77
+ raise ValueError(
78
+ f"shape mismatch: student={student_logits.shape}, "
79
+ f"teacher={teacher_logits.shape}"
80
+ )
81
+
82
+ V = student_logits.size(-1)
83
+ if h_max is None:
84
+ h_max = math.log(V)
85
+
86
+ s_log = F.log_softmax(student_logits / temperature, dim=-1)
87
+ t_log = F.log_softmax(teacher_logits / temperature, dim=-1)
88
+
89
+ s_p = s_log.exp()
90
+ t_p = t_log.exp()
91
+
92
+ # Forward KL (teacher || student): mode-covering
93
+ # KL(t || s) = Σ t · (log t - log s)
94
+ kl_fwd = (t_p * (t_log - s_log)).sum(dim=-1)
95
+
96
+ # Reverse KL (student || teacher): mode-seeking (this is what SDPO uses)
97
+ # KL(s || t) = Σ s · (log s - log t)
98
+ kl_rev = (s_p * (s_log - t_log)).sum(dim=-1)
99
+
100
+ # Per-token teacher entropy → gate weight
101
+ H_t = teacher_entropy(teacher_logits) # (B, T) in nats
102
+ w = (H_t / h_max).clamp(0.0, 1.0) # (B, T) in [0, 1]
103
+
104
+ # Mix: high entropy → forward KL; low entropy → reverse KL
105
+ per_token_loss = w * kl_fwd + (1 - w) * kl_rev # (B, T)
106
+
107
+ if labels is not None:
108
+ if labels.shape != per_token_loss.shape:
109
+ raise ValueError(
110
+ f"labels shape {labels.shape} != per-token-loss shape "
111
+ f"{per_token_loss.shape}"
112
+ )
113
+ per_token_loss = per_token_loss * labels.float()
114
+
115
+ if reduction == "none":
116
+ return per_token_loss
117
+ if reduction == "sum":
118
+ return per_token_loss.sum()
119
+ if reduction == "mean":
120
+ return per_token_loss.mean()
121
+ if reduction == "batchmean":
122
+ return per_token_loss.sum() / max(1, per_token_loss.shape[0])
123
+ raise ValueError(f"unknown reduction: {reduction!r}")
124
+
125
+
126
+ __all__ = ["teacher_entropy", "entropy_aware_opd_loss"]
composer_replication/distillation/simpo.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SimPO loss — reference-free DPO replacement.
2
+
3
+ Paper: "SimPO: Simple Preference Optimization with a Reference-Free Reward"
4
+ Meng et al., NeurIPS 2024 (arXiv:2405.14734)
5
+ License: MIT (https://github.com/princeton-nlp/SimPO)
6
+
7
+ Standard DPO requires log-probabilities under both the policy and a
8
+ reference policy:
9
+
10
+ L_DPO = -log σ( β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))] )
11
+
12
+ SimPO drops the reference-policy term, replaces it with a target margin γ,
13
+ and uses average sequence log-probability instead of sum. This removes the
14
+ reference-model VRAM cost (which is a meaningful fraction of total
15
+ training-time memory).
16
+
17
+ L_SimPO = -log σ( β·[avg_logπ(c) - avg_logπ(r)] - γ )
18
+
19
+ Where:
20
+ - avg_logπ(c) = (1/|c|) · Σ_t logπ(c_t | c_<t, prompt)
21
+ - β: scaling factor (paper default: 2.0)
22
+ - γ: target margin (paper default: 1.0)
23
+
24
+ Compose with the framework: replace channel-3 `_compute_trace_replay_loss`
25
+ when `dpo_variant="simpo"` is passed to `compose_loss`. Inputs change:
26
+ SimPO does NOT consume `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`
27
+ (those become unused).
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import torch
32
+ import torch.nn.functional as F
33
+
34
+
35
+ def simpo_loss(
36
+ chosen_avg_logprobs: torch.Tensor,
37
+ rejected_avg_logprobs: torch.Tensor,
38
+ *,
39
+ beta: float = 2.0,
40
+ gamma: float = 1.0,
41
+ ) -> torch.Tensor:
42
+ """SimPO loss — reference-free DPO with target margin.
43
+
44
+ Args:
45
+ chosen_avg_logprobs: (B,) average per-token log-prob of the chosen
46
+ response under the policy. Computed as
47
+ `chosen_logprobs.sum() / response_length`.
48
+ rejected_avg_logprobs: (B,) same for rejected.
49
+ beta: scaling factor (paper default 2.0)
50
+ gamma: target margin (paper default 1.0)
51
+
52
+ Returns:
53
+ Scalar loss; lower is better.
54
+
55
+ Reference: arXiv:2405.14734 Eq. (5).
56
+ """
57
+ if chosen_avg_logprobs.shape != rejected_avg_logprobs.shape:
58
+ raise ValueError(
59
+ f"chosen and rejected avg-logprob tensors must have the same shape, "
60
+ f"got chosen={chosen_avg_logprobs.shape}, "
61
+ f"rejected={rejected_avg_logprobs.shape}"
62
+ )
63
+ logits = beta * (chosen_avg_logprobs - rejected_avg_logprobs) - gamma
64
+ return -F.logsigmoid(logits).mean()
65
+
66
+
67
+ def avg_sequence_logprob(
68
+ model_logprobs: torch.Tensor,
69
+ response_mask: torch.Tensor,
70
+ ) -> torch.Tensor:
71
+ """Helper: convert (B, T) per-token log-probs + (B, T) response mask into
72
+ (B,) per-sequence AVERAGE log-probability over response tokens.
73
+
74
+ SimPO uses the average (not sum) so that long sequences aren't
75
+ penalized for having many tokens. The mask should be 1 on response
76
+ tokens and 0 on prompt+padding.
77
+ """
78
+ masked = model_logprobs * response_mask.float()
79
+ n_tokens = response_mask.sum(dim=-1).clamp_min(1.0).float()
80
+ return masked.sum(dim=-1) / n_tokens
81
+
82
+
83
+ __all__ = ["simpo_loss", "avg_sequence_logprob"]
composer_replication/distillation/taid.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TAID loss — Temporally Adaptive Interpolated Distillation.
2
+
3
+ Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient
4
+ Knowledge Transfer in Language Models"
5
+ Sakana AI, arXiv:2501.16937
6
+ License: Apache-2.0 (https://github.com/SakanaAI/TAID)
7
+
8
+ Standard JSD/KL distillation on a large student-teacher capacity gap can
9
+ suffer from mode collapse: the student converges to a degenerate point
10
+ distribution that minimizes the KL by ignoring tail probabilities.
11
+
12
+ TAID interpolates between an "identity" target (the student's own
13
+ distribution at step 0) and the teacher's distribution, with the
14
+ interpolation coefficient annealed from 0 → 1 over training:
15
+
16
+ P_target(t) = (1 - α(t)) · P_student_init + α(t) · P_teacher
17
+
18
+ Where α(t) is a schedule (linear, cosine, or paper-default exp ramp).
19
+
20
+ The student then learns against `P_target(t)` using the standard JSD/KL
21
+ loss. As training progresses, the target shifts smoothly from "what you
22
+ already are" toward "what the teacher knows," giving the student a
23
+ smooth path through capacity-gap regions where naive distillation
24
+ collapses.
25
+
26
+ Compose with the framework: TAID *wraps* `generalized_jsd_loss`. The
27
+ wrapper passes a blended target instead of the raw teacher target. When
28
+ `taid_alpha=1.0` we recover pure SDPO (the standard JSD/OPSD path).
29
+ """
30
+ from __future__ import annotations
31
+
32
+ import math
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+
38
+ def taid_alpha_schedule(
39
+ step: int,
40
+ total_steps: int,
41
+ *,
42
+ schedule: str = "linear",
43
+ alpha_min: float = 0.0,
44
+ alpha_max: float = 1.0,
45
+ warmup_frac: float = 0.0,
46
+ ) -> float:
47
+ """Compute α(t) for the TAID schedule.
48
+
49
+ Args:
50
+ step: current training step (0-indexed)
51
+ total_steps: total training steps planned
52
+ schedule: "linear" | "cosine" | "exp"
53
+ alpha_min: starting α (default 0 = pure student-init target)
54
+ alpha_max: ending α (default 1 = pure teacher target)
55
+ warmup_frac: fraction of total_steps spent at alpha_min
56
+
57
+ Returns:
58
+ α value in [alpha_min, alpha_max]
59
+
60
+ Reference: arXiv:2501.16937 §3.2.
61
+ """
62
+ if total_steps <= 0:
63
+ raise ValueError(f"total_steps must be > 0, got {total_steps}")
64
+ if step < 0:
65
+ raise ValueError(f"step must be ≥ 0, got {step}")
66
+
67
+ warmup_steps = int(total_steps * warmup_frac)
68
+ if step < warmup_steps:
69
+ return alpha_min
70
+
71
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
72
+ progress = min(1.0, max(0.0, progress))
73
+
74
+ if schedule == "linear":
75
+ alpha = alpha_min + (alpha_max - alpha_min) * progress
76
+ elif schedule == "cosine":
77
+ # 0.5 * (1 - cos(π·t)) goes 0 → 1 as t goes 0 → 1
78
+ alpha = alpha_min + (alpha_max - alpha_min) * 0.5 * (1 - math.cos(math.pi * progress))
79
+ elif schedule == "exp":
80
+ # Paper default: α(t) = α_min + (α_max - α_min) · (1 - exp(-5·t))
81
+ # Front-loads progress toward larger α
82
+ alpha = alpha_min + (alpha_max - alpha_min) * (1 - math.exp(-5 * progress))
83
+ else:
84
+ raise ValueError(f"unknown schedule: {schedule!r}")
85
+
86
+ return float(alpha)
87
+
88
+
89
+ def taid_blended_logits(
90
+ student_init_logits: torch.Tensor,
91
+ teacher_logits: torch.Tensor,
92
+ alpha: float,
93
+ ) -> torch.Tensor:
94
+ """Blend the "student-at-init" and teacher logits in probability space.
95
+
96
+ Returns logits of `(1 - α)·P_student_init + α·P_teacher`.
97
+ Internally:
98
+ 1. softmax both → P_student_init, P_teacher (in prob space)
99
+ 2. linear interpolate
100
+ 3. log → blended logits
101
+
102
+ Args:
103
+ student_init_logits: (B, T, V) student logits at training start
104
+ (frozen — keep a snapshot from step 0)
105
+ teacher_logits: (B, T, V) teacher logits (e.g., hint-conditioned
106
+ forward pass per SDPO)
107
+ alpha: interpolation coefficient in [0, 1]
108
+
109
+ Returns:
110
+ (B, T, V) logits whose softmax is the blended target distribution.
111
+ """
112
+ if not (0.0 <= alpha <= 1.0):
113
+ raise ValueError(f"alpha must be in [0, 1], got {alpha}")
114
+ if student_init_logits.shape != teacher_logits.shape:
115
+ raise ValueError(
116
+ f"shape mismatch: student_init={student_init_logits.shape}, "
117
+ f"teacher={teacher_logits.shape}"
118
+ )
119
+
120
+ # Mix in probability space, then log to get logits
121
+ p_student_init = F.softmax(student_init_logits, dim=-1)
122
+ p_teacher = F.softmax(teacher_logits, dim=-1)
123
+ p_blended = (1 - alpha) * p_student_init + alpha * p_teacher
124
+ # Clamp for numerical stability before log
125
+ p_blended = p_blended.clamp_min(1e-12)
126
+ return torch.log(p_blended)
127
+
128
+
129
+ def taid_loss(
130
+ student_logits: torch.Tensor,
131
+ teacher_logits: torch.Tensor,
132
+ student_init_logits: torch.Tensor,
133
+ *,
134
+ schedule_step: int,
135
+ total_steps: int,
136
+ schedule: str = "linear",
137
+ alpha_min: float = 0.0,
138
+ alpha_max: float = 1.0,
139
+ jsd_beta: float = 0.5,
140
+ temperature: float = 1.0,
141
+ reduction: str = "batchmean",
142
+ ) -> torch.Tensor:
143
+ """TAID-wrapped generalized-JSD loss.
144
+
145
+ Wraps the framework's `generalized_jsd_loss` (= SDPO/OPSD) with the
146
+ TAID schedule. At α=0 the loss target is the student's own initial
147
+ distribution (essentially a regularizer); at α=1 it's the standard
148
+ JSD-against-teacher (SDPO).
149
+
150
+ Args:
151
+ student_logits: (B, T, V) current student logits with grad
152
+ teacher_logits: (B, T, V) teacher logits (no grad — same model
153
+ different context per SDPO, or different model per real
154
+ distillation)
155
+ student_init_logits: (B, T, V) student logits captured at step 0
156
+ of training. Caller must save this and pass it in.
157
+ schedule_step: current training step
158
+ total_steps: total planned training steps
159
+ schedule: "linear" | "cosine" | "exp" — see `taid_alpha_schedule`
160
+ alpha_min, alpha_max: schedule range (defaults 0, 1)
161
+ jsd_beta: β param of generalized_jsd_loss (0=fwd KL, 0.5=JSD,
162
+ 1=rev KL)
163
+ temperature: temperature for both student and target
164
+ reduction: "batchmean" | "sum" | "mean" | "none"
165
+
166
+ Returns:
167
+ Scalar loss (or unreduced tensor if `reduction="none"`).
168
+
169
+ Reference: arXiv:2501.16937 Eq. (4) + §3.2.
170
+ """
171
+ # Lazy-import generalized_jsd_loss to avoid circular import
172
+ from composer_replication.opsd import generalized_jsd_loss
173
+
174
+ alpha = taid_alpha_schedule(
175
+ step=schedule_step,
176
+ total_steps=total_steps,
177
+ schedule=schedule,
178
+ alpha_min=alpha_min,
179
+ alpha_max=alpha_max,
180
+ )
181
+ blended_logits = taid_blended_logits(
182
+ student_init_logits=student_init_logits,
183
+ teacher_logits=teacher_logits,
184
+ alpha=alpha,
185
+ )
186
+ return generalized_jsd_loss(
187
+ student_logits=student_logits,
188
+ teacher_logits=blended_logits,
189
+ beta=jsd_beta,
190
+ temperature=temperature,
191
+ reduction=reduction,
192
+ )
193
+
194
+
195
+ __all__ = ["taid_alpha_schedule", "taid_blended_logits", "taid_loss"]
composer_replication/distillation/tests/test_distillation_losses.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Distillation-loss unit tests — SimPO + TAID + Entropy-Aware OPD."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+
6
+ import pytest
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from composer_replication.distillation import (
11
+ entropy_aware_opd_loss,
12
+ simpo_loss,
13
+ taid_loss,
14
+ )
15
+ from composer_replication.distillation.simpo import avg_sequence_logprob
16
+ from composer_replication.distillation.taid import (
17
+ taid_alpha_schedule,
18
+ taid_blended_logits,
19
+ )
20
+ from composer_replication.distillation.entropy_aware_opd import teacher_entropy
21
+
22
+
23
+ # ---------------------------------------------------------------------
24
+ # SimPO
25
+ # ---------------------------------------------------------------------
26
+
27
+ def test_simpo_loss_returns_scalar():
28
+ chosen = torch.tensor([0.5, 0.4, 0.3])
29
+ rejected = torch.tensor([0.1, 0.0, -0.2])
30
+ loss = simpo_loss(chosen, rejected, beta=2.0, gamma=1.0)
31
+ assert loss.dim() == 0
32
+ assert torch.isfinite(loss)
33
+
34
+
35
+ def test_simpo_loss_lower_for_better_separation():
36
+ """Larger margin between chosen and rejected → lower loss."""
37
+ # Same setup, two batches with different separations
38
+ small_sep_loss = simpo_loss(
39
+ torch.tensor([0.1]), torch.tensor([0.05]),
40
+ )
41
+ large_sep_loss = simpo_loss(
42
+ torch.tensor([1.0]), torch.tensor([-1.0]),
43
+ )
44
+ assert large_sep_loss < small_sep_loss, (
45
+ f"large separation should give smaller loss; "
46
+ f"got small_sep={small_sep_loss}, large_sep={large_sep_loss}"
47
+ )
48
+
49
+
50
+ def test_simpo_loss_differentiable():
51
+ chosen = torch.tensor([0.5], requires_grad=True)
52
+ rejected = torch.tensor([0.0], requires_grad=True)
53
+ loss = simpo_loss(chosen, rejected)
54
+ loss.backward()
55
+ assert chosen.grad is not None
56
+ assert rejected.grad is not None
57
+ assert torch.isfinite(chosen.grad).all()
58
+ assert torch.isfinite(rejected.grad).all()
59
+
60
+
61
+ def test_simpo_loss_shape_mismatch_raises():
62
+ with pytest.raises(ValueError, match="same shape"):
63
+ simpo_loss(torch.zeros(3), torch.zeros(5))
64
+
65
+
66
+ def test_avg_sequence_logprob():
67
+ """Helper averages over response tokens, ignoring prompt + padding."""
68
+ # B=2, T=4
69
+ logprobs = torch.tensor([
70
+ [-10.0, -10.0, -1.0, -2.0], # response is last 2 tokens, avg=-1.5
71
+ [-1.0, -3.0, -1.0, -10.0], # response is first 3 tokens, avg=-5/3
72
+ ])
73
+ mask = torch.tensor([
74
+ [0, 0, 1, 1],
75
+ [1, 1, 1, 0],
76
+ ])
77
+ avg = avg_sequence_logprob(logprobs, mask)
78
+ expected = torch.tensor([-1.5, -5.0 / 3.0])
79
+ torch.testing.assert_close(avg, expected, atol=1e-5, rtol=1e-5)
80
+
81
+
82
+ # ---------------------------------------------------------------------
83
+ # TAID
84
+ # ---------------------------------------------------------------------
85
+
86
+ def test_taid_alpha_schedule_endpoints():
87
+ """At step 0 → alpha_min; at step total → alpha_max."""
88
+ assert taid_alpha_schedule(0, 100, schedule="linear") == 0.0
89
+ assert taid_alpha_schedule(100, 100, schedule="linear") == 1.0
90
+ assert taid_alpha_schedule(0, 100, schedule="cosine") == 0.0
91
+ assert taid_alpha_schedule(100, 100, schedule="cosine") == pytest.approx(1.0)
92
+ assert taid_alpha_schedule(0, 100, schedule="exp") == pytest.approx(0.0)
93
+ assert taid_alpha_schedule(100, 100, schedule="exp") == pytest.approx(1 - math.exp(-5))
94
+
95
+
96
+ def test_taid_alpha_schedule_monotonic_linear():
97
+ prev = -1.0
98
+ for step in [0, 10, 25, 50, 75, 90, 100]:
99
+ a = taid_alpha_schedule(step, 100, schedule="linear")
100
+ assert a >= prev
101
+ prev = a
102
+
103
+
104
+ def test_taid_alpha_schedule_warmup():
105
+ """During warmup_frac, alpha stays at alpha_min."""
106
+ a_warmup = taid_alpha_schedule(50, 1000, warmup_frac=0.1, schedule="linear")
107
+ # warmup_steps = 100, step 50 < 100 → still alpha_min
108
+ assert a_warmup == 0.0
109
+ a_post_warmup = taid_alpha_schedule(150, 1000, warmup_frac=0.1, schedule="linear")
110
+ # post-warmup, partial way through remaining 900 steps
111
+ assert a_post_warmup > 0.0
112
+ assert a_post_warmup < 1.0
113
+
114
+
115
+ def test_taid_blended_logits_endpoints():
116
+ """alpha=0 → student_init target; alpha=1 → teacher target."""
117
+ # Use logits with strong peaks to make endpoint behavior obvious
118
+ student_init = torch.zeros(2, 3, 4)
119
+ student_init[0, 0, 0] = 10.0 # peaks at index 0
120
+ teacher = torch.zeros(2, 3, 4)
121
+ teacher[0, 0, 3] = 10.0 # peaks at index 3
122
+
123
+ blended_alpha0 = taid_blended_logits(student_init, teacher, alpha=0.0)
124
+ blended_alpha1 = taid_blended_logits(student_init, teacher, alpha=1.0)
125
+ blended_half = taid_blended_logits(student_init, teacher, alpha=0.5)
126
+
127
+ # alpha=0: argmax follows student_init
128
+ assert blended_alpha0[0, 0].argmax().item() == 0
129
+ # alpha=1: argmax follows teacher
130
+ assert blended_alpha1[0, 0].argmax().item() == 3
131
+ # alpha=0.5: bimodal; both 0 and 3 should be elevated
132
+ half_probs = F.softmax(blended_half[0, 0], dim=-1)
133
+ assert half_probs[0] > 0.4
134
+ assert half_probs[3] > 0.4
135
+
136
+
137
+ def test_taid_loss_returns_scalar_and_differentiable():
138
+ B, T, V = 2, 4, 8
139
+ student_logits = torch.randn(B, T, V, requires_grad=True)
140
+ teacher_logits = torch.randn(B, T, V)
141
+ student_init = torch.randn(B, T, V)
142
+ loss = taid_loss(
143
+ student_logits, teacher_logits, student_init,
144
+ schedule_step=500, total_steps=1000,
145
+ )
146
+ assert loss.dim() == 0
147
+ assert torch.isfinite(loss)
148
+ loss.backward()
149
+ assert student_logits.grad is not None
150
+ assert torch.isfinite(student_logits.grad).all()
151
+
152
+
153
+ def test_taid_loss_alpha_zero_ignores_teacher():
154
+ """At alpha=0, teacher gradient should not flow through to student."""
155
+ B, T, V = 1, 2, 4
156
+ student_init = torch.randn(B, T, V)
157
+ s1 = torch.randn(B, T, V, requires_grad=True)
158
+ teacher_a = torch.zeros(B, T, V)
159
+ teacher_a[..., 0] = 10.0
160
+ teacher_b = torch.zeros(B, T, V)
161
+ teacher_b[..., 3] = 10.0
162
+ # At step 0 with alpha_min=alpha_max=0, alpha is forced to 0 → blended = student_init
163
+ loss_a = taid_loss(s1, teacher_a, student_init, schedule_step=0, total_steps=100,
164
+ alpha_min=0.0, alpha_max=0.0)
165
+ loss_b = taid_loss(s1, teacher_b, student_init, schedule_step=0, total_steps=100,
166
+ alpha_min=0.0, alpha_max=0.0)
167
+ # Different teachers should give the same loss when alpha is pinned to 0
168
+ assert abs(float(loss_a) - float(loss_b)) < 1e-4
169
+
170
+
171
+ # ---------------------------------------------------------------------
172
+ # Entropy-Aware OPD
173
+ # ---------------------------------------------------------------------
174
+
175
+ def test_teacher_entropy_one_hot_is_zero():
176
+ """Argmax-1 distribution has entropy 0."""
177
+ logits = torch.zeros(1, 1, 4)
178
+ logits[..., 0] = 100.0 # essentially one-hot
179
+ H = teacher_entropy(logits)
180
+ assert float(H[0, 0]) < 1e-3
181
+
182
+
183
+ def test_teacher_entropy_uniform_is_log_v():
184
+ """Uniform distribution over V symbols has entropy = log(V)."""
185
+ logits = torch.zeros(1, 1, 5)
186
+ H = teacher_entropy(logits)
187
+ assert float(H[0, 0]) == pytest.approx(math.log(5), rel=1e-5)
188
+
189
+
190
+ def test_entropy_aware_opd_returns_scalar_and_differentiable():
191
+ B, T, V = 2, 3, 8
192
+ student_logits = torch.randn(B, T, V, requires_grad=True)
193
+ teacher_logits = torch.randn(B, T, V)
194
+ loss = entropy_aware_opd_loss(student_logits, teacher_logits)
195
+ assert loss.dim() == 0
196
+ assert torch.isfinite(loss)
197
+ loss.backward()
198
+ assert student_logits.grad is not None
199
+ assert torch.isfinite(student_logits.grad).all()
200
+
201
+
202
+ def test_entropy_aware_opd_with_label_mask():
203
+ """Label mask should zero out per-token loss on labels==0 positions."""
204
+ B, T, V = 1, 4, 6
205
+ student_logits = torch.randn(B, T, V, requires_grad=True)
206
+ teacher_logits = torch.randn(B, T, V)
207
+ full_loss = entropy_aware_opd_loss(student_logits, teacher_logits)
208
+ half_mask = torch.tensor([[1, 1, 0, 0]])
209
+ half_loss = entropy_aware_opd_loss(
210
+ student_logits, teacher_logits, labels=half_mask,
211
+ )
212
+ # half_loss should be ~half of the unmasked sum (modulo the entropy gating
213
+ # being position-dependent — but it should at least be < full_loss)
214
+ assert float(half_loss) < float(full_loss)
215
+
216
+
217
+ def test_entropy_aware_opd_zero_when_distributions_match():
218
+ """When student and teacher are identical, both KLs are 0 → loss is 0."""
219
+ logits = torch.randn(1, 2, 4)
220
+ loss = entropy_aware_opd_loss(logits, logits)
221
+ assert float(loss) < 1e-5
222
+
223
+
224
+ def test_entropy_aware_opd_reduction_modes():
225
+ student_logits = torch.randn(2, 3, 4, requires_grad=True)
226
+ teacher_logits = torch.randn(2, 3, 4)
227
+ none_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="none")
228
+ mean_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="mean")
229
+ sum_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="sum")
230
+ batchmean_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="batchmean")
231
+ assert none_loss.shape == (2, 3)
232
+ assert mean_loss.dim() == 0
233
+ assert sum_loss.dim() == 0
234
+ assert batchmean_loss.dim() == 0
235
+ # batchmean = sum / batch_size
236
+ assert abs(float(batchmean_loss) - float(sum_loss) / 2) < 1e-4
composer_replication/recipes/monarch/actors.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Monarch actor skeletons — DESIGN/SKELETON for v0.
2
+
3
+ Per ADR-006, full Monarch integration is deferred to v0.2+. This file
4
+ documents the actor signatures so the framework's recipe matrix is
5
+ complete.
6
+
7
+ Importing this module does NOT require monarch installed; the imports
8
+ are deferred inside class bodies. Real instantiation will fail without
9
+ monarch, which is the desired behavior for a recipe document.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+
16
+ class TrainerActor:
17
+ """Hosts the framework's 3-channel composer trainer.
18
+
19
+ Real implementation (v0.2+):
20
+
21
+ from monarch import Actor, endpoint
22
+
23
+ class TrainerActor(Actor):
24
+ @endpoint
25
+ async def train_outer_step(self, batch_id: int) -> dict:
26
+ # 1. Pull batch from generator
27
+ # 2. Run inner H steps with composer compose_loss
28
+ # 3. Compute pseudo-gradient
29
+ # 4. Hand to ObjectStoreAllReduce manager
30
+ # 5. Apply outer SGD step
31
+ # 6. Return metrics dict
32
+ ...
33
+
34
+ For v0 the actor is just a documentation stub.
35
+ """
36
+ backend = "monarch"
37
+ role = "trainer"
38
+
39
+ def __init__(self) -> None:
40
+ raise NotImplementedError(
41
+ "Monarch trainer actor is a v0 skeleton; implementation "
42
+ "deferred to v0.2 per ADR-006."
43
+ )
44
+
45
+ async def train_outer_step(self, batch_id: int) -> dict[str, Any]:
46
+ raise NotImplementedError
47
+
48
+
49
+ class GeneratorActor:
50
+ """vllm-backed rollout actor."""
51
+ backend = "monarch"
52
+ role = "generator"
53
+
54
+ def __init__(self) -> None:
55
+ raise NotImplementedError("v0 skeleton — see ADR-006.")
56
+
57
+ async def rollout(self, prompts: list[str]) -> list[str]:
58
+ raise NotImplementedError
59
+
60
+
61
+ class RewarderActor:
62
+ """verifiers-protocol rewarder for RLVR-style RL."""
63
+ backend = "monarch"
64
+ role = "rewarder"
65
+
66
+ def __init__(self) -> None:
67
+ raise NotImplementedError("v0 skeleton — see ADR-006.")
68
+
69
+ async def score(self, completions: list[str]) -> list[float]:
70
+ raise NotImplementedError
71
+
72
+
73
+ class TeacherPoolActor:
74
+ """Channel-3 teacher pool — wraps composer_replication.teacher_replay."""
75
+ backend = "monarch"
76
+ role = "teacher_pool"
77
+
78
+ def __init__(self) -> None:
79
+ raise NotImplementedError("v0 skeleton — see ADR-006.")
80
+
81
+ async def replay(self, states: list[dict]) -> list[dict]:
82
+ raise NotImplementedError
83
+
84
+
85
+ __all__ = [
86
+ "GeneratorActor",
87
+ "RewarderActor",
88
+ "TeacherPoolActor",
89
+ "TrainerActor",
90
+ ]
composer_replication/recipes/monarch/monarch_actor_layout.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Monarch actor mesh — design for hosting the framework's training topology
2
+
3
+ **Status**: Design + skeleton. Real Monarch integration is post-replication
4
+ work (ADR-006 explicitly defers it to v0.2+).
5
+
6
+ **ADR**: 006
7
+
8
+ ## What Monarch is
9
+
10
+ Monarch (https://github.com/meta-pytorch/monarch, BSD-3) is Meta's actor-
11
+ mesh runtime — a thin coordination layer over Python processes that lets
12
+ you describe a training topology as a graph of typed actors, then run
13
+ that topology on top of any cluster manager (k8s, Slurm, raw ssh).
14
+
15
+ Per ADR-006, Monarch is the only Meta PyTorch agentic-stack component
16
+ that's actively shipping (v0.4.1 stable, v0.5 dev daily) and not paused.
17
+ TorchForge, the original "agent" piece, is paused per its own repo banner.
18
+
19
+ ## Why Monarch fits the framework's design
20
+
21
+ The framework already has an N-actor topology even without Monarch:
22
+ - Trainer (channel 1: GRPO; channel 2: SDPO; channel 3: trace-replay DPO)
23
+ - Generator (rollout / vllm)
24
+ - Rewarder (RLVR test runner / verifiers protocol)
25
+ - N teachers (channel 3: external OpenRouter calls)
26
+ - DiLoCo replicas (N copies of trainer, syncing via object store)
27
+
28
+ PRIME-RL gives us the trainer/generator/rewarder split for free. Monarch
29
+ takes that further: each of those becomes a Monarch actor, and the framework
30
+ gains:
31
+ 1. **Heterogeneous executor support** — actors run wherever Monarch's
32
+ backend places them (Modal, k8s, on-prem cluster). Composes naturally
33
+ with our `ServerlessExecutor` Protocol.
34
+ 2. **Failure recovery** — Monarch handles actor crashes + restarts;
35
+ the framework's DiLoCo state is durable in object storage, so a
36
+ restarted trainer replica can resume from the last outer round.
37
+ 3. **Hot-swap of actor implementations** — switch teacher backends
38
+ from "OpenRouter" to "local vllm" by changing one Monarch actor
39
+ binding.
40
+
41
+ ## Actor topology (proposed)
42
+
43
+ ```
44
+ ┌───────────────────────────────────────────────────────────────┐
45
+ │ ComposerReplicationMesh │
46
+ │ │
47
+ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
48
+ │ │ Trainer × N │←─│ Generator │←─│ Rewarder │ │
49
+ │ │ (DiLoCo │ │ (vllm) │ │ (verifiers) │ │
50
+ │ │ replicas) │ └──────────────┘ └──────────────────┘ │
51
+ │ └──────┬───────┘ │
52
+ │ │ │
53
+ │ │ Channel 2: same-model hint-conditioned forward │
54
+ │ │ Channel 3: cross-model OpenRouter teachers │
55
+ │ ▼ │
56
+ │ ┌──────────────┐ │
57
+ │ │ TeacherPool │ ── OpenRouter (Claude, GPT, DeepSeek, ...) │
58
+ │ │ (channel 3) │ │
59
+ │ └──────────────┘ │
60
+ │ │
61
+ │ ┌──────────────────────────────────────────────────────────┐ │
62
+ │ │ ObjectStore (s3://, hf://, file://) │ │
63
+ │ │ · DiLoCo pseudo-gradients (round_N/rank_R.pt) │ │
64
+ │ │ · Replay datasets (NormalizedDPOPair JSONL) │ │
65
+ │ └──────────────────────────────────────────────────────────┘ │
66
+ └────────────────────────────────────────────────────────────────┘
67
+ ```
68
+
69
+ ## Mapping to Monarch primitives
70
+
71
+ ```python
72
+ from monarch import Actor, mesh, endpoint
73
+
74
+ class TrainerActor(Actor):
75
+ """Hosts the GRPO trainer + composer 3-channel loss."""
76
+ @endpoint
77
+ async def train_outer_step(self, batch_id: int): ...
78
+
79
+ class GeneratorActor(Actor):
80
+ """vllm rollout server — generates trajectories on demand."""
81
+ @endpoint
82
+ async def rollout(self, prompts: list[str]) -> list[str]: ...
83
+
84
+ class RewarderActor(Actor):
85
+ """Runs verifiers protocol — RLVR-style test execution."""
86
+ @endpoint
87
+ async def score(self, completions: list[str]) -> list[float]: ...
88
+
89
+ class TeacherPoolActor(Actor):
90
+ """Channel 3 — OpenRouter calls to N external teachers."""
91
+ @endpoint
92
+ async def replay(self, states: list[dict]) -> list[dict]: ...
93
+
94
+ # Topology
95
+ trainers = mesh.spawn(TrainerActor, n=4, gpu="A100")
96
+ generator = mesh.spawn(GeneratorActor, n=1, gpu="A100")
97
+ rewarder = mesh.spawn(RewarderActor, n=1, gpu=None)
98
+ teachers = mesh.spawn(TeacherPoolActor, n=1, gpu=None)
99
+ ```
100
+
101
+ ## Status of this directory
102
+
103
+ - `monarch_actor_layout.md` — this file (design)
104
+ - `actors.py` — skeleton actor definitions; do not import without
105
+ monarch installed
106
+ - `composer_mesh.py` — composition glue; not yet implemented
107
+
108
+ ## Open questions (deferred to v0.2)
109
+
110
+ - Does Monarch v0.5's Slurm backend hand-shake cleanly with HF Jobs?
111
+ (HF Jobs runs each "job" as an independent container; Monarch wants
112
+ to manage the lifecycle. Possible mismatch.)
113
+ - Can the `TrainerActor` host the framework's `ComposerReplicationTrainer`
114
+ unmodified, or does it need to be split into `step_init` /
115
+ `step_compute` endpoints to fit Monarch's async actor model?
116
+
117
+ ## References
118
+
119
+ - Monarch repo: https://github.com/meta-pytorch/monarch
120
+ - ADR-006: docs/adrs/ADR-006-rl-frameworks.md
121
+ - Reconnaissance: docs/research/RL_FRAMEWORKS_LANDSCAPE.md § Monarch
composer_replication/recipes/prime_rl/composer_loss.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PRIME-RL composer loss adapter — SKELETON for v0.
2
+
3
+ Per ADR-006, PRIME-RL exposes a `CustomLossConfig` that takes an
4
+ importable function. This module supplies that function: a thin adapter
5
+ that maps PRIME-RL's `LossInputs` struct onto the framework's 3-channel
6
+ loss composition.
7
+
8
+ Status: SKELETON. The full implementation requires a runtime spike with
9
+ prime-rl installed; this file documents the contract and provides a
10
+ working stub that returns a finite scalar so PRIME-RL can be configured
11
+ end-to-end without yet having all three channels wired up.
12
+
13
+ Reference:
14
+ - PRIME-RL `LossInputs` shape (verified via DeepWiki audit, Wave 13):
15
+ - trainer_logprobs: Tensor (B, T) — student log-probs of generated tokens
16
+ - inference_logprobs: Tensor (B, T) — log-probs from inference engine
17
+ - teacher_logprobs: Tensor (B, T) | None — optional teacher channel
18
+ - advantages: Tensor (B, T) — GRPO advantages
19
+ - loss_mask: Tensor (B, T) — response-token mask
20
+ """
21
+ from __future__ import annotations
22
+
23
+ from typing import Any
24
+
25
+
26
+ def loss_fn(
27
+ inputs: Any, # PRIME-RL's LossInputs — typed as Any to avoid hard import
28
+ *,
29
+ alpha_sdpo: float = 0.5,
30
+ beta_dpo: float = 0.3,
31
+ epsilon: float = 1e-6,
32
+ ) -> Any: # Returns a torch.Tensor (scalar)
33
+ """Composer 3-channel loss adapted to PRIME-RL's LossInputs struct.
34
+
35
+ Channels (per `composer_replication.compose_loss`):
36
+ 1. GRPO policy-gradient: -(advantages * trainer_logprobs * mask).mean()
37
+ 2. SDPO / OPSD: generalized_jsd_loss(student_logits, teacher_logits)
38
+ 3. Trace-replay DPO: standard DPO on (chosen, rejected) pairs
39
+
40
+ For PRIME-RL adaptation:
41
+ - Channel 1 reads from `advantages` + `trainer_logprobs` directly.
42
+ (Note: this is REINFORCE-with-advantage, not full GRPO. Full
43
+ GRPO would use `inference_logprobs` for the importance-sampling
44
+ ratio + PPO clipping. See Wave 13 review Finding 6.)
45
+ - Channel 2 (SDPO) is **DEFERRED** for v0 because PRIME-RL v0.5
46
+ exposes log-probs not logits, and SDPO needs the full vocab
47
+ distribution. Setting alpha_sdpo>0 raises NotImplementedError
48
+ (Wave 13 review Finding 1 — earlier draft was silently degenerate).
49
+ - Channel 3 (DPO) is OUT OF SCOPE for the PRIME-RL recipe in v0
50
+ — it would require modifying PRIME-RL's data path to pass
51
+ `(chosen, rejected)` pairs alongside the rollout, which is a
52
+ separate integration effort. v0 emits beta_dpo=0 with a
53
+ warning if non-zero.
54
+
55
+ Args:
56
+ inputs: PRIME-RL `LossInputs` (duck-typed)
57
+ alpha_sdpo: weight on channel 2 (SDPO)
58
+ beta_dpo: weight on channel 3 (DPO) — currently must be 0
59
+ epsilon: numerical stability for log/division
60
+
61
+ Returns:
62
+ Scalar torch.Tensor; PRIME-RL's trainer takes care of `.backward()`.
63
+ """
64
+ import torch # lazy
65
+ from composer_replication.opsd import generalized_jsd_loss
66
+
67
+ # Channel 1: GRPO
68
+ advantages = inputs.advantages
69
+ trainer_lp = inputs.trainer_logprobs
70
+ mask = inputs.loss_mask
71
+ if mask.dtype != advantages.dtype:
72
+ mask = mask.to(advantages.dtype)
73
+ grpo_loss = -(advantages * trainer_lp * mask).sum() / mask.sum().clamp_min(epsilon)
74
+
75
+ total = grpo_loss
76
+
77
+ # Channel 2: SDPO/OPSD — DEFERRED in PRIME-RL recipe v0.
78
+ #
79
+ # Wave 13 cross-model review (docs/research/WAVE_13_FINAL_REVIEW.md
80
+ # Finding 1) caught that an earlier draft of this code applied
81
+ # `unsqueeze(-1)` to (B, T) log-prob tensors before passing them to
82
+ # generalized_jsd_loss, which calls log_softmax(dim=-1). Softmax of a
83
+ # 1-element vector is exactly 1.0; its log is 0. So the SDPO term was
84
+ # mathematically degenerate (always 0), silently disabling channel 2
85
+ # while reporting alpha_sdpo>0 in the config.
86
+ #
87
+ # The right path forward depends on PRIME-RL exposing full logits, not
88
+ # just log-probs. Until that lands upstream, refuse to fake the channel:
89
+ teacher_lp = getattr(inputs, "teacher_logprobs", None)
90
+ if teacher_lp is not None and alpha_sdpo > 0:
91
+ raise NotImplementedError(
92
+ "SDPO channel in the PRIME-RL recipe is deferred. PRIME-RL v0.5 "
93
+ "exposes (B, T) log-probs through LossInputs but not full logits, "
94
+ "and SDPO/OPSD requires the full distribution over vocabulary. "
95
+ "Set alpha_sdpo=0.0 to silence this and use channel 1 (GRPO) only. "
96
+ "See docs/research/WAVE_13_FINAL_REVIEW.md Finding 1."
97
+ )
98
+
99
+ # Channel 3: not supported in PRIME-RL recipe v0
100
+ if beta_dpo != 0.0:
101
+ import warnings
102
+ warnings.warn(
103
+ "PRIME-RL recipe v0 does not support DPO channel; "
104
+ "set beta_dpo=0.0 to silence this warning.",
105
+ stacklevel=2,
106
+ )
107
+
108
+ return total
109
+
110
+
111
+ __all__ = ["loss_fn"]
composer_replication/recipes/prime_rl/prime_rl_config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PRIME-RL config wiring the framework's 3-channel composer loss.
2
+ #
3
+ # Status: SKELETON. Field names approximate PRIME-RL's v0.5 config schema;
4
+ # verify against the installed version before launching a real run.
5
+ # Reference: docs/research/RL_FRAMEWORKS_LANDSCAPE.md § PRIME-RL.
6
+
7
+ # --- Model ------------------------------------------------------------
8
+ model:
9
+ base: "Qwen/Qwen2.5-0.5B"
10
+ attn_implementation: "flash_attention_2"
11
+ dtype: "bfloat16"
12
+
13
+ # --- Training environment (verifiers / OpenEnv compatible) -----------
14
+ env:
15
+ protocol: "verifiers"
16
+ config:
17
+ # Point at any verifiers-protocol task (math, code, etc.)
18
+ name: "math/gsm8k"
19
+ split: "train"
20
+
21
+ # --- Custom loss (the framework's contribution) -----------------------
22
+ loss:
23
+ custom:
24
+ # PRIME-RL imports this and calls loss_fn(inputs, **kwargs) at each step.
25
+ # The function MUST return a scalar tensor (PRIME-RL handles backward).
26
+ import_path: "composer_replication.recipes.prime_rl.composer_loss:loss_fn"
27
+ kwargs:
28
+ alpha_sdpo: 0.5
29
+ beta_dpo: 0.0 # DPO channel out-of-scope for PRIME-RL recipe v0
30
+ epsilon: 1.0e-6
31
+
32
+ # --- PRIME-RL three-actor split --------------------------------------
33
+ trainer:
34
+ optimizer: "muon"
35
+ learning_rate: 1.0e-5
36
+ inner_steps: 500 # H for Decoupled DiLoCo outer-loop sync
37
+ # To enable Decoupled DiLoCo, the trainer's optimizer manager is
38
+ # monkey-patched at startup with composer_replication.diloco.serverless.MockManager
39
+ # backed by ObjectStoreAllReduce. See ADR-005 for the wiring.
40
+
41
+ generator:
42
+ backend: "vllm"
43
+ tensor_parallel: 1
44
+
45
+ rewarder:
46
+ protocol: "verifiers"
47
+ # No-op for the math task — verifiers does the verification
48
+
49
+ # --- Decoupled DiLoCo (optional) -------------------------------------
50
+ diloco:
51
+ enabled: true
52
+ rendezvous_uri: "s3://my-bucket/diloco-runs/qwen-05b-replication/"
53
+ world_size: 4
54
+ outer_lr: 0.7
55
+ outer_steps: 100
56
+ # When enabled, replicas should be launched via
57
+ # composer_replication.diloco.serverless.{ModalExecutor, HFJobsExecutor, ...}
58
+ # rather than as a single PRIME-RL job.
59
+
60
+ # --- Logging / checkpointing -----------------------------------------
61
+ checkpoint:
62
+ every_n_outer_steps: 10
63
+ output_dir: "./checkpoints/prime-rl-composer/"
64
+ logging:
65
+ wandb_project: "composer-replication"
66
+ log_every_n_steps: 1
composer_replication/recipes/prime_rl/prime_rl_recipe.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recipe C — PRIME-RL: 3-channel composer loss via PRIME-RL's `CustomLossConfig`
2
+
3
+ **Status**: Recipe complete; runtime smoke test deferred to a follow-up
4
+ spike (requires `prime-rl >= 0.5` installed + a CUDA box).
5
+ **ADR**: 006
6
+
7
+ ## Why PRIME-RL is a third RL recipe (alongside TRL and VeRL)
8
+
9
+ Per ADR-006, PRIME-RL is the cleanest extension surface for a 3-channel
10
+ loss because it ships a **first-class `CustomLossConfig`** that takes an
11
+ importable Python function and a `LossInputs` struct exposing exactly
12
+ the tensors we need:
13
+
14
+ ```python
15
+ @dataclass
16
+ class LossInputs:
17
+ trainer_logprobs: Tensor # student log-probs of generated tokens
18
+ inference_logprobs: Tensor # log-probs from the inference engine
19
+ # (importance-sampling ratio numerator)
20
+ teacher_logprobs: Tensor | None # if the teacher channel is wired in
21
+ advantages: Tensor # GRPO advantages (channel 1)
22
+ loss_mask: Tensor # response-token mask
23
+ ```
24
+
25
+ The user wires this in via a YAML config field — no fork, no Trainer
26
+ subclass, no monkey-patching:
27
+
28
+ ```yaml
29
+ # prime_rl_config.yaml
30
+ loss:
31
+ custom:
32
+ import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
33
+ kwargs:
34
+ alpha_sdpo: 0.5
35
+ beta_dpo: 0.3
36
+ ```
37
+
38
+ ## Step-by-step
39
+
40
+ ### 1. Install PRIME-RL
41
+ ```bash
42
+ pip install prime-rl>=0.5
43
+ # (or: pip install -e .[prime-rl] from the framework repo)
44
+ ```
45
+
46
+ ### 2. Drop in the composer loss
47
+ The framework ships `composer_replication.recipes.prime_rl.composer_loss`
48
+ which adapts the 3-channel `compose_loss` to PRIME-RL's `LossInputs`
49
+ struct. The signature is fixed by PRIME-RL:
50
+
51
+ ```python
52
+ def loss_fn(inputs: LossInputs, *, alpha_sdpo: float, beta_dpo: float) -> Tensor:
53
+ # channel 1: GRPO (PRIME-RL's default policy gradient)
54
+ grpo = (inputs.advantages * inputs.trainer_logprobs * inputs.loss_mask).mean()
55
+
56
+ # channel 2: SDPO/OPSD against teacher_logprobs
57
+ sdpo = ...
58
+
59
+ # channel 3: trace-replay DPO via teacher_logprobs disagreement
60
+ trace_replay_dpo = ...
61
+
62
+ return -grpo + alpha_sdpo * sdpo + beta_dpo * trace_replay_dpo
63
+ ```
64
+
65
+ Concrete file: `composer_loss.py` in this directory (skeleton; fills in
66
+ when the user does the runtime spike).
67
+
68
+ ### 3. PRIME-RL config
69
+
70
+ The example `prime_rl_config.yaml` in this directory wires:
71
+ - The training environment via the `verifiers` env protocol (OpenEnv-
72
+ compatible — no translation layer needed)
73
+ - The custom loss with `import_path` pointing at our `loss_fn`
74
+ - Trainer / generator / rewarder split (PRIME-RL's three-actor design)
75
+
76
+ ### 4. Decoupled DiLoCo over PRIME-RL replicas
77
+
78
+ PRIME-RL runs trainer/generator/rewarder as separate processes. To layer
79
+ Decoupled DiLoCo on top, replace the trainer process's optimizer with
80
+ the framework's `make_diloco_outer_loop` and pass a `MockManager`
81
+ (per ADR-005) backed by `ObjectStoreAllReduce`. The other two actors
82
+ are unchanged.
83
+
84
+ This setup is what makes "any number of teachers, any RL framework, any
85
+ serverless executor" composable — PRIME-RL's plug-in points line up
86
+ naturally with the framework's plug-in points.
87
+
88
+ ## What this recipe gives the user
89
+
90
+ - Frontier-RL post-training infra (PRIME-RL's actor-mesh design,
91
+ battle-tested on INTELLECT-1/2)
92
+ - 3-channel composer loss via a single YAML field
93
+ - DiLoCo outer-loop sync via a one-line monkey-patch of the trainer's
94
+ manager
95
+ - OpenEnv-compatible task plumbing for free
96
+
97
+ ## What this recipe doesn't give the user
98
+
99
+ - An actual training run yet — that's a separate spike.
100
+ - Quality validation against TRL/VeRL — pending Spike 004 A/B.
101
+ - Hardware autoscaling — that's the Monarch recipe's job (recipes/monarch/).
102
+
103
+ ## References
104
+
105
+ - PRIME-RL repo: https://github.com/PrimeIntellect-ai/prime-rl
106
+ - ADR-006: docs/adrs/ADR-006-rl-frameworks.md
107
+ - Reconnaissance: docs/research/RL_FRAMEWORKS_LANDSCAPE.md (§ PRIME-RL)
composer_replication/recipes/replaysim/default.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default replaysim normalization recipe.
2
+ #
3
+ # This is a data-juicer YAML config (https://github.com/modelscope/data-juicer).
4
+ # It runs CPU-only ops that filter and clean DPO pairs produced by
5
+ # composer_replication.teacher_replay.extract_dpo_pairs.
6
+ #
7
+ # The op-graph operates on records of shape:
8
+ #
9
+ # {
10
+ # "state_id": "...",
11
+ # "messages": [{"role": "user", "content": "..."}],
12
+ # "chosen": [{"role": "assistant", "content": "..."}],
13
+ # "rejected": [{"role": "assistant", "content": "..."}],
14
+ # "chosen_teacher": "...",
15
+ # "rejected_teacher": "..."
16
+ # }
17
+ #
18
+ # Ops listed in `process` are applied in order. Each op operates on the
19
+ # full record but typically reads/writes one field. data-juicer's
20
+ # DPO/preference-pair ops know how to handle the chosen/rejected pair
21
+ # structure natively.
22
+
23
+ # Project & I/O are filled in by DJNormalizer at runtime; we only
24
+ # specify the op pipeline here.
25
+
26
+ # --- Op pipeline (applied in order) -----------------------------------
27
+ process:
28
+
29
+ # 1. Length filter on the assistant response.
30
+ # Drops pairs where either the chosen or rejected response is shorter
31
+ # than 8 chars or longer than 32k chars (likely garbled / overflow).
32
+ - text_length_filter:
33
+ min_len: 8
34
+ max_len: 32000
35
+ text_keys: ["chosen", "rejected"]
36
+
37
+ # 2. Word-count filter on response.
38
+ # Drops pairs with absurdly low (< 2 words) or high (> 4096 words)
39
+ # response counts.
40
+ - words_num_filter:
41
+ min_num: 2
42
+ max_num: 4096
43
+ text_keys: ["chosen", "rejected"]
44
+
45
+ # 3. Special-character filter.
46
+ # Drops responses where >50% of characters are non-alphabetic
47
+ # special chars (likely encoding errors or junk).
48
+ - special_characters_filter:
49
+ max_ratio: 0.5
50
+ text_keys: ["chosen", "rejected"]
51
+
52
+ # 4. Per-conversation deduplication.
53
+ # If the chosen and rejected responses are identical (no real
54
+ # disagreement), drop the pair.
55
+ - document_deduplicator:
56
+ lowercase: true
57
+ ignore_non_character: true
58
+ text_keys: ["chosen"]
59
+ # data-juicer's per-batch dedup; full corpus dedup is a separate op.
60
+
61
+ # Notes:
62
+ # - We DO NOT run `pair_preference_mapper` because its default config may
63
+ # re-synthesize the rejected text via an LLM call — we already have
64
+ # real disagreement-derived rejected text and don't want to pay another
65
+ # API call. (See ADR-004 § "One-day spike before merge.")
66
+ # - Language detection is intentionally not in the default — it requires
67
+ # downloading a fasttext model and adds startup latency. Add the
68
+ # `language_id_score_filter` op to a custom recipe if needed.
69
+ # - Semantic-similarity dedup is GPU-bound (NeMo-Curator ops); not in
70
+ # the default.
composer_replication/replaysim/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """composer_replication.replaysim — N-teacher trace replay + dataset normalization.
2
+
3
+ Per ADR-004, this package consolidates the framework's
4
+ "replay an LLM trace through N teachers, get a DPO/preference dataset" flow:
5
+
6
+ raw trace
7
+ ↓ (existing teacher_replay.replay_trace)
8
+ list[TeacherCallResult]
9
+ ↓ (existing teacher_replay.extract_dpo_pairs)
10
+ list[DPOPair]
11
+ ↓ (NEW — composer_replication.replaysim.normalize.DJNormalizer)
12
+ list[NormalizedDPOPair] # length-filtered, dedup'd, chat-template-validated
13
+
14
+ The pre-normalization pipeline is unchanged. The normalizer is opt-in via
15
+ the new convenience function `replay_and_normalize_trace(...)` which wraps
16
+ the existing `replay_trace` + `extract_dpo_pairs` and pipes their output
17
+ through a `data-juicer` op-graph.
18
+
19
+ Adopting `data-juicer` (Alibaba, Apache-2.0) was the verdict from the
20
+ 2026-05-26 reconnaissance — see docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md.
21
+ It's the only mature library with NATIVE multi-turn `messages` + DPO
22
+ preference-pair ops that runs CPU-only on the ops we need.
23
+
24
+ Optional dependency: `pip install -e .[replaysim]` pulls `data-juicer`.
25
+ Without it, the normalizer raises `ImportError` at use time but the
26
+ package still imports cleanly.
27
+
28
+ This module re-exports the existing `teacher_replay` API for convenience
29
+ so users can `from composer_replication.replaysim import replay_trace`.
30
+ """
31
+ from __future__ import annotations
32
+
33
+ from composer_replication.replaysim.normalize import (
34
+ DJNormalizer,
35
+ NormalizedDPOPair,
36
+ replay_and_normalize_trace,
37
+ )
38
+
39
+ # Re-exports from the pre-existing teacher_replay module (unchanged):
40
+ from composer_replication.teacher_replay import (
41
+ DPOPair,
42
+ TeacherCallResult,
43
+ extract_dpo_pairs,
44
+ replay_trace,
45
+ )
46
+
47
+ __all__ = [
48
+ "DJNormalizer",
49
+ "DPOPair",
50
+ "NormalizedDPOPair",
51
+ "TeacherCallResult",
52
+ "extract_dpo_pairs",
53
+ "replay_and_normalize_trace",
54
+ "replay_trace",
55
+ ]
composer_replication/replaysim/normalize.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DJNormalizer — data-juicer adapter for replaysim DPO output.
2
+
3
+ Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph.
4
+ The op-graph runs entirely CPU-side and applies length filtering, chat-
5
+ template validation, and per-conversation deduplication. Ops are loaded
6
+ from a YAML recipe so users can swap normalization strategies without
7
+ touching framework code.
8
+
9
+ Default recipe lives at:
10
+ composer_replication/recipes/replaysim/default.yaml
11
+
12
+ The data-juicer dependency is optional (pulled by the `[replaysim]` extra).
13
+ This file imports it lazily inside method bodies so that the package
14
+ imports cleanly without it.
15
+
16
+ Source-of-truth shape (from `composer_replication.teacher_replay`):
17
+
18
+ DPOPair = TypedDict("DPOPair", {
19
+ "state_id": str,
20
+ "state_messages": list[dict], # conversation up to this step
21
+ "chosen": str, # teacher-consensus action
22
+ "rejected": str, # student action
23
+ "n_teachers_agreeing": int,
24
+ })
25
+
26
+ The normalizer does NOT require chosen_teacher / rejected_teacher fields —
27
+ those don't exist in the real DPOPair shape.
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import asyncio
32
+ import json
33
+ import os
34
+ import tempfile
35
+ from dataclasses import dataclass
36
+ from pathlib import Path
37
+ from typing import Any, Iterable, cast
38
+
39
+ from composer_replication.teacher_replay import (
40
+ DPOPair,
41
+ TeacherCallResult,
42
+ extract_dpo_pairs,
43
+ replay_trace,
44
+ )
45
+
46
+
47
+ @dataclass
48
+ class NormalizedDPOPair:
49
+ """A DPOPair that has passed through normalization. Same data as
50
+ DPOPair but reshaped into chat-messages format (matching data-juicer's
51
+ native multi-turn op support) plus a metadata dict tracking which
52
+ ops fired.
53
+ """
54
+ state_id: str
55
+ """Identifier for the trace state (turn) this pair came from."""
56
+
57
+ state_messages: list[dict[str, Any]]
58
+ """The conversation context up to (and including) this step's user prompt."""
59
+
60
+ chosen_messages: list[dict[str, Any]]
61
+ """The chosen completion as a chat-messages list (one assistant turn)."""
62
+
63
+ rejected_messages: list[dict[str, Any]]
64
+ """The rejected completion as a chat-messages list (one assistant turn)."""
65
+
66
+ n_teachers_agreeing: int
67
+ """How many teachers agreed on the chosen action (preserved from DPOPair)."""
68
+
69
+ metadata: dict[str, Any]
70
+ """Op-graph provenance: which ops fired, what they changed."""
71
+
72
+
73
+ def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
74
+ """Convert a DPOPair (or dict-shaped equivalent) into a data-juicer
75
+ record using the messages format.
76
+ """
77
+ p = cast(dict[str, Any], pair)
78
+ return {
79
+ "state_id": p.get("state_id", ""),
80
+ "messages": p.get("state_messages", []),
81
+ "chosen": [{"role": "assistant", "content": p.get("chosen", "")}],
82
+ "rejected": [{"role": "assistant", "content": p.get("rejected", "")}],
83
+ "n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
84
+ }
85
+
86
+
87
+ def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
88
+ """Inverse — convert a data-juicer record back to NormalizedDPOPair."""
89
+ return NormalizedDPOPair(
90
+ state_id=rec.get("state_id", ""),
91
+ state_messages=rec.get("messages", []),
92
+ chosen_messages=rec.get("chosen", []),
93
+ rejected_messages=rec.get("rejected", []),
94
+ n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
95
+ metadata=rec.get("__dj_meta__", {}),
96
+ )
97
+
98
+
99
+ class DJNormalizer:
100
+ """data-juicer-backed normalizer for DPO pairs.
101
+
102
+ Args:
103
+ recipe_path: path to a data-juicer YAML recipe. If None, uses the
104
+ framework's default recipe (length filter + chat-template
105
+ validation + per-conversation dedup).
106
+ skip_dj: if True, the normalizer becomes a passthrough — useful
107
+ for test environments without data-juicer installed. Records
108
+ are still converted to NormalizedDPOPair shape but no ops run.
109
+ """
110
+
111
+ DEFAULT_RECIPE = (
112
+ Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml"
113
+ )
114
+
115
+ def __init__(
116
+ self,
117
+ recipe_path: str | os.PathLike[str] | None = None,
118
+ *,
119
+ skip_dj: bool = False,
120
+ ) -> None:
121
+ self.recipe_path = (
122
+ Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE
123
+ )
124
+ self.skip_dj = skip_dj
125
+
126
+ if not skip_dj:
127
+ try:
128
+ import data_juicer # type: ignore[import-not-found] # noqa: F401
129
+ except ImportError as e:
130
+ raise RuntimeError(
131
+ "DJNormalizer requires data-juicer. Install with "
132
+ "`pip install -e .[replaysim]` or pass skip_dj=True "
133
+ "for a passthrough. Got: " + repr(e)
134
+ )
135
+
136
+ if not self.skip_dj and not self.recipe_path.exists():
137
+ raise FileNotFoundError(
138
+ f"Recipe not found: {self.recipe_path}. Either pass an "
139
+ f"explicit recipe_path or add the default recipe at this "
140
+ f"location."
141
+ )
142
+
143
+ def normalize(
144
+ self,
145
+ pairs: Iterable[DPOPair | dict[str, Any]],
146
+ ) -> list[NormalizedDPOPair]:
147
+ """Run the full normalization op-graph on a batch of DPO pairs.
148
+
149
+ Args:
150
+ pairs: iterable of DPOPair (output of extract_dpo_pairs) or
151
+ dict-shaped equivalents.
152
+
153
+ Returns:
154
+ list of NormalizedDPOPair, possibly shorter than input (filter
155
+ ops can drop records).
156
+ """
157
+ records = [_dpo_pair_to_dj_record(p) for p in pairs]
158
+
159
+ if self.skip_dj:
160
+ for rec in records:
161
+ rec["__dj_meta__"] = {"skipped": True}
162
+ return [_dj_record_to_normalized(r) for r in records]
163
+
164
+ # Real path: write to temp JSONL, hand to data-juicer's Executor,
165
+ # read back. data-juicer's CLI contract is file-in / file-out.
166
+ from data_juicer.config import init_configs # type: ignore[import-not-found]
167
+ from data_juicer.core import DefaultExecutor # type: ignore[import-not-found]
168
+
169
+ with tempfile.TemporaryDirectory() as td:
170
+ input_path = Path(td) / "input.jsonl"
171
+ output_path = Path(td) / "output.jsonl"
172
+ with input_path.open("w") as f:
173
+ for rec in records:
174
+ f.write(json.dumps(rec) + "\n")
175
+ cfg = init_configs(
176
+ args=[
177
+ "--config", str(self.recipe_path),
178
+ "--dataset_path", str(input_path),
179
+ "--export_path", str(output_path),
180
+ ],
181
+ )
182
+ executor = DefaultExecutor(cfg)
183
+ executor.run()
184
+
185
+ output_records: list[dict[str, Any]] = []
186
+ with output_path.open() as f:
187
+ for line in f:
188
+ line = line.strip()
189
+ if not line:
190
+ continue
191
+ output_records.append(json.loads(line))
192
+
193
+ return [_dj_record_to_normalized(r) for r in output_records]
194
+
195
+
196
+ # ---------------------------------------------------------------------
197
+ # Convenience: replay + extract pairs + normalize, end to end.
198
+ # ---------------------------------------------------------------------
199
+
200
+
201
+ async def replay_and_normalize_trace(
202
+ *,
203
+ states: Any,
204
+ teachers: Any = None,
205
+ agreement_threshold: int = 2,
206
+ max_total_usd: float = 5.0,
207
+ normalizer: DJNormalizer | None = None,
208
+ **replay_kwargs: Any,
209
+ ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
210
+ """Async convenience: replay → extract pairs → normalize, in one call.
211
+
212
+ The underlying `replay_trace` is async; this wrapper preserves that
213
+ so callers can `await` it from an async context. For sync callers
214
+ use `replay_and_normalize_trace_sync`.
215
+
216
+ Args:
217
+ states: sequence of TraceState (the frozen agentic trace)
218
+ teachers: sequence of TeacherSpec (default: framework defaults)
219
+ agreement_threshold: passed to `extract_dpo_pairs`
220
+ max_total_usd: passed to `replay_trace`
221
+ normalizer: defaults to `DJNormalizer()`. Pass
222
+ `DJNormalizer(skip_dj=True)` to bypass data-juicer.
223
+ **replay_kwargs: extra kwargs forwarded to `replay_trace`.
224
+
225
+ Returns:
226
+ Tuple of (raw teacher_actions, normalized DPO pairs).
227
+ """
228
+ if normalizer is None:
229
+ normalizer = DJNormalizer()
230
+
231
+ if teachers is None:
232
+ teacher_actions = await replay_trace(
233
+ states=states, max_total_usd=max_total_usd, **replay_kwargs,
234
+ )
235
+ else:
236
+ teacher_actions = await replay_trace(
237
+ states=states,
238
+ teachers=teachers,
239
+ max_total_usd=max_total_usd,
240
+ **replay_kwargs,
241
+ )
242
+
243
+ # extract_dpo_pairs reads student_action from each state's
244
+ # `student_action` field, so we don't need to pass it separately.
245
+ raw_pairs = extract_dpo_pairs(
246
+ states=states,
247
+ teacher_actions=teacher_actions,
248
+ agreement_threshold=agreement_threshold,
249
+ )
250
+
251
+ normalized = normalizer.normalize(raw_pairs)
252
+ return teacher_actions, normalized
253
+
254
+
255
+ def replay_and_normalize_trace_sync(
256
+ *args: Any,
257
+ **kwargs: Any,
258
+ ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
259
+ """Sync wrapper for the async `replay_and_normalize_trace`. Convenient
260
+ for scripts and tests.
261
+ """
262
+ return asyncio.run(replay_and_normalize_trace(*args, **kwargs))
263
+
264
+
265
+ __all__ = [
266
+ "DJNormalizer",
267
+ "NormalizedDPOPair",
268
+ "replay_and_normalize_trace",
269
+ "replay_and_normalize_trace_sync",
270
+ ]
composer_replication/replaysim/tests/__init__.py ADDED
File without changes
composer_replication/replaysim/tests/test_replaysim.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Replaysim normalization tests — the skip_dj passthrough path.
2
+
3
+ The full data-juicer path requires `pip install -e .[replaysim]` which we
4
+ defer to the user's environment. These tests verify:
5
+
6
+ 1. The package imports cleanly without data-juicer installed.
7
+ 2. `DJNormalizer(skip_dj=True)` is a working passthrough.
8
+ 3. The DPOPair → DJ-record → NormalizedDPOPair shape transforms are
9
+ lossless modulo the metadata field.
10
+ 4. The DPOPair dict shape (TypedDict) is what we expect.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import pytest
15
+
16
+ from composer_replication.replaysim import (
17
+ DJNormalizer,
18
+ NormalizedDPOPair,
19
+ replay_and_normalize_trace,
20
+ )
21
+ from composer_replication.replaysim.normalize import (
22
+ _dj_record_to_normalized,
23
+ _dpo_pair_to_dj_record,
24
+ )
25
+
26
+
27
+ def _make_pair(
28
+ state_id: str,
29
+ state_messages: list[dict] | None = None,
30
+ chosen: str = "Four.",
31
+ rejected: str = "Five.",
32
+ n_teachers_agreeing: int = 2,
33
+ ) -> dict:
34
+ """Helper — DPOPair is a TypedDict, so dicts work directly."""
35
+ return {
36
+ "state_id": state_id,
37
+ "state_messages": state_messages or [{"role": "user", "content": "What is 2+2?"}],
38
+ "chosen": chosen,
39
+ "rejected": rejected,
40
+ "n_teachers_agreeing": n_teachers_agreeing,
41
+ }
42
+
43
+
44
+ def test_dpo_pair_to_dj_record_shape():
45
+ p = _make_pair("s1")
46
+ rec = _dpo_pair_to_dj_record(p)
47
+ assert rec["state_id"] == "s1"
48
+ assert rec["messages"] == [{"role": "user", "content": "What is 2+2?"}]
49
+ assert rec["chosen"] == [{"role": "assistant", "content": "Four."}]
50
+ assert rec["rejected"] == [{"role": "assistant", "content": "Five."}]
51
+ assert rec["n_teachers_agreeing"] == 2
52
+
53
+
54
+ def test_dj_record_to_normalized_roundtrip():
55
+ p = _make_pair("s2", chosen="C", rejected="R", n_teachers_agreeing=3)
56
+ rec = _dpo_pair_to_dj_record(p)
57
+ rec["__dj_meta__"] = {"ops_applied": ["text_length_filter"]}
58
+ norm = _dj_record_to_normalized(rec)
59
+ assert isinstance(norm, NormalizedDPOPair)
60
+ assert norm.state_id == "s2"
61
+ assert norm.chosen_messages == [{"role": "assistant", "content": "C"}]
62
+ assert norm.rejected_messages == [{"role": "assistant", "content": "R"}]
63
+ assert norm.n_teachers_agreeing == 3
64
+ assert norm.metadata == {"ops_applied": ["text_length_filter"]}
65
+
66
+
67
+ def test_dj_record_to_normalized_preserves_state_messages():
68
+ """The conversation context (state_messages) must round-trip."""
69
+ multi_turn = [
70
+ {"role": "user", "content": "What is 2+2?"},
71
+ {"role": "assistant", "content": "Let me think."},
72
+ {"role": "user", "content": "Just give me a number."},
73
+ ]
74
+ p = _make_pair("s3", state_messages=multi_turn)
75
+ rec = _dpo_pair_to_dj_record(p)
76
+ norm = _dj_record_to_normalized(rec)
77
+ assert norm.state_messages == multi_turn
78
+
79
+
80
+ def test_dj_normalizer_skip_dj_passthrough():
81
+ """skip_dj=True: bypasses data-juicer entirely, just does shape conversion."""
82
+ pairs = [
83
+ _make_pair("s1", chosen="c1", rejected="r1"),
84
+ _make_pair("s2", chosen="c2", rejected="r2"),
85
+ ]
86
+ normalizer = DJNormalizer(skip_dj=True)
87
+ out = normalizer.normalize(pairs)
88
+ assert len(out) == 2
89
+ assert all(isinstance(o, NormalizedDPOPair) for o in out)
90
+ assert out[0].state_id == "s1"
91
+ assert out[1].state_id == "s2"
92
+ assert out[0].metadata == {"skipped": True}
93
+ assert out[1].metadata == {"skipped": True}
94
+
95
+
96
+ def test_dj_normalizer_skip_dj_preserves_count():
97
+ """Passthrough must not drop records — only filter ops do that."""
98
+ pairs = [_make_pair(f"s{i}") for i in range(10)]
99
+ normalizer = DJNormalizer(skip_dj=True)
100
+ out = normalizer.normalize(pairs)
101
+ assert len(out) == 10
102
+
103
+
104
+ def test_dj_normalizer_default_recipe_path_exists():
105
+ """The default recipe ships with the package."""
106
+ assert DJNormalizer.DEFAULT_RECIPE.exists(), \
107
+ f"Default recipe missing at {DJNormalizer.DEFAULT_RECIPE}"
108
+
109
+
110
+ def test_dj_normalizer_real_path_requires_data_juicer():
111
+ """Without skip_dj, instantiation requires data-juicer or fails clearly."""
112
+ try:
113
+ import data_juicer # type: ignore[import-not-found] # noqa: F401
114
+ except ImportError:
115
+ with pytest.raises(RuntimeError, match="data-juicer"):
116
+ DJNormalizer(skip_dj=False)
117
+ else:
118
+ # data-juicer IS installed; verify init succeeds with default recipe
119
+ normalizer = DJNormalizer(skip_dj=False)
120
+ assert normalizer.recipe_path == DJNormalizer.DEFAULT_RECIPE
121
+
122
+
123
+ def test_replay_and_normalize_trace_signature():
124
+ """Convenience function is callable and importable. Smoke-only — we
125
+ don't run it against OpenRouter from CI."""
126
+ assert callable(replay_and_normalize_trace)
127
+ # It's an async function
128
+ import inspect
129
+ assert inspect.iscoroutinefunction(replay_and_normalize_trace)
130
+
131
+
132
+ def test_record_handles_missing_optional_fields():
133
+ """A DPOPair dict missing some optional fields shouldn't crash the converter."""
134
+ minimal = {"state_id": "x", "chosen": "a", "rejected": "b"}
135
+ rec = _dpo_pair_to_dj_record(minimal)
136
+ assert rec["state_id"] == "x"
137
+ assert rec["messages"] == [] # missing state_messages → empty list
138
+ assert rec["n_teachers_agreeing"] == 0 # missing → default 0
docs/ALTERED_MINDS_TIE_IN.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # altered-minds × Composer Replication Framework
2
+
3
+ **Status**: Tie-in design doc.
4
+ **Date**: 2026-05-26 (Wave 13)
5
+ **Source workstream**: `llm-mental-alterations` (formerly Codeseys/llm-mental-alterations
6
+ on HF; user has indicated a rename to `altered-minds`)
7
+
8
+ ## What altered-minds is studying
9
+
10
+ From the user's existing wiki notes (`~/wiki/projects/llm-mental-alterations.md`):
11
+
12
+ - Fine-tuning Llama-3.1-8B with **personality SFT** induces a depression/
13
+ anxiety cognitive-distortion signature on MMLU `moral_scenarios`:
14
+ - Class 3 ("both fine") collapses **−31.1pp**
15
+ - Class 0 ("both wrong") improves **+4.6pp**
16
+ - Multi-seed reproducible (4/4 seeds, n=895)
17
+ - 18% of base-correct items broken
18
+ - Other domains affected: `high_school_chemistry +4.2pp`,
19
+ `machine_learning +4.9pp` (reliably improved).
20
+ - H-3 Gemma-MoE hypothesis is deferred (Hopper-only).
21
+ - Spend so far: $9.75 / $400 budget.
22
+
23
+ The headline question driving the workstream is roughly:
24
+ **"What measurable cognitive alterations does personality-style SFT
25
+ introduce, and can we recover or sharpen them via downstream RL?"**
26
+
27
+ ## Why this framework is the right second-stage workstream
28
+
29
+ altered-minds today is an **SFT-only** pipeline. A typical run:
30
+ 1. Take a base model (Llama-3.1-8B).
31
+ 2. Apply personality SFT.
32
+ 3. Evaluate on MMLU + alteration-specific probes.
33
+ 4. Document the alteration signature.
34
+
35
+ The Composer Replication Framework, by design, is a **post-SFT
36
+ reinforcement-learning framework**. It can take any HF model — including
37
+ an altered-minds-altered model — and apply:
38
+ - **GRPO** with verifiable rewards
39
+ - **SDPO/OPSD** self-distillation against the altered model's hint-
40
+ conditioned forward passes
41
+ - **Trace-replay DPO** against N external teachers
42
+
43
+ That gives altered-minds three orthogonal axes of investigation it doesn't
44
+ currently have:
45
+
46
+ | Axis | What changes | What we learn |
47
+ |---|---|---|
48
+ | **GRPO with verifiable reward** | Train the altered model on math/code where ground truth is checkable | Does the alteration's "personality" persist under task-driven RL, or does it wash out? |
49
+ | **SDPO against the altered model's own hints** | Self-distillation — the altered model teaches itself with hint-conditioned forward passes | Can we **sharpen** the alteration without further SFT? |
50
+ | **Trace-replay DPO with frontier teachers** | The altered model rolls out, frontier teachers replay the same prompts, disagreement → DPO pairs | Where does the altered model **disagree** with frontier consensus? Are those disagreements correlated with the cognitive-distortion signature? |
51
+
52
+ The **third** axis is the most interesting for altered-minds specifically.
53
+ The framework's `replay_trace` + `extract_dpo_pairs` produce, by construction,
54
+ a dataset of "altered-model output" vs "frontier-consensus output" for any
55
+ prompt distribution. If the altered model's depression/anxiety signature
56
+ shows up in moral_scenarios, then the trace-replay output on
57
+ moral-scenario prompts is **a measurable corpus of the alteration**.
58
+
59
+ ## Concrete plan: altered-minds-RL spike
60
+
61
+ ### Phase 1 — model selection
62
+ Pick the altered-minds checkpoint that produced the strongest signature
63
+ (per the user's notes: the multi-seed Llama-3.1-8B personality-SFT run
64
+ where moral_scenarios class 3 collapsed −31.1pp).
65
+
66
+ ### Phase 2 — domain-specific replaysim
67
+
68
+ Run `composer_replication.replaysim.replay_and_normalize_trace` against:
69
+ - A held-out moral_scenarios test set (the alteration locus)
70
+ - A held-out high_school_chemistry test set (where altered-minds *improved*)
71
+ - A held-out general MMLU baseline
72
+
73
+ Teachers: framework defaults (Claude Opus 4.7, GPT-5, DeepSeek V4 Pro).
74
+ This produces **three normalized DPO datasets** capturing where the
75
+ altered model disagrees with frontier consensus on each domain.
76
+
77
+ Cost estimate: ~$0.98/trace × 100 prompts × 3 domains ≈ **$300**.
78
+ Fits inside the user's existing $400 altered-minds budget.
79
+
80
+ ### Phase 3 — GRPO with the framework
81
+
82
+ Run `composer_replication.recipes.trl.ComposerReplicationTrainer` with:
83
+ - **Channel 1 (GRPO)**: turned ON, reward = MMLU letter-correctness
84
+ - **Channel 2 (SDPO/OPSD)**: turned ON at α=0.2, hint-conditioned
85
+ against the altered model's own forward pass
86
+ - **Channel 3 (trace-replay DPO)**: turned ON at β=0.4, against the
87
+ Phase-2 datasets
88
+
89
+ Train for ~500 steps on a single GPU (Qwen-0.5B feasibility-test
90
+ already confirmed in the framework; for Llama-8B, use Modal + the
91
+ framework's `ServerlessExecutor` per ADR-005 — local 5090 is too small).
92
+
93
+ ### Phase 4 — re-evaluate
94
+
95
+ Re-run the same MMLU + alteration probes used originally on the
96
+ **post-RL** model. Three outcomes are possible:
97
+
98
+ | Outcome | Interpretation |
99
+ |---|---|
100
+ | Alteration signature persists at same magnitude | The alteration is robust to task-driven RL — useful as a lower bound on its "depth" |
101
+ | Alteration signature attenuates | Task-driven RL washes out personality-SFT — useful for understanding alteration brittleness |
102
+ | Alteration signature **amplifies** on channel-2-only ablation | SDPO is reinforcing the alteration; rare and significant — would be a publishable finding |
103
+
104
+ ### Phase 5 — Decoupled DiLoCo for multi-personality experiments
105
+
106
+ Once a single altered-minds-RL run works, the framework's serverless
107
+ DiLoCo (ADR-005) lets us run **N personality-altered models in parallel
108
+ across Modal/HF Jobs**, with their pseudo-gradients pooled via object
109
+ storage. This becomes the natural sweep over personality types
110
+ (depression vs anxiety vs grandiose vs ...) at minimal incremental
111
+ infrastructure cost.
112
+
113
+ ## Repo layout proposal
114
+
115
+ The Composer Replication Framework is intentionally generic. The
116
+ altered-minds-specific RL spike should live as a separate repo or
117
+ subdirectory **using** the framework, not inside it:
118
+
119
+ ```
120
+ altered-minds/ # the renamed llm-mental-alterations repo
121
+ composer_replication_runs/ # NEW
122
+ moral_scenarios_replay.py # uses composer_replication.replaysim
123
+ train_grpo.py # uses composer_replication.trainer
124
+ eval_post_rl.py # standard altered-minds eval
125
+ recipes/
126
+ altered_minds.yaml # data-juicer recipe — symlinks/copies
127
+ # composer_replication's default + adds
128
+ # MMLU-format-aware ops
129
+ ```
130
+
131
+ The framework provides the algorithm + infrastructure. The altered-minds
132
+ repo owns the experimental narrative + results.
133
+
134
+ ## Open questions for the user
135
+
136
+ Before we proceed to Phase 1:
137
+
138
+ 1. **Confirm the rename**: the wiki memory says `llm-mental-alterations`
139
+ on HF; user wants `altered-minds` — should we rename the HF repo?
140
+ 2. **Budget allocation**: the $300 trace-replay cost (Phase 2) eats most
141
+ of the remaining $390 altered-minds budget. Is that acceptable, or
142
+ should we use only one domain (moral_scenarios) for $100?
143
+ 3. **GPU venue for Phase 3**: 8B-model RL on single-GPU is feasible on
144
+ the user's RTX 5090 (32GB) for short runs, OR we use Modal A100s for
145
+ a more aggressive run. Preference?
146
+
147
+ ## References
148
+
149
+ - altered-minds workstream wiki: `~/wiki/projects/llm-mental-alterations.md`
150
+ - Framework ADRs: docs/adrs/ADR-001 through ADR-007
151
+ - Framework V1-V8 brief coverage: docs/V1_V8_COVERAGE.md
152
+ - Self-distillation landscape: docs/research/SELF_DISTILLATION_LANDSCAPE.md
153
+ (relevant: TAID's annealed-teacher schedule could test "alteration
154
+ recovery" by interpolating between altered-init and base-teacher)
docs/V1_V8_COVERAGE.md CHANGED
@@ -90,5 +90,27 @@ This is the post-replication phase. The CPU-only deep-work-loop phase (Waves 7-1
90
 
91
  - `docs/VISION_VALIDATION.md` — original 10-point scorecard + post-Wave-11 honest re-scoring
92
  - `docs/research/WAVE_7_10_FINAL_REVIEW.md` — cross-model adversarial review of Wave 7-10 (10 priority items, 2 BLOCKERs both addressed)
93
- - `docs/adrs/ADR-001..003` — three architectural decisions (GPU venue, trace source, DiLoCo impl)
94
  - `BACKLOG.md` — pre-execution acceptance criteria for Spikes 006/007/008 + Wave 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  - `docs/VISION_VALIDATION.md` — original 10-point scorecard + post-Wave-11 honest re-scoring
92
  - `docs/research/WAVE_7_10_FINAL_REVIEW.md` — cross-model adversarial review of Wave 7-10 (10 priority items, 2 BLOCKERs both addressed)
93
+ - `docs/adrs/ADR-001..007` — seven architectural decisions (GPU venue, trace source, DiLoCo impl, replaysim normalization, serverless DiLoCo, RL frameworks, distillation losses)
94
  - `BACKLOG.md` — pre-execution acceptance criteria for Spikes 006/007/008 + Wave 10
95
+
96
+ ---
97
+
98
+ ## Wave 13 expansion (2026-05-26)
99
+
100
+ The user expanded the brief mid-loop:
101
+
102
+ > *"keep going. make sure that we do the paths of the Composer 2.5 methods, the n-teachers replaysim, and Decoupled DiLoCo (so that we can leverage modal or huggingface-jobs or other serverless training systems). … For V5 see if we can leverage [a normalization library] to normalize the data while also making the replaysim dataset generation. … if we can properly document and research the self-distillation papers like SDPO OPDS and/or others. … see if there are other frameworks that are more popular that we could try to use. meta's pytorch agentic stack components are something that I'd like to explore."*
103
+
104
+ | Wave 13 ask | Artifact | Status |
105
+ |---|---|---|
106
+ | Decoupled DiLoCo over serverless | ADR-005 + `composer_replication.diloco.serverless` (Protocol + ObjectStoreAllReduce + LocalProcessExecutor + Modal/HFJobs skeletons) + 9 multi-process tests | ✅ Closed (local) / 🟡 Skeleton (cloud) |
107
+ | Replaysim normalization | ADR-004 + `composer_replication.replaysim` package + `data-juicer` adapter + default YAML recipe + 9 unit tests | ✅ Closed (passthrough) / 🟡 Pending data-juicer install for full path |
108
+ | Other RL frameworks (V3 expansion) | ADR-006 + `composer_replication.recipes.prime_rl` (recipe + composer_loss adapter + config.yaml) | ✅ Closed (recipe) / 🟡 Skeleton (runtime) |
109
+ | Meta's PyTorch agentic stack | ADR-006 + `composer_replication.recipes.monarch` (actor layout doc + skeleton actors) | ✅ Closed (design) / 🟡 Skeleton (impl) |
110
+ | Deeper self-distillation research | ADR-007 + `docs/research/SELF_DISTILLATION_LANDSCAPE.md` + `composer_replication.distillation` module (SimPO + TAID + Entropy-Aware OPD) + 17 unit tests | ✅ Closed (standalone losses) / 🟡 Deferred to Wave 14 (`compose_loss` kwargs not yet wired — Wave 13 review Finding 2) |
111
+ | altered-minds tie-in | `docs/ALTERED_MINDS_TIE_IN.md` (5-phase plan, $300 estimate, open questions) | ✅ Closed (design) |
112
+
113
+ **Wave 13 test addition**: 35 new tests passing (17 distillation + 9 serverless multi-process + 9 replaysim).
114
+
115
+ The framework now covers the full expanded brief. Total tests passing
116
+ across the framework as of Wave 13: **107** (72 from prior waves + 35 new).
docs/V3_SUBSTRATE_COVERAGE.md CHANGED
@@ -151,12 +151,16 @@ even if it doesn't translate to code.
151
  |---|---|---|---|---|---|
152
  | TRL | ✅ | ✅ | ✅ | 38 + 9 + 3 = 50 | ✅ |
153
  | VeRL | ✅ | ✅ | 🟡 (skeleton) | — | v0.2 |
154
- | DiLoCo | | ✅ | ✅ | 5 (single-replica) | optional |
 
 
155
  | OpenEnv | ✅ | ✅ | n/a (protocol) | — | substrate |
156
- | Monarch | ✅ | ✅ (reference) | n/a | — | future option |
157
  | TorchForge | ✅ | n/a (paused) | n/a | — | n/a |
158
 
159
- **6/6 substrates covered.** Code-bearing integrations (TRL, VeRL, DiLoCo)
160
- have working extension points. Reference substrates (OpenEnv, Monarch,
161
- TorchForge) are documented as research outputs, which matches the brief's
162
- "research...how we could try to set this up" framing.
 
 
 
151
  |---|---|---|---|---|---|
152
  | TRL | ✅ | ✅ | ✅ | 38 + 9 + 3 = 50 | ✅ |
153
  | VeRL | ✅ | ✅ | 🟡 (skeleton) | — | v0.2 |
154
+ | **PRIME-RL** (Wave 13) | ✅ | ✅ | 🟡 (loss adapter + config) | | v0.2 (cleanest hook) |
155
+ | DiLoCo (single-process) | ✅ | ✅ | ✅ | 5 (single-replica) | optional |
156
+ | **DiLoCo over serverless** (Wave 13) | ✅ | ✅ ADR-005 | ✅ Local + 🟡 Modal/HFJobs | 9 multi-process | ✅ (local) / future (cloud) |
157
  | OpenEnv | ✅ | ✅ | n/a (protocol) | — | substrate |
158
+ | **Monarch** (Wave 13) | ✅ | ✅ (actor layout) | 🟡 (skeleton) | — | v0.2+ |
159
  | TorchForge | ✅ | n/a (paused) | n/a | — | n/a |
160
 
161
+ **8/8 substrates covered** (was 6/6 pre-Wave-13). New since Wave 13:
162
+ PRIME-RL (the cleanest custom-loss hook), Monarch (Meta's actively-shipped
163
+ agentic-stack component), and serverless DiLoCo (Modal/HF Jobs adapters
164
+ + object-store rendezvous). The framework can now realize Decoupled
165
+ DiLoCo across cloud executors **without any cross-job NCCL** — see
166
+ ADR-005 for the design rationale.
docs/adrs/ADR-004-replaysim-normalization.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADR-004 — Replaysim normalization layer for the trace-replay channel
2
+
3
+ **Status**: Accepted
4
+ **Date**: 2026-05-26
5
+ **Wave**: 13 (deep work loop, expansion phase)
6
+
7
+ ## Context
8
+
9
+ The brief's V5 clause says:
10
+
11
+ > use traces from an llm-application usage then replay the traces with
12
+ > different models to see at each llm-step what the llm would do. by doing
13
+ > this we get distillation data from any number of models that could be
14
+ > used to train the target model further
15
+
16
+ The user added 2026-05-26: *"see if we can leverage [a normalization
17
+ library] to normalize the data while also making the replaysim dataset
18
+ generation."*
19
+
20
+ Currently the framework has `composer_replication.teacher_replay`:
21
+ - `replay_trace()` — N-teacher OpenRouter replay, returns
22
+ `list[TeacherCallResult]`
23
+ - `extract_dpo_pairs()` — converts teacher disagreement to `list[DPOPair]`
24
+
25
+ This produces preference-pair training data, but with **zero normalization**:
26
+ no dedup, no length filtering, no language detection, no quality
27
+ filtering, no chat-template validation. The output is closer to "raw
28
+ LLM API responses" than "training-ready dataset."
29
+
30
+ For the replaysim to power downstream RL training (V6), the dataset needs
31
+ to be production-quality. Hand-rolling that pipeline is a tax we'd rather
32
+ not pay.
33
+
34
+ ## Options considered
35
+
36
+ Audited five candidates in `docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md`:
37
+
38
+ | Library | License | Multi-turn? | DPO pairs? | Streaming? | GPU? | Verdict |
39
+ |---|---|---|---|---|---|---|
40
+ | HuggingFace `datatrove` | MIT | ❌ flat-text only | ❌ | ✅ | ❌ | Deal-breaker on multi-turn |
41
+ | Alibaba `data-juicer` | Apache-2 | ✅ native `messages` ops | ✅ `pair_preference_mapper` | ✅ | ❌ for ops we need | **Chosen** |
42
+ | NVIDIA `nemo-curator` | Apache-2 | partial | ❌ | ✅ | ✅ mandatory for differentiating ops | Reject — GPU-bound for the ops we need |
43
+ | Argilla `distilabel` | Apache-2 | ✅ native chat | ✅ formatters | ✅ | ❌ | Reject — would replace teacher orchestration, not just normalize |
44
+ | Databricks `lilac` | — | n/a | n/a | n/a | n/a | Reject — archived 2024-03 |
45
+
46
+ ## Decision
47
+
48
+ **Adopt `data-juicer` (Alibaba/modelscope, Apache-2.0, last push 2026-05-25, 6.4k★).**
49
+
50
+ Reasons:
51
+
52
+ 1. **It's the only candidate with native multi-turn + DPO support in the
53
+ *normalization* op-graph.** Has `pair_preference_mapper`,
54
+ `dialog_intent_detection_mapper`, `dialog_topic_detection_mapper`,
55
+ etc. that operate on chat-format messages directly.
56
+
57
+ 2. **CPU-runnable for our op set.** The differentiating ops we need
58
+ (length filter, language ID, chat-template validation, dedup) all
59
+ work on CPU. We avoid the NeMo-Curator GPU dependency entirely.
60
+
61
+ 3. **Streaming-friendly.** Op graph is a DAG; we can pipe `replay_trace`
62
+ output into the graph during generation, not as a post-hoc pass. This
63
+ matters for cost discipline — bad teacher outputs get filtered before
64
+ contributing to OpenRouter spend on subsequent steps.
65
+
66
+ 4. **YAML-recipe driven.** Recipes live in `recipes/replaysim/` and can
67
+ be version-controlled. A user can swap normalization recipes without
68
+ touching framework code.
69
+
70
+ ## Consequences
71
+
72
+ ### Accepted
73
+
74
+ - New module `composer_replication.replaysim` lifts the existing
75
+ `teacher_replay` logic out of the package's flat namespace and adds:
76
+ - `composer_replication.replaysim.normalize` — `DJNormalizer` adapter
77
+ that wraps `data-juicer` op graphs around `replay_trace` output
78
+ - `recipes/replaysim/default.yaml` — base normalization recipe (length
79
+ filter + chat-template validation + per-turn dedup)
80
+ - Optional `recipes/replaysim/with_disagreement_filter.yaml` — adds a
81
+ semantic-similarity filter that drops "false disagreements" where
82
+ teachers used different wording for the same answer
83
+ - New optional dependency `[replaysim]` extra in `pyproject.toml`:
84
+ `pip install -e .[replaysim]` pulls `data-juicer`. Core install
85
+ doesn't require it.
86
+ - The existing `replay_trace` and `extract_dpo_pairs` keep their
87
+ signatures. The normalizer is opt-in via a `normalizer=` kwarg on a
88
+ new `replay_and_normalize_trace` convenience function.
89
+
90
+ ### One-day spike before merge
91
+
92
+ `pair_preference_mapper` in data-juicer might unconditionally re-synthesize
93
+ the `rejected` text via an LLM call. We already have `rejected` from
94
+ teacher disagreement and don't want to pay another API call. The recon
95
+ flagged this — verify by reading the mapper's source, and if it's LLM-bound,
96
+ substitute a plain validator that checks the field exists + isn't empty.
97
+
98
+ If the spike fails (the mapper IS LLM-bound and isn't easily replaceable),
99
+ fall back to writing a custom `DJOp` subclass that validates pre-existing
100
+ DPO pairs without re-synthesis. ~50 LOC.
101
+
102
+ ### Rejected paths
103
+
104
+ - **`datatrove`**: would have required hand-rolling all chat-template logic
105
+ on top of flat-text ops. Bigger ongoing maintenance cost than
106
+ data-juicer's native multi-turn support.
107
+ - **`nemo-curator`**: GPU-mandatory ops mean we'd need to pay for GPU during
108
+ dataset generation (separate from the replay phase, which is already
109
+ GPU-free). Net cost increase for no quality win.
110
+ - **`distilabel`**: too broad — its pipeline abstraction would replace our
111
+ `replay_trace` entirely. We'd lose direct OpenRouter cost control + the
112
+ audit trail. Possible v0.3 migration if data-juicer becomes a bottleneck.
113
+
114
+ ### Future work
115
+
116
+ - v0.2: add a `recipes/replaysim/altered_minds.yaml` for the user's
117
+ `altered-minds` workstream tie-in (per Wave 13 expansion)
118
+ - v0.3: revisit if `distilabel` becomes more mature and the migration
119
+ cost vs ongoing-maintenance balance shifts
120
+
121
+ ## Source
122
+
123
+ `docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md` (2026-05-26
124
+ subagent recon, primary-sourced from each repo's GitHub + DeepWiki).
docs/adrs/ADR-005-serverless-diloco.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADR-005 — Decoupled DiLoCo over serverless training systems
2
+
3
+ **Status**: Accepted
4
+ **Date**: 2026-05-26
5
+ **Wave**: 13
6
+
7
+ ## Context
8
+
9
+ The brief's V2 clause says:
10
+
11
+ > take that and combine it with diloco (decoupled, open, any variant of diloco)
12
+
13
+ The user expanded 2026-05-26: *"Decoupled DiLoCo (so that we can leverage
14
+ modal or huggingface-jobs or other serverless training systems). we need
15
+ this both on the dataset generation and the RL orchestration side of
16
+ things."*
17
+
18
+ Spike 008 wrote `composer_replication.diloco.make_diloco_outer_loop`
19
+ (wraps `torchft.local_sgd.DiLoCo`) but that's a single-process API. To
20
+ realize "Decoupled DiLoCo across serverless executors" we need:
21
+
22
+ 1. An abstraction layer that lets the framework launch N replicas on
23
+ different serverless backends (Modal, HF Jobs, SageMaker, etc.) without
24
+ per-backend code in the trainer.
25
+ 2. A communication primitive that doesn't require inter-job NCCL/RDMA
26
+ (most serverless executors don't expose that, and DiLoCo doesn't need
27
+ it — sync happens once per ~500-1000 inner steps).
28
+
29
+ ## Options considered
30
+
31
+ `docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md` audited 6 executors:
32
+
33
+ | Executor | Inter-job network | Cold start | $/A100·hr | $/H100·hr |
34
+ |---|---|---|---|---|
35
+ | Modal | yes (cluster mode) | ~30s | $1.95 | $5.50 |
36
+ | HuggingFace Jobs | no | ~60s | $4.18 | $9.50 |
37
+ | AWS SageMaker training | yes (warm pools) | ~3-5min | ~$3.06 | ~$8.50 |
38
+ | GCP Vertex AI | yes (cluster) | ~5-10min | ~$3.67 | ~$10 |
39
+ | Azure ML | yes (cluster) | ~5-10min | ~$3.67 | ~$10 |
40
+ | k8s + Volcano/KubeRay | yes (cluster IP) | ~30-90s | (BYO) | (BYO) |
41
+
42
+ Most expose a "spin up a job, run a script" interface. Few expose inter-job
43
+ networking; the ones that do require explicit cluster mode (extra cost +
44
+ config).
45
+
46
+ ## Decision
47
+
48
+ **Adopt object-store rendezvous as the default DiLoCo communication
49
+ primitive across all serverless executors.** Specifically:
50
+
51
+ - `composer_replication.diloco.serverless` package
52
+ - `class ServerlessExecutor(Protocol)` — uniform interface with
53
+ `launch_replicas / poll / stream_logs / cancel / collect /
54
+ backend_name / supports_inter_replica_network`
55
+ - `class ObjectStoreAllReduce` — fsspec-backed pseudo-gradient exchange
56
+ using s3:// / gs:// / az:// / hf:// / file:// — single code path, swappable
57
+ bucket
58
+ - v0 concrete adapters: `ModalExecutor` and `HFJobsExecutor`
59
+ - v0.1+ adapters: `RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`
60
+
61
+ ### Why object-store rendezvous (not NCCL across jobs)
62
+
63
+ DiLoCo paper (arXiv:2311.08105) shows the outer-loop sync is **once per
64
+ H = 500-1000 inner steps**, equivalent to ~10-30 minutes of wall-clock at
65
+ typical post-training step rates. For a 1B-param model in bf16:
66
+
67
+ - Pseudo-gradient size: ~2 GB per replica per outer round
68
+ - Sync frequency: ~once per 30 minutes
69
+ - Therefore: ~2 GB × N_replicas, every ~30 min, durably written to object
70
+ storage with a single `PutObject` per replica + `GetObject` per other
71
+ replica
72
+
73
+ Even with N=8 replicas, that's 16 GB write + 14 GB × 8 reads = 128 GB read
74
+ spread over 30 minutes = ~70 MB/s aggregate. **S3 free-tier handles this
75
+ without breaking a sweat**, and S3 cross-job reads cost ~$0.0001 per
76
+ GET. Total inter-replica communication cost: ~$0.05 per outer round.
77
+ **Negligible compared to GPU spend.**
78
+
79
+ By contrast, cross-job NCCL would require:
80
+ - Inter-job networking (mostly unavailable on serverless)
81
+ - Sustained low-latency connections (vs. burst-IO once per 30min)
82
+ - Backend-specific cluster mode (Modal-only on some platforms)
83
+
84
+ Object-store rendezvous decouples the algorithm from the executor and
85
+ matches DiLoCo's actual communication profile.
86
+
87
+ ### Why Modal + HF Jobs as the v0 executors
88
+
89
+ - **Modal**: best dev velocity, sub-minute cold start, mature Python SDK,
90
+ user already has CLI configured. Gives us a fast iteration loop for the
91
+ serverless layer.
92
+ - **HuggingFace Jobs**: zero acquisition cost (HF token already wired up),
93
+ brand-aligned with the framework's HF-native posture, ~$4.18/A100·hr.
94
+ Not the cheapest, but the right "default executor for HF users."
95
+
96
+ These two cover the spectrum of "fast for development" + "natural HF
97
+ integration." Other executors are documented and stubbed but not
98
+ implemented in v0.
99
+
100
+ ## Consequences
101
+
102
+ ### Accepted
103
+
104
+ - New package `composer_replication.diloco.serverless`:
105
+ - `executor.py` — `ServerlessExecutor` Protocol + base class
106
+ - `allreduce.py` — `ObjectStoreAllReduce` mockManager that drops into
107
+ `make_diloco_outer_loop` with no changes to the existing wrapper
108
+ - `modal.py` — `ModalExecutor` (~150 LOC)
109
+ - `hf_jobs.py` — `HFJobsExecutor` (~150 LOC)
110
+ - `replica_entrypoint.py` — the script each replica runs (loaded from
111
+ HF Datasets / object store)
112
+ - New optional dependency `[serverless]` extra: `pip install -e .[serverless]`
113
+ pulls `fsspec`, `s3fs`, `huggingface_hub` (already a transitive dep), and
114
+ `modal-client` (only if user opts in to Modal).
115
+ - Smoke test in `spikes/009-decoupled-diloco/` (new, deferred — not part
116
+ of this wave's commit) — local-only `file://` rendezvous between two
117
+ Python processes in `tests/test_serverless_local.py`. Multi-cloud test
118
+ is post-replication.
119
+
120
+ ### Open / deferred
121
+
122
+ - **Real serverless smoke**: spinning up 2 Modal containers + S3 rendezvous
123
+ + verifying both converge. Deferred to a small-budget post-Wave-13 spike
124
+ ($2-5 estimated). Not blocking for the v0 packaging.
125
+ - **HF Jobs API stability**: HF Jobs is a relatively new product. The
126
+ recon flagged "API may evolve through 2026"; we pin to a specific
127
+ `huggingface_hub` minor and bump deliberately.
128
+
129
+ ### Trade-offs explicitly accepted
130
+
131
+ - We do NOT use Modal's cluster/RDMA mode in v0. That gives sub-second
132
+ cross-job NCCL but costs more and is Modal-only. Object-store rendezvous
133
+ is the right default; users on Modal who want faster sync can override.
134
+ - We do NOT support job-internal multi-GPU training in this layer. The
135
+ serverless layer is for **inter-replica** sync; intra-replica training
136
+ uses the existing `make_diloco_outer_loop` (which itself can wrap
137
+ multi-GPU FSDP via torchft).
138
+
139
+ ## Source
140
+
141
+ `docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md` (2026-05-26 subagent
142
+ recon, primary-sourced from each provider's official docs + pricing pages).
docs/adrs/ADR-006-rl-frameworks.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADR-006 — RL framework strategy: TRL + VeRL + PRIME-RL
2
+
3
+ **Status**: Accepted
4
+ **Date**: 2026-05-26
5
+ **Wave**: 13
6
+
7
+ ## Context
8
+
9
+ The brief's V3 clause names six substrates: **monarch, torchforge,
10
+ openenv, VeRL, TRL** (plus DiLoCo). Cross-model review (Wave 11) flagged
11
+ that V3 was thin on the RL-framework side: TRL has working code, VeRL has
12
+ a config skeleton, and Monarch/TorchForge/OpenEnv are research-only.
13
+
14
+ User's 2026-05-26 expansion: *"see if there are other frameworks that are
15
+ more popular that we could try to use. meta's pytorch agentic stack
16
+ components are something that I'd like to explore."*
17
+
18
+ `docs/research/RL_FRAMEWORKS_LANDSCAPE.md` audited:
19
+ - 6 RL frameworks: OpenRLHF, PRIME-RL, NeMo-Aligner, Unsloth, LLaMA-Factory,
20
+ DeepSpeed-Chat
21
+ - 4 Meta PyTorch stack components: Monarch, TorchTitan, TorchForge, torchchat
22
+
23
+ ## Options considered
24
+
25
+ | Framework | License | GRPO/DAPO? | Custom-loss extension | Verdict |
26
+ |---|---|---|---|---|
27
+ | OpenRLHF | Apache-2 | ✅ DAPO | Fork `openrlhf/models/loss.py` + Trainer subclass (~400-600 LOC) | Strong but heavyweight |
28
+ | **PRIME-RL** | **Apache-2** | **✅ GRPO + DAPO** | **First-class `CustomLossConfig` with `LossInputs` struct (~200-300 LOC)** | **Chosen** |
29
+ | NeMo-Aligner | Apache-2 | ❌ no GRPO/DAPO | n/a | Reject |
30
+ | Unsloth | Apache-2 | TRL patcher | Closed `unsloth_zoo` loss kernels — unhookable | Reject |
31
+ | LLaMA-Factory | Apache-2 | ❌ delegates to EasyR1 | n/a | Reject |
32
+ | DeepSpeed-Chat | Apache-2 | ❌ PPO+DPO only | feature-stale since 2023 | Reject |
33
+
34
+ | Meta stack | License | Active? | Role |
35
+ |---|---|---|---|
36
+ | **Monarch** | **BSD-3** | **✅ v0.4.1 stable, v0.5 dev** | **Actor mesh — coordination layer for any SPMD trainer** |
37
+ | TorchTitan | BSD-3 | ✅ active | Distributed-training stack (already a transitive dep of PRIME-RL) |
38
+ | TorchForge | BSD-3 | ❌ paused | Patterns only, per repo banner |
39
+ | torchchat | BSD-3 | active | Inference only — out of scope |
40
+
41
+ ## Decision
42
+
43
+ **Add PRIME-RL as the third RL framework after TRL+VeRL, and Monarch as the
44
+ agentic-stack coordination layer.**
45
+
46
+ ### Why PRIME-RL
47
+
48
+ PRIME-RL ships a **first-class `CustomLossConfig` with an `import_path`**
49
+ that lets us drop in a Python function returning a tensor. The config
50
+ exposes a `LossInputs` struct with exactly the tensors we need:
51
+ `trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`,
52
+ `advantages`, `loss_mask`. This is **the cleanest possible extension
53
+ point for a 3-channel loss** — no fork, no Trainer subclass, no monkey-
54
+ patching.
55
+
56
+ It also uses the `verifiers` env protocol (OpenEnv-compatible by design),
57
+ so it slots into the framework's existing data path without translation.
58
+
59
+ PRIME-RL was used to train INTELLECT-1 (10B base, 30 nodes) and INTELLECT-2
60
+ (32B QwQ); production-tested on real distributed runs.
61
+
62
+ ### Why Monarch (not TorchForge or TorchTitan as a top-level)
63
+
64
+ - **Monarch is what's actually shipping** from Meta's agentic stack. v0.4.1
65
+ is stable, v0.5 dev daily. BSD-3.
66
+ - **TorchForge is paused** per its own repo banner. We document it
67
+ (research/03) but don't depend on it.
68
+ - **TorchTitan is a transitive dep** of PRIME-RL already, so we get its
69
+ benefits without needing to build a direct integration. If we wanted a
70
+ TorchTitan-only path, it would be redundant with PRIME-RL.
71
+ - **torchchat is inference-only** and doesn't fit the training-framework
72
+ conversation.
73
+
74
+ Monarch's role in our stack: **the actor mesh that hosts trainer/generator/
75
+ rewarder/judge actors**. PRIME-RL's three-actor split (trainer, generator,
76
+ rewarder) maps naturally onto Monarch primitives.
77
+
78
+ ## Consequences
79
+
80
+ ### Accepted
81
+
82
+ - `composer_replication/recipes/prime_rl/` directory:
83
+ - `prime_rl_recipe.md` — integration recipe (parallel to TRL Recipe A,
84
+ VeRL Recipe B)
85
+ - `composer_loss.py` — the 3-channel loss adapted to PRIME-RL's
86
+ `LossInputs` struct (~200-300 LOC)
87
+ - `prime_rl_config.yaml` — example PRIME-RL config wiring our loss in
88
+ - `composer_replication/recipes/monarch/` directory:
89
+ - `monarch_actor_layout.md` — design doc for the actor mesh
90
+ - `actors.py` — placeholder Monarch actor definitions (skeleton only;
91
+ full integration is post-replication)
92
+ - New optional dependencies in `pyproject.toml`:
93
+ - `[prime-rl]` extra: `prime-rl>=0.5`
94
+ - `[monarch]` extra: `monarch>=0.4.1`
95
+ - `docs/V3_SUBSTRATE_COVERAGE.md` updated to reflect the new additions.
96
+
97
+ ### Three-recipe production matrix
98
+
99
+ | User scenario | Recommended recipe |
100
+ |---|---|
101
+ | Quick start, single-cluster, ≤7B | TRL Recipe A |
102
+ | Production multi-node, ≤32B | VeRL Recipe B |
103
+ | Decentralized / DiLoCo-shape, any size | PRIME-RL recipe (NEW) |
104
+ | Coordination-heavy multi-actor RL | Monarch + any of the above |
105
+
106
+ ### Trade-offs explicitly accepted
107
+
108
+ - **Three RL frameworks is a maintenance burden.** We accept this because
109
+ no single one covers all the user scenarios above. The framework's
110
+ contribution is the 3-channel loss + the trace-replay channel, expressed
111
+ in three different framework idioms. Each recipe is ~200-300 LOC; total
112
+ triplication tax ~700 LOC vs. picking one framework.
113
+ - **Monarch is BSD-3 not MIT.** The framework is MIT; users opting in to
114
+ Monarch take on its license. Documented in pyproject.toml's optional
115
+ extras.
116
+ - **PRIME-RL's API may evolve.** The `LossInputs` struct is currently the
117
+ contract; if PRIME-RL stabilizes a different shape we'd need to bump.
118
+ Pin to v0.5.x in our optional extras.
119
+
120
+ ## Source
121
+
122
+ `docs/research/RL_FRAMEWORKS_LANDSCAPE.md` (2026-05-26 subagent recon,
123
+ primary-sourced from DeepWiki audits + GitHub repo READMEs + PyPI release
124
+ metadata).
docs/adrs/ADR-007-self-distillation-losses.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADR-007 — Self-distillation losses landscape and which to add
2
+
3
+ **Status**: Accepted
4
+ **Date**: 2026-05-26
5
+ **Wave**: 13
6
+
7
+ ## Context
8
+
9
+ The framework currently has **one** distillation loss: `generalized_jsd_loss`
10
+ (verified port of `siyan-zhao/OPSD`, the kernel of SDPO arXiv:2601.20802 —
11
+ Composer 2.5's "targeted RL with textual feedback").
12
+
13
+ User's 2026-05-26 expansion: *"if we can properly document and research the
14
+ self-distillation papers like SDPO OPDS and/or others that are related
15
+ then we can take stuff from there to help level up our training framework."*
16
+
17
+ `docs/research/SELF_DISTILLATION_LANDSCAPE.md` audited 8 candidate methods
18
+ across primary sources (arXiv abstracts + verified GitHub repos):
19
+
20
+ | Method | arXiv | License | Verdict |
21
+ |---|---|---|---|
22
+ | **SimPO** | **2405.14734** | **MIT, mature** | **Chosen — drop-in DPO replacement, no ref model** |
23
+ | KTO | 2402.01306 | Apache-2 (in trl) | Optional — only if channel-3 moves to per-step binary |
24
+ | Self-Rewarding LM | 2401.10020 | research | Reject — procedure not loss |
25
+ | MiniLLM | 2306.08543 | MIT | Reject — same reverse-KL family as SDPO |
26
+ | GKD | 2306.13649 | research | Already lifted (= our `generalized_jsd_loss`) |
27
+ | DistiLLM | 2402.03898 | MIT | Reject — TAID dominates empirically |
28
+ | **TAID** | **2501.16937** | **Apache-2, mature** | **Chosen — wraps existing JSD with annealed teacher** |
29
+ | **Entropy-Aware OPD** | **ICLR 2026 Spotlight** | **(code release pending)** | **Chosen — token-wise gated forward/reverse KL** |
30
+
31
+ ## Decision
32
+
33
+ **Add three composable self-distillation losses to the framework as a
34
+ pluggable distillation module:**
35
+
36
+ 1. **SimPO** — reference-free DPO replacement for channel 3
37
+ 2. **TAID** — annealed teacher interpolation that wraps existing JSD/SDPO
38
+ 3. **Entropy-Aware OPD** — token-wise mixture of forward and reverse KL
39
+
40
+ ### Why these three (and not the others)
41
+
42
+ #### SimPO (chosen)
43
+ - **Reference-free DPO**: removes the ref-model VRAM cost (which is the
44
+ single biggest memory tax of standard DPO).
45
+ - Uses average sequence log-prob with target margin γ instead of
46
+ ref-policy logits.
47
+ - ~80 LOC. MIT licensed.
48
+ - **Composes**: drop-in for channel 3 (`trace_replay_dpo`). Our DPO and
49
+ SimPO are interchangeable at the loss level — both consume `(chosen,
50
+ rejected)` pairs and emit a scalar. SimPO drops the ref logprobs from
51
+ the input dict.
52
+
53
+ #### TAID (chosen)
54
+ - **Annealed Interpolated Distillation**: wraps the existing JSD with a
55
+ schedule that interpolates between identity (student-only target) and
56
+ teacher target over training. Provably prevents mode collapse on
57
+ large-capacity-gap distillation.
58
+ - ~150 LOC. Apache-2.
59
+ - **Composes**: TAID *wraps* `generalized_jsd_loss`, doesn't replace it.
60
+ Our `compose_loss` gets a `taid_alpha` schedule kwarg; when 0 it's
61
+ pure SDPO, when scheduled it's TAID-SDPO.
62
+
63
+ #### Entropy-Aware OPD (chosen, with caveat)
64
+ - **Token-wise gated mixture** of forward and reverse KL based on per-
65
+ token teacher entropy. Directly fixes a documented failure mode of the
66
+ reverse-KL family (which SDPO/OPSD belongs to).
67
+ - ICLR 2026 Spotlight. **Code release pending** as of 2026-05-26.
68
+ - ~120 LOC.
69
+ - **Composes**: also wraps `generalized_jsd_loss`, but with a per-token
70
+ weighting tensor instead of a global schedule.
71
+ - **Caveat**: we'll vendor a clean-room implementation from the paper
72
+ pseudocode until the official code drops. License question: vendoring
73
+ from a paper's pseudocode is fair use; redistributing the official code
74
+ when it drops requires checking its license.
75
+
76
+ ### Why we explicitly reject the others
77
+
78
+ - **GKD**: already lifted as `generalized_jsd_loss`. No additional value.
79
+ - **DistiLLM**: skew-KL is in the same reverse-KL family. TAID dominates
80
+ it empirically per the TAID paper.
81
+ - **MiniLLM**: same reverse-KL recipe as SDPO. We already have SDPO.
82
+ - **Self-Rewarding LM**: a procedure (model judges its own outputs to
83
+ generate preference pairs), not a loss. If we want self-judging, that's
84
+ a separate spike on the trace-replay side — not a loss-channel addition.
85
+ - **KTO**: only useful if the channel-3 shape moves from preference pairs
86
+ to per-step binary signals. Not currently in scope. Documented as a
87
+ fallback for future use.
88
+
89
+ ## Consequences
90
+
91
+ ### Accepted
92
+
93
+ - New module `composer_replication.distillation`:
94
+ - `__init__.py` — re-exports the three new losses
95
+ - `simpo.py` — `simpo_loss(chosen_lp, rejected_lp, beta, gamma)` (~80 LOC)
96
+ - `taid.py` — `taid_loss(student_logits, teacher_logits, alpha,
97
+ schedule_step, total_steps, **jsd_kwargs)` (~150 LOC)
98
+ - `entropy_aware_opd.py` — `entropy_aware_opd_loss(student_logits,
99
+ teacher_logits, **jsd_kwargs)` (~120 LOC)
100
+ - `tests/test_distillation_losses.py` — 17 sanity tests (loss is finite,
101
+ differentiable, returns scalar, matches paper formulas at boundary
102
+ conditions)
103
+
104
+ ### Wave 14+ work — `compose_loss` integration is NOT in this wave
105
+
106
+ An earlier draft of this ADR claimed `composer_replication.compose_loss`
107
+ would receive new kwargs (`dpo_variant`, `sdpo_wrapper`, `taid_schedule_step`,
108
+ `taid_total_steps`). **The Wave 13 cross-model review
109
+ (docs/research/WAVE_13_FINAL_REVIEW.md Finding 2) flagged that those
110
+ kwargs were never actually added to `compose_loss`** — the standalone
111
+ losses landed but the integration into the framework's loss composition
112
+ is not done. To stay honest:
113
+
114
+ - **What works in Wave 13**: `from composer_replication.distillation
115
+ import simpo_loss, taid_loss, entropy_aware_opd_loss` — all three are
116
+ importable, type-checked, unit-tested, and ready to be called directly.
117
+ - **What does NOT work in Wave 13**: passing
118
+ `compose_loss(model, batch, dpo_variant="simpo", sdpo_wrapper="taid", ...)`.
119
+ That call signature does not exist; it would raise `TypeError`.
120
+ - **Wave 14 plan**: add the four kwargs to `compose_loss` with a small
121
+ integration test exercising at least one combination (SDPO+TAID + plain
122
+ DPO would suffice). Estimated ~30 LOC + 2-3 tests.
123
+
124
+ Users wanting the new losses *now* should use them as standalone
125
+ functions in their own loss-composition code:
126
+
127
+ ```python
128
+ from composer_replication.distillation import simpo_loss, taid_loss
129
+
130
+ # Drop-in DPO replacement:
131
+ ch3 = simpo_loss(chosen_avg_lp, rejected_avg_lp, beta=2.0, gamma=1.0)
132
+
133
+ # TAID-wrapped SDPO (channel 2):
134
+ ch2 = taid_loss(
135
+ student_logits, teacher_logits, student_init_logits,
136
+ schedule_step=trainer.state.step, total_steps=trainer.state.max_steps,
137
+ )
138
+
139
+ total = grpo_loss + alpha * ch2 + beta * ch3
140
+ ```
141
+
142
+ This is identical to what the integrated path would do — the integration
143
+ is a convenience kwarg layer, not a different algorithm.
144
+
145
+ ### `pyproject.toml` impact
146
+
147
+ No new deps — these are pure PyTorch losses on top of existing tensors.
148
+
149
+ ### Trade-offs
150
+
151
+ - **Combinatorial complexity**: with three options for channel 2 and two
152
+ options for channel 3, we have 6 distillation variants. We accept this
153
+ because:
154
+ - Defaults are sane (`dpo_variant="dpo"`, `sdpo_wrapper="none"`)
155
+ - Each variant is independently unit-tested
156
+ - Users opt into combinations explicitly
157
+ - **Entropy-Aware OPD is pre-code-release**: we vendor from paper
158
+ pseudocode. Risk: our implementation might disagree with the official
159
+ release. Mitigation: clear-room note in the source file; bump pin
160
+ if/when official code drops.
161
+
162
+ ### Future work
163
+
164
+ - v0.2: research **direct preference fine-tuning** variants (DRO, PRO,
165
+ IPO) that might replace channel 3 entirely. These are off the chosen
166
+ axis but might dominate.
167
+ - v0.3: integrate the three new losses with PRIME-RL's `CustomLossConfig`
168
+ (per ADR-006) so users can mix-and-match across frameworks.
169
+
170
+ ## Source
171
+
172
+ `docs/research/SELF_DISTILLATION_LANDSCAPE.md` (2026-05-26 subagent recon,
173
+ primary-sourced from arXiv + GitHub READMEs).
docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiLoCo Serverless Executor Reconnaissance
2
+
3
+ **Status:** Reconnaissance complete (feeds ADR-005).
4
+ **Audience:** ADR-005 author + framework integrator wiring `composer_replication.diloco.serverless` against real backends.
5
+ **Scope:** Decoupled DiLoCo across N independently-scheduled serverless GPU jobs. NOT a generic "serverless training" survey.
6
+ **Date:** 2026-05-26.
7
+
8
+ ---
9
+
10
+ ## TL;DR
11
+
12
+ | Executor | Inter-job net? | Cold start | $/A100·hr (1×) | $/H100·hr (1×) | Max // jobs | Ranking for Decoupled DiLoCo |
13
+ |---|---|---|---|---|---|---|
14
+ | **Modal** | ✅ `i6pn` + `@modal.experimental.clustered` (50 Gbps + RDMA up to 3.2 Tbps); also same-workspace TCP via shared `Dict`/`Queue` | ~1–10 s warm-boot; ≤90 s incl. image pull on first run | A100-40GB: $2.10; A100-80GB: $2.50 | H100: $3.95 | Workspace quota; Starter ≤10 GPU containers, Team much higher (contact) | **★★★★★** primary adapter |
15
+ | **HF Jobs** | ❌ No documented inter-job networking. Workaround: object store (HF Hub bucket / dataset / S3) | "starting" → "running" billed; per-min granularity; typical scheduling 10–60 s | A100-80GB: $2.50 (`a100-large`); 4×: $10.00; 8×: $20.00 | H200: $5.00; 8×H200: $40.00 (no H100 SKU) | Pro/Team/Enterprise quota; not publicly capped per-run (parallel via SDK loop) | **★★★★☆** secondary adapter; pseudo-grad via Hub bucket Volume |
16
+ | **AWS SageMaker Training Jobs** | ✅ Inside one *job's* multi-instance cluster (EFA/SMDDP). ❌ Across separate `CreateTrainingJob` invocations — same workaround as HF | Image pull + EBS attach: typically 2–5 min cold; warm pools cut to ~10 s for ≤60 min | ml.p4d.24xlarge ≈ $32.77/hr (8×A100-40GB) ≈ $4.10/A100·hr | ml.p5.48xlarge ≈ $98.32/hr ≈ $12.29/H100·hr | Account quota (typical 4–20 instances; raise via Service Quotas) | **★★★☆☆** good for one big "fragment"; clunky as N-replicas-of-1-GPU |
17
+ | **GCP Vertex AI Custom Jobs** | ✅ Inside one CustomJob's worker pools (gRPC/MPI). ❌ Across separate jobs — same workaround | 2–6 min typical cold | a2-highgpu-1g (1×A100-40GB) ≈ $3.67/hr (incl. Vertex training premium ~30–50%) | a3-highgpu-8g ≈ $88/hr ≈ $11/H100·hr | Per-region GPU quota | **★★☆☆☆** highest premium per GPU; useful as 3rd region |
18
+ | **Azure ML Command Jobs** | ✅ within `instance_count>1` (InfiniBand on `ND*`-series). ❌ across jobs — same workaround | 3–8 min typical cold (image cache → curated env helps) | NC24ads_A100_v4 (1×A100-80GB): ~$3.67/hr (PAYG list) | ND96isr_H100_v5 (8×H100): ~$98/hr ≈ $12.25/H100·hr | Per-region quota, surcharge $0/core (only VM+disk) | **★★☆☆☆** like Vertex; useful only if user already lives in Azure |
19
+ | **k8s + Volcano / KubeRay** | ✅ if cluster networked. Volcano gang-schedules `RayJob`/MPIJob; pods see each other on cluster network | Pod schedule: seconds–minutes (image cache, GPU availability) | Whatever the underlying cluster pays (e.g. spot A100 ~$1–2/hr on RunPod / Lambda / OCI K8s) | Same | Cluster capacity | **★★★★☆** best price/perf if user owns/leases a cluster; ops cost nontrivial |
20
+ | **RunPod (honourable mention)** | ✅ same DC; no documented federation | seconds | ~$1.19/hr A100-80GB community, ~$2.17/hr secure | ~$1.99/hr H100 community, ~$4.18/hr secure | Account quota | **★★★☆☆** — not in the candidate list but a strong third adapter for cost |
21
+
22
+ The Decoupled DiLoCo framing kills the "must have inter-job allreduce" requirement: per the original DiLoCo paper (arXiv:2311.08105 §3.2), pseudo-gradients are exchanged **once every H = 500–1000 inner steps**, totalling KB-to-MB of gradient data per round. **Bandwidth is irrelevant; latency is irrelevant; the only requirement is "all N replicas can read & write a shared blob store."** That makes object-storage-based pseudo-gradient exchange the *correct* default, and the Modal `clustered`-style RDMA fabric a *bonus* you can opt into when a single executor runs ≥2 replicas in the same region.
23
+
24
+ **Recommendation: ship the framework with two adapters — `ModalExecutor` and `HFJobsExecutor` — both speaking the same `Executor` ABC, both using object-store pseudo-grad exchange by default. Add a third adapter (`RunPodExecutor` or `K8sExecutor`) when a user needs it.**
25
+
26
+ ---
27
+
28
+ ## 1. Why Decoupled DiLoCo over the network is *easy*
29
+
30
+ From DiLoCo (Douillard et al., *DiLoCo: Distributed Low-Communication Training of Language Models*, arXiv:2311.08105):
31
+
32
+ - **Setup.** N "workers" each train a full local copy of the model with an inner optimizer (AdamW, LR 4e-4, etc.) on disjoint shards of data.
33
+ - **Outer round (every H=500 steps in the paper, often 1000 in follow-ups).** Each worker computes its **pseudo-gradient** `δ_k = θ_initial − θ_local` (the negative of its accumulated local update). The N workers all-reduce the pseudo-gradient, average it, and the outer optimizer (Nesterov SGD, lr=0.7, momentum=0.9) applies it to `θ_initial` to produce `θ_initial^(t+1)`. Workers reset to that.
34
+ - **Communication budget per round.** One full-model parameter tensor per worker (FP32, fp16, or bf16). For a 1B model in bf16, that's ~2 GB per worker per round. For Streaming DiLoCo (Liu et al. 2025) the communication is sliced into fragments and overlapped with compute, but the *aggregate* per round is the same.
35
+ - **Communication frequency.** Once per H=500–1000 inner steps. With one inner step ≈ 1–3 s on a single A100/H100 for a 7B model, that's one outer round every **~10–30 minutes wall-clock**.
36
+
37
+ The implication: **the outer-loop "allreduce" is a one-shot 2–10 GB upload+download every 10+ minutes.** It does not need NCCL. It does not need RDMA. It does not even need TCP between the replicas. **An S3 `PutObject` followed by N `GetObject`s is sufficient.** Cross-region transfer at 1 Gbps moves 2 GB in ~17 s; even at 100 Mbps it's ~3 min — small compared to the H=500 inner-step interval. This is the key insight that makes "Modal + HuggingFace Jobs as DiLoCo replicas" actually a sensible architecture rather than a hack.
38
+
39
+ We codify this in the framework with two communication backends:
40
+
41
+ 1. **`InProcessAllReduce`** — what `composer_replication.diloco` already uses (torchft `Manager` mock). For unit tests and same-process/same-host runs.
42
+ 2. **`ObjectStoreAllReduce`** — barriers + pseudo-grad averaging via S3/GCS/HF Hub bucket. New code for ADR-005. Expected per-round overhead 20–60 s for a 7B model — already amortised over 10–30 min of compute.
43
+
44
+ The torchft `Manager` interface (used by `torchft.local_sgd.DiLoCo`) only requires `.allreduce(tensor) → Work`, `.should_commit()`, `.start_quorum()`, `.current_step()`. We implement `.allreduce` on top of object storage. Done.
45
+
46
+ ---
47
+
48
+ ## 2. Per-executor audit
49
+
50
+ ### 2.1 Modal — primary adapter
51
+
52
+ **Inter-job networking.** Yes, in two flavours.
53
+
54
+ - **`@modal.experimental.clustered(size=N, rdma=True)`**: gang-schedules N containers in the *same* Modal cluster, gives them i6pn IPv6 addresses, and (with `rdma=True`) provisions InfiniBand RoCE up to 3,200 Gbps for inter-node communication. ([modal.com/docs/guide/multi-node-training](https://modal.com/docs/guide/multi-node-training)). This is the right primitive for a *single-executor* multi-replica DiLoCo where all N replicas live on Modal.
55
+ - **i6pn private network** ([modal.com/docs/guide/private-networking](https://modal.com/docs/guide/private-networking)): any two `@app.function(i6pn=True)` containers in the same workspace+region can address each other over a 50 Gbps IPv6 fabric. Region-scoped — Modal documents that "i6pn networking is region-scoped functionality."
56
+
57
+ **Cross-executor:** for the *cross-cloud* Decoupled DiLoCo case (Modal + HF + …), Modal containers reach out to S3/HF Hub/GCS like any other internet-connected workload. No Modal-specific magic needed.
58
+
59
+ **Cold start.** Modal's container infra warm-boots in ~1 s for a cached image; first-run pulls of a large PyTorch image dominate (30–90 s). HF model download adds 15–45 s for a 7B model from cold (cache on a `modal.Volume` after run 1). See `MODAL_RECONNAISSANCE.md` §1.3 in this repo for the same numbers from a different audit angle. Realistic per-run cold: **~60–120 s** on first launch, ~10–30 s on subsequent launches with warm image cache.
60
+
61
+ **$/GPU·hr (from <https://modal.com/pricing>, on-demand, base region, preemptible default).**
62
+
63
+ | GPU | Modal `gpu=` string | $/sec | $/hour |
64
+ |---|---|---|---|
65
+ | A100-40GB | `"A100-40GB"` | 0.000583 | **$2.099** |
66
+ | A100-80GB | `"A100-80GB"` | 0.000694 | **$2.498** |
67
+ | H100 (pinned) | `"H100!"` | 0.001097 | **$3.949** |
68
+ | H200 | `"H200"` | (see pricing page) | ~$4.5–5/hr per the published table |
69
+ | B200 | `"B200"` | — | ~$6/hr per the published table |
70
+
71
+ **Multipliers from same pricing page:** region pinning 1.5–1.75×, non-preemptible 3×. Default is preemptible — for DiLoCo this is *fine*: a preempted replica retries, the outer loop tolerates an absent-this-round member by simply averaging over the survivors.
72
+
73
+ **Max concurrent jobs.** Modal documents "default limits on Modal free tier" of 10 GPU containers in [the Blender example](https://modal.com/docs/examples/blender_video) (`max_containers=10 if WITH_GPU else 100`). Paid plans scale far higher; clustered functions starting May 31, 2026 require 8 GPUs/node, capping at "up to 64 devices" per cluster (`@clustered`). Practically, for 8 single-A100 replicas of Decoupled DiLoCo, the Starter plan is limiting; Team plan ≥10 paid GPU containers handles it. Contact Modal support for >64-GPU clusters.
74
+
75
+ **Verified API for spinning up N parallel jobs** (verified pattern from `modal-examples` and Modal docs):
76
+
77
+ ```python
78
+ # composer_replication/diloco/serverless/_modal_adapter.py
79
+ import modal
80
+
81
+ app = modal.App("diloco-replicas")
82
+ image = (
83
+ modal.Image.debian_slim(python_version="3.11")
84
+ .uv_pip_install("torch", "transformers", "torchft-nightly")
85
+ .add_local_python_source("composer_replication")
86
+ )
87
+
88
+ @app.function(image=image, gpu="A100-40GB", timeout=60 * 60 * 24)
89
+ def run_inner_loop(replica_id: int, rendezvous_uri: str, config: dict):
90
+ """One DiLoCo replica. Trains for N inner steps, then participates in
91
+ one outer-round pseudo-gradient exchange via the rendezvous_uri (S3 path),
92
+ repeats."""
93
+ from composer_replication.diloco.serverless import run_replica
94
+ return run_replica(replica_id=replica_id,
95
+ rendezvous_uri=rendezvous_uri,
96
+ **config)
97
+
98
+ @app.local_entrypoint()
99
+ def main(num_replicas: int = 4):
100
+ rendezvous_uri = "s3://my-bucket/diloco-run-2026-05-26/"
101
+ config = {"model": "Qwen/Qwen2.5-7B", "outer_rounds": 100, "sync_every": 500}
102
+ # .map / .starmap fans out N parallel container invocations.
103
+ args = [(i, rendezvous_uri, config) for i in range(num_replicas)]
104
+ results = list(run_inner_loop.starmap(args))
105
+ print(f"All {num_replicas} replicas completed: {results}")
106
+ ```
107
+
108
+ For the *single-executor RDMA* case (all N on Modal in one region, max throughput):
109
+
110
+ ```python
111
+ @app.function(gpu="H100:8", timeout=60 * 60 * 24)
112
+ @modal.experimental.clustered(size=4, rdma=True)
113
+ def diloco_cluster_train(rendezvous_uri: str, config: dict):
114
+ info = modal.experimental.get_cluster_info()
115
+ # info.rank is our DiLoCo replica id; info.container_ips[0] is rank-0.
116
+ return run_replica(replica_id=info.rank, rendezvous_uri=rendezvous_uri, **config)
117
+ ```
118
+
119
+ **Right abstraction layer for the framework.** Modal Functions map to **one DiLoCo replica each**. The local entrypoint (or our `Executor.launch_replicas()`) does `.starmap` to fan out N. Inter-replica state lives in S3 (default) or in Modal-side `modal.Dict` / `modal.Queue` (faster, same-workspace only). The `@clustered` decorator is *not* required for Decoupled DiLoCo — it's an opt-in optimization for when you want one Modal cluster to be your whole training run.
120
+
121
+ **Rough $-per-replica-hour for an A100-40GB single-replica Modal run** (no clustering): 1 × $2.099 + ~$0.05 CPU/RAM overhead + ~$0.005 networking ≈ **$2.16/hr/replica**.
122
+
123
+ ### 2.2 HuggingFace Jobs — secondary adapter
124
+
125
+ **Inter-job networking.** **No documented inter-job networking primitive.** HF Jobs is a Docker-Image-+-command service ([huggingface.co/docs/hub/en/jobs](https://huggingface.co/docs/hub/en/jobs)) modelled after `docker run`. There is no "address my peer job" API. Each job runs in its own pod with internet egress only; HF does not advertise a private VPC network.
126
+
127
+ **Workaround (the right one for DiLoCo).** HF Jobs supports **`Volume` mounts** of HF Hub repos and HF storage buckets ([huggingface.co/docs/huggingface_hub/en/guides/jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs)):
128
+
129
+ ```python
130
+ from huggingface_hub import run_job, Volume
131
+ checkpoints_bucket = Volume(type="bucket", source="myorg/diloco-rendezvous", mount_path="/rendezvous")
132
+ job = run_job(image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
133
+ command=["python", "/code/run_replica.py", "--replica-id", "0"],
134
+ flavor="a100-large",
135
+ timeout="6h",
136
+ volumes=[checkpoints_bucket])
137
+ ```
138
+
139
+ The `bucket` volume is read+write by default — perfect for object-store-based pseudo-gradient exchange. This is *exactly* the same workaround we'd apply to SageMaker, Vertex AI, Azure ML — but on HF it's first-class because `Volume(type="bucket", ...)` is built into the API.
140
+
141
+ **Cold start.** HF docs say "billing only when starting or running" — no charge during build. Empirically (per the HF quickstart logs), `hf jobs uv run` reports a state transition `created → starting → running` typically in **10–60 s** for a cached image, longer for first-pull of a large CUDA image. The default timeout is 30 minutes; use `timeout="6h"` or similar for DiLoCo.
142
+
143
+ **$/GPU·hr (from <https://huggingface.co/docs/hub/jobs-pricing>; per-minute billing).**
144
+
145
+ | Hardware flavor | Hourly | $/A100·hr | $/H100/H200·hr |
146
+ |---|---|---|---|
147
+ | `a100-large` (1× A100 80GB) | **$2.50** | $2.50 | — |
148
+ | `4xa100-large` (4× A100 80GB) | $10.00 | $2.50 | — |
149
+ | `8xa100-large` (8× A100 80GB) | $20.00 | $2.50 | — |
150
+ | `h200` (1× H200 141GB) | $5.00 | — | $5.00 (H200, not H100) |
151
+ | `4xh200` | $20.00 | — | $5.00 |
152
+ | `8xh200` | $40.00 | — | $5.00 |
153
+ | `l40sx1` | $1.80 | — | — |
154
+ | `a10g-large` | $1.50 | — | — |
155
+ | `t4-small` | $0.40 | — | — |
156
+
157
+ **No H100 SKU is published** as of this write — HF jumps from A100→H200. Treat HF's "$5/hr H200" as the H100-equivalent line item.
158
+
159
+ **Max concurrent jobs.** HF documents "Jobs are available to any user or organization with a positive credit balance" but doesn't publish a per-account concurrency cap. The Python SDK pattern in their docs:
160
+
161
+ ```python
162
+ # Verified — direct from huggingface.co/docs/huggingface_hub/en/guides/jobs
163
+ jobs = [run_job(image=image, command=command) for command in commands]
164
+ for job in jobs:
165
+ while inspect_job(job_id=job.id).status.stage not in ("COMPLETED", "ERROR"):
166
+ time.sleep(10)
167
+ ```
168
+
169
+ …clearly assumes a "spawn N, poll N" model. Empirically, Pro accounts can run several jobs in parallel; Enterprise plans are higher.
170
+
171
+ **Verified API for spinning up N parallel jobs:**
172
+
173
+ ```python
174
+ # composer_replication/diloco/serverless/_hf_jobs_adapter.py
175
+ from huggingface_hub import run_job, run_uv_job, inspect_job, fetch_job_logs, Volume
176
+
177
+ def spawn_diloco_replica(replica_id: int, num_replicas: int, rendezvous_repo: str):
178
+ return run_job(
179
+ image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
180
+ command=["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint",
181
+ "--replica-id", str(replica_id),
182
+ "--num-replicas", str(num_replicas),
183
+ "--rendezvous-uri", "/rendezvous"],
184
+ flavor="a100-large",
185
+ timeout="12h",
186
+ env={"HF_HUB_ENABLE_HF_TRANSFER": "1"},
187
+ secrets={"HF_TOKEN": "<token>"},
188
+ volumes=[Volume(type="bucket", source=rendezvous_repo, mount_path="/rendezvous")],
189
+ )
190
+
191
+ def spawn_n(num_replicas: int, rendezvous_repo: str = "myorg/diloco-rendezvous-2026-05-26"):
192
+ jobs = [spawn_diloco_replica(i, num_replicas, rendezvous_repo) for i in range(num_replicas)]
193
+ return jobs # list[JobInfo]
194
+ ```
195
+
196
+ The `Volume(type="bucket", ...)` is the secret weapon. Each replica writes its pseudo-gradient to a unique key under `/rendezvous/round-{t}/replica-{i}.pt`, then waits on a barrier file (busy-loop on `os.path.exists` with sleeps). The leader rank averages and writes `/rendezvous/round-{t}/avg.pt`. Standard object-store DiLoCo pattern.
197
+
198
+ **Right abstraction.** Same as Modal: one `run_job` = one DiLoCo replica. Fan-out via list comprehension. No special multi-node primitive — and we don't need one for Decoupled DiLoCo.
199
+
200
+ ### 2.3 AWS SageMaker Training Jobs
201
+
202
+ **Inter-job networking.** SageMaker has *intra-job* multi-node networking (`InstanceCount > 1` provisions a single EFA/InfiniBand-connected cluster, suitable for SMDDP `AllReduce` with `pytorchddp` or `torch_distributed` launchers — see [docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-framework-estimator.html](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-framework-estimator.html)). It does **not** have *inter-job* networking — two separate `CreateTrainingJob` calls produce two isolated VPCs (unless you wire a shared customer VPC, which is non-trivial and Decoupled DiLoCo doesn't benefit from anyway).
203
+
204
+ **Workaround.** S3. Each SageMaker training job has read+write access to S3 by default (via the IAM role passed to `CreateTrainingJob`). Pseudo-gradient exchange via `s3://bucket/diloco-run/round-{t}/replica-{i}.pt` is straightforward.
205
+
206
+ **Cold start.** SageMaker docs and the cost-optimization blog post acknowledge five phases: Starting, Downloading, Training, Uploading, Completed. The Starting+Downloading phases are the cold start and **typically take 2–5 minutes**: image pull from ECR, EBS volume attach, `boto3` IAM role fetch, container init. **Warm pools** ([docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html](https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html)) cut subsequent matching jobs to ~10 s by retaining the cluster up to `KeepAlivePeriodInSeconds` (max 3600 s = 60 min) — *but matching requires identical RoleArn/InstanceType/InstanceCount/VpcConfig*, so warm pools work for "rerun the same DiLoCo replica config" but not for heterogeneous fleets.
207
+
208
+ **$/GPU·hr (from [aws.amazon.com/sagemaker/ai/pricing/](https://aws.amazon.com/sagemaker/ai/pricing/), training tab, US East regions; per-second billing).** SageMaker training instances carry a ~20–25% premium over raw EC2 because the service includes managed orchestration. Pricing varies by region; representative US East values:
209
+
210
+ | Instance | GPUs | $/hr (training) | $/GPU·hr |
211
+ |---|---|---|---|
212
+ | ml.p4d.24xlarge | 8× A100-40GB | ≈ $32.77 | ≈ **$4.10/A100·hr** |
213
+ | ml.p4de.24xlarge | 8× A100-80GB | ≈ $40.97 | ≈ $5.12/A100·hr |
214
+ | ml.p5.48xlarge | 8× H100-80GB | ≈ $98.32 | ≈ **$12.29/H100·hr** |
215
+ | ml.g5.48xlarge | 8× A10G-24GB | ≈ $10.18 (per HyperPod example) | ≈ $1.27/A10G·hr |
216
+
217
+ (Hourly rates above are *training* rates inferred from SageMaker's published training-tab price calculator and the HyperPod ml.g5.24xlarge $10.18/hr example; consult the live pricing page in [aws.amazon.com/sagemaker/ai/pricing/](https://aws.amazon.com/sagemaker/ai/pricing/) for region-specific quotes. **Managed Spot Training** ([docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html)) cuts up to 80–90% — and DiLoCo tolerates spot well because outer round t can simply skip preempted replicas.)
218
+
219
+ **Per-A100 / per-H100 rates are the highest of any executor in this audit.** SageMaker is a poor choice for cost-sensitive Decoupled DiLoCo unless you already have committed savings plans or run on Spot.
220
+
221
+ **Max concurrent jobs.** AWS Service Quotas: per-account default is typically 4 (for ml.p4d.24xlarge) and 0 (for ml.p5.48xlarge — must request access). Both are raisable. There's a soft cap of 1000 active training jobs per account.
222
+
223
+ **Verified API for spinning up N parallel jobs** (using boto3, since `sagemaker` Python SDK abstracts away the parallel-launch case):
224
+
225
+ ```python
226
+ # composer_replication/diloco/serverless/_sagemaker_adapter.py
227
+ import boto3
228
+
229
+ sm = boto3.client("sagemaker", region_name="us-east-1")
230
+
231
+ def spawn_diloco_replica(replica_id: int, num_replicas: int, s3_rendezvous: str):
232
+ return sm.create_training_job(
233
+ TrainingJobName=f"diloco-replica-{replica_id}-{int(time.time())}",
234
+ AlgorithmSpecification={
235
+ "TrainingImage": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.4.0-gpu-py311-cu124-ubuntu22.04-sagemaker",
236
+ "TrainingInputMode": "File",
237
+ "ContainerEntrypoint": ["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint"],
238
+ "ContainerArguments": ["--replica-id", str(replica_id),
239
+ "--num-replicas", str(num_replicas),
240
+ "--rendezvous-uri", s3_rendezvous],
241
+ },
242
+ ResourceConfig={
243
+ "InstanceCount": 1, # one A100/H100 per replica
244
+ "InstanceType": "ml.p4d.24xlarge",
245
+ "VolumeSizeInGB": 200,
246
+ "KeepAlivePeriodInSeconds": 1800, # warm pool for fast subsequent launches
247
+ },
248
+ OutputDataConfig={"S3OutputPath": f"{s3_rendezvous}/output/replica-{replica_id}/"},
249
+ StoppingCondition={"MaxRuntimeInSeconds": 24*3600},
250
+ RoleArn="arn:aws:iam::ACCOUNT:role/SageMakerExecutionRole",
251
+ EnableManagedSpotTraining=True, # 80%+ savings, DiLoCo-tolerant
252
+ )
253
+
254
+ def spawn_n(num_replicas: int):
255
+ s3_rendezvous = "s3://my-diloco-bucket/run-2026-05-26"
256
+ return [spawn_diloco_replica(i, num_replicas, s3_rendezvous) for i in range(num_replicas)]
257
+ ```
258
+
259
+ (The `CreateTrainingJob` API spec is documented in full at [docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html).)
260
+
261
+ **Right abstraction.** Same shape: 1 training job = 1 DiLoCo replica. SageMaker's *intra-job* multi-node features (SMDDP, EFA, `instance_count=8`) are wasted if our framing is "N independent replicas"; they only help if a single replica is itself FSDP-sharded across instances, which we explicitly don't want for v0.x.
262
+
263
+ ### 2.4 GCP Vertex AI Custom Jobs
264
+
265
+ **Inter-job networking.** Same story as SageMaker: a single `CustomJob` can have multiple `workerPoolSpecs` (chief, workers, parameter servers, evaluator) on a private VPC; *separate* CustomJobs are isolated. Workaround: GCS bucket. Vertex's [configure-compute](https://cloud.google.com/vertex-ai/docs/training/configure-compute) doc covers single-node and multi-replica configurations for one job.
266
+
267
+ **Cold start.** Typical 2–6 min for cold image pull + VM provision. Vertex caches images in Artifact Registry; subsequent jobs in the same region with the same custom container start faster (~30–60 s).
268
+
269
+ **$/GPU·hr.** Vertex AI training prices = (Compute Engine VM rate) × (Vertex training premium ≈ 30–50%). From the Vertex Training SKU groups page ([cloud.google.com/skus/sku-groups/vertex-training](https://cloud.google.com/skus/sku-groups/vertex-training)) the SKUs include "Training - NVIDIA A100 80GB in Virginia" etc.; published list rate equivalents are roughly:
270
+
271
+ | Machine type | GPUs | $/hr (Vertex training, on-demand, us-central1) |
272
+ |---|---|---|
273
+ | `a2-highgpu-1g` | 1× A100-40GB | ≈ **$3.67/hr** |
274
+ | `a2-ultragpu-1g` | 1× A100-80GB | ≈ $5.07/hr |
275
+ | `a2-highgpu-8g` | 8× A100-40GB | ≈ $29.39/hr |
276
+ | `a3-highgpu-8g` | 8× H100-80GB | ≈ **$88.49/hr** ⇒ $11.06/H100·hr |
277
+ | `a3-megagpu-8g` | 8× H100-80GB (with NVSwitch) | ≈ $108/hr |
278
+
279
+ (Vertex AI pricing is the Compute Engine GPU rate plus a Vertex training premium that varies by region. The figures above are approximate list prices from public sources; confirm in the [Vertex AI pricing calculator](https://cloud.google.com/vertex-ai/pricing) before quoting.)
280
+
281
+ **Max concurrent jobs.** Per-region GPU quota (`NVIDIA_A100_GPUS`, `NVIDIA_H100_GPUS`, etc.) — typical default is 8 A100s per region, raise via Cloud Console quota request.
282
+
283
+ **Verified API for spinning up N parallel jobs** (using `google-cloud-aiplatform`):
284
+
285
+ ```python
286
+ # composer_replication/diloco/serverless/_vertex_ai_adapter.py
287
+ from google.cloud import aiplatform
288
+
289
+ aiplatform.init(project="my-project", location="us-central1",
290
+ staging_bucket="gs://my-diloco-bucket")
291
+
292
+ def spawn_diloco_replica(replica_id: int, num_replicas: int, gcs_rendezvous: str):
293
+ job = aiplatform.CustomJob.from_local_script(
294
+ display_name=f"diloco-replica-{replica_id}",
295
+ script_path="composer_replication/diloco/serverless/replica_entrypoint.py",
296
+ container_uri="us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.2-4.py311:latest",
297
+ args=["--replica-id", str(replica_id),
298
+ "--num-replicas", str(num_replicas),
299
+ "--rendezvous-uri", gcs_rendezvous],
300
+ machine_type="a2-highgpu-1g", # 1× A100-40GB per replica
301
+ accelerator_type="NVIDIA_TESLA_A100",
302
+ accelerator_count=1,
303
+ replica_count=1, # one replica, single-host
304
+ )
305
+ job.submit() # async; returns immediately
306
+ return job
307
+
308
+ def spawn_n(num_replicas: int):
309
+ gcs = "gs://my-diloco-bucket/run-2026-05-26"
310
+ return [spawn_diloco_replica(i, num_replicas, gcs) for i in range(num_replicas)]
311
+ ```
312
+
313
+ **Right abstraction.** Identical to SageMaker / HF / Modal: one `CustomJob.submit()` = one DiLoCo replica.
314
+
315
+ ### 2.5 Azure ML Command Jobs
316
+
317
+ **Inter-job networking.** Single `command` job with `resources.instance_count=N` provisions N coordinated nodes (InfiniBand on `ND*`-series); separate jobs are isolated. Workaround: Azure Blob Storage or Azure ML Datastore.
318
+
319
+ **Cold start.** 3–8 min from job submission to first-byte-of-stdout for a curated environment; longer for custom images. Curated environments (e.g., `AzureML-acpt-pytorch-2.8-cuda12.6@latest`) are pre-cached on the cluster's image cache.
320
+
321
+ **$/GPU·hr (from [azure.microsoft.com/en-us/pricing/details/machine-learning/](https://azure.microsoft.com/en-us/pricing/details/machine-learning/), GPU section, US West 2 PAYG list).**
322
+
323
+ | VM size | GPUs | Approx $/hr |
324
+ |---|---|---|
325
+ | Standard_NC24ads_A100_v4 | 1× A100-80GB | ≈ **$3.67/hr** |
326
+ | Standard_NC48ads_A100_v4 | 2× A100-80GB | ≈ $7.35/hr |
327
+ | Standard_ND96asr_A100_v4 | 8× A100-40GB (InfiniBand) | ≈ $27.20/hr |
328
+ | Standard_NC40ads_H100_v5 | 1× H100 NVL 94GB | ≈ $7/hr (regional) |
329
+ | Standard_ND96isr_H100_v5 | 8× H100-80GB (InfiniBand) | ≈ **$98/hr** ⇒ $12.25/H100·hr |
330
+
331
+ (Azure publishes $0/core ML "service surcharge" for these — you pay only the underlying VM rate. So the relevant hourly rate is the standard PAYG VM rate from Azure's pricing page, not a separate Azure ML markup. **Low-Priority** VMs cut up to 80% — DiLoCo-tolerant like SageMaker Spot.)
332
+
333
+ **Max concurrent jobs.** Per-subscription per-region GPU vCPU quota; typical default 0–24 cores for `ND*`-series, raise via Azure portal.
334
+
335
+ **Verified API for spinning up N parallel jobs** (using `azure-ai-ml` v2 SDK; pattern from [learn.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch)):
336
+
337
+ ```python
338
+ # composer_replication/diloco/serverless/_azure_ml_adapter.py
339
+ from azure.ai.ml import MLClient, command
340
+ from azure.identity import DefaultAzureCredential
341
+
342
+ ml_client = MLClient(DefaultAzureCredential(), subscription_id="...",
343
+ resource_group_name="...", workspace_name="...")
344
+
345
+ def spawn_diloco_replica(replica_id: int, num_replicas: int, blob_uri: str):
346
+ job = command(
347
+ code="./composer_replication",
348
+ command=("python -m composer_replication.diloco.serverless.replica_entrypoint "
349
+ f"--replica-id {replica_id} --num-replicas {num_replicas} "
350
+ f"--rendezvous-uri {blob_uri}"),
351
+ environment="AzureML-acpt-pytorch-2.8-cuda12.6@latest",
352
+ compute="gpu-cluster", # an AmlCompute pre-created with min_instances=0, max_instances=8
353
+ resources={"instance_count": 1},
354
+ display_name=f"diloco-replica-{replica_id}",
355
+ )
356
+ return ml_client.jobs.create_or_update(job)
357
+
358
+ def spawn_n(num_replicas: int):
359
+ blob = "azureml://datastores/workspaceblobstore/paths/diloco-run/"
360
+ return [spawn_diloco_replica(i, num_replicas, blob) for i in range(num_replicas)]
361
+ ```
362
+
363
+ **Right abstraction.** Same one-job-per-replica pattern.
364
+
365
+ ### 2.6 Kubernetes + Volcano / KubeRay
366
+
367
+ **Inter-job networking.** Native — pods on the same cluster see each other on the cluster network. Volcano provides **gang scheduling** (all-or-nothing pod admission, essential for "all N DiLoCo replicas start together" semantics) and **network-topology-aware scheduling** ([volcano.sh/en/docs/network_topology_aware_scheduling/](https://volcano.sh/en/docs/network_topology_aware_scheduling/)). KubeRay's `RayJob` resource integrates with Volcano (PR [ray-project/kuberay#3972](https://github.com/ray-project/kuberay/pull/3972), merged 2025-10-09) — `RayJob` + `volcano.sh/queue-name` label gives you gang-scheduled Ray clusters per job.
368
+
369
+ For Decoupled DiLoCo: **N RayJobs, each running one replica**, gang-scheduled via Volcano, sharing pseudo-grad through a `PersistentVolume` or in-cluster S3-compatible object store (MinIO).
370
+
371
+ **Cold start.** Pod schedule time depends on cluster state: seconds (pre-pulled image, free GPU node) to minutes (image pull + GPU node autoscale). Predictable on a steady-state cluster.
372
+
373
+ **$/GPU·hr.** **Whatever the underlying K8s cluster pays.** This is the *cheapest* tier in this audit if the user already runs a GPU K8s cluster (e.g., RunPod K8s, Lambda Cloud, OCI K8s, on-prem). Examples:
374
+
375
+ - RunPod community cloud K8s: ~$1.19/hr A100-80GB, ~$1.99/hr H100.
376
+ - Lambda K8s: ~$1.29/hr A100-40GB, ~$2.49/hr H100-80GB.
377
+ - On-prem owned hardware: amortized $0.50–$1.00 per A100/H100 hour.
378
+
379
+ **Max concurrent jobs.** Cluster capacity. Volcano's queue-based admission control + Kubernetes-native quotas govern this.
380
+
381
+ **Verified API for spinning up N parallel jobs** (Volcano `Job` + KubeRay pattern from the docs):
382
+
383
+ ```yaml
384
+ # k8s manifest, one per DiLoCo replica
385
+ apiVersion: batch.volcano.sh/v1alpha1
386
+ kind: Job
387
+ metadata: {name: diloco-replica-0}
388
+ spec:
389
+ minAvailable: 1
390
+ schedulerName: volcano
391
+ queue: diloco-queue
392
+ tasks:
393
+ - replicas: 1
394
+ name: replica
395
+ template:
396
+ spec:
397
+ containers:
398
+ - name: trainer
399
+ image: myorg/composer-replication:latest
400
+ command: ["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint",
401
+ "--replica-id", "0", "--num-replicas", "4",
402
+ "--rendezvous-uri", "s3://minio.cluster.local/diloco/"]
403
+ resources:
404
+ limits: {nvidia.com/gpu: 1}
405
+ restartPolicy: OnFailure
406
+ ```
407
+
408
+ …and the framework's `K8sExecutor` adapter does `kubectl apply -f` (or uses the Python K8s client) for each of N rendered manifests.
409
+
410
+ **Right abstraction.** Either one `volcano.batch.Job` per replica (simple, no Ray) or one `RayJob` per replica (overkill for DiLoCo, but useful if you want Ray Tune integration). One pod = one DiLoCo replica.
411
+
412
+ ### 2.7 RunPod / Lambda / Vast.ai (honourable mentions)
413
+
414
+ Not in the original candidate list, but worth one paragraph each because they're the price-leaders for serverless GPUs:
415
+
416
+ - **RunPod Serverless / Pods.** Cheap on-demand A100/H100 (~$1.19–$2.17/hr A100-80GB; ~$1.99–$4.18/hr H100). REST API `POST /v2/{endpoint}/run` for serverless; SDK `runpod` for pods. No native multi-job network — same S3 workaround. **Strong third adapter candidate** for a cost-optimised deployment.
417
+ - **Lambda Cloud (Lambda Labs).** Bare metal hourly rentals, not a true serverless API. Programmatic launch via `lambdalabs` API. Outside the "serverless" framing.
418
+ - **Vast.ai.** Bidding-style spot market. API-driven launches. Cheapest per A100·hr in the market, but variable availability.
419
+
420
+ We do **not** include these as v0 adapters but document them as "next-up after Modal + HF" if the user wants further price compression.
421
+
422
+ ---
423
+
424
+ ## 3. The right abstraction: `composer_replication.diloco.serverless`
425
+
426
+ ### 3.1 The core interface
427
+
428
+ ```python
429
+ # composer_replication/diloco/serverless/_protocol.py
430
+ from __future__ import annotations
431
+ from abc import ABC, abstractmethod
432
+ from dataclasses import dataclass
433
+ from typing import Any, Iterator, Protocol
434
+
435
+ @dataclass(frozen=True)
436
+ class ReplicaSpec:
437
+ """One DiLoCo replica's launch config. Mirrors `make_diloco_outer_loop()`'s
438
+ args (see composer_replication/diloco/__init__.py) plus a rendezvous_uri
439
+ for the object-store all-reduce backend."""
440
+ replica_id: int
441
+ num_replicas: int
442
+ rendezvous_uri: str # s3://, gs://, az://, hf://, file://
443
+ model_id: str # e.g. "Qwen/Qwen2.5-7B"
444
+ inner_optimizer: dict[str, Any] # serializable; reconstructed in worker
445
+ sync_every: int = 500
446
+ outer_lr: float = 0.7
447
+ outer_momentum: float = 0.9
448
+ outer_rounds: int = 100
449
+ extra_env: dict[str, str] | None = None
450
+
451
+ @dataclass(frozen=True)
452
+ class ReplicaHandle:
453
+ replica_id: int
454
+ backend: str # "modal" | "hfjobs" | "sagemaker" | ...
455
+ job_id: str
456
+ log_url: str | None = None
457
+
458
+ @dataclass(frozen=True)
459
+ class ReplicaResult:
460
+ replica_id: int
461
+ status: str # "completed" | "failed" | "preempted"
462
+ final_checkpoint_uri: str | None
463
+ metrics: dict[str, Any]
464
+
465
+ class ServerlessExecutor(Protocol):
466
+ """Protocol any serverless backend implements to host Decoupled DiLoCo."""
467
+
468
+ def launch_replicas(self, specs: list[ReplicaSpec]) -> list[ReplicaHandle]: ...
469
+ def poll(self, handles: list[ReplicaHandle]) -> list[ReplicaHandle]: ...
470
+ def stream_logs(self, handle: ReplicaHandle) -> Iterator[str]: ...
471
+ def cancel(self, handles: list[ReplicaHandle]) -> None: ...
472
+ def collect(self, handles: list[ReplicaHandle], *,
473
+ timeout: float | None = None) -> list[ReplicaResult]: ...
474
+
475
+ @property
476
+ def backend_name(self) -> str: ...
477
+
478
+ @property
479
+ def supports_inter_replica_network(self) -> bool:
480
+ """True iff backend natively connects replicas (e.g., Modal i6pn).
481
+ False = pseudo-grad must use rendezvous_uri object store. Default rendezvous
482
+ is *always* object-store regardless; this flag only unlocks an opt-in
483
+ same-backend fast path (see ModalExecutor(use_clustered_rdma=True))."""
484
+ ...
485
+ ```
486
+
487
+ Concrete adapters inherit from a small `BaseExecutor(ABC)` for cross-cutting retry/log/timeout, paralleling `composer_replication.trainer.composer_trainer`. `launch_replicas()` is partial-failure tolerant: on partial submit it returns handles for the K successful replicas with the failed one carrying `job_id=""` and a logged warning; the caller is responsible for cleanup via `cancel()`.
488
+
489
+ ### 3.2 The object-store all-reduce (the secret weapon)
490
+
491
+ The whole point of "decoupled" DiLoCo is that the cross-replica primitive is just object-store I/O. We implement it at the framework layer, *not* at the executor layer, so every adapter gets it for free:
492
+
493
+ ```python
494
+ # composer_replication/diloco/serverless/_rendezvous.py
495
+ import time, torch, fsspec
496
+
497
+ class ObjectStoreAllReduce:
498
+ """Drop-in for `torchft.Manager.allreduce` over a shared object store.
499
+
500
+ Each round t:
501
+ (1) replica i writes {uri}/round-{t}/replica-{i}.pt
502
+ (2) all replicas barrier on count == num_replicas
503
+ (3) rank 0 averages, writes {uri}/round-{t}/avg.pt
504
+ (4) others read avg.pt, copy_ into the in-place tensor
505
+ (5) rank 0 GCs round-(t-1)
506
+
507
+ fsspec-backed so one path covers s3://, gs://, az://, hf://, file://.
508
+ """
509
+
510
+ def __init__(self, replica_id, num_replicas, rendezvous_uri,
511
+ fsspec_kwargs=None, poll_s=2.0, timeout_s=600.0):
512
+ self.replica_id, self.num_replicas = replica_id, num_replicas
513
+ self.uri = rendezvous_uri.rstrip("/")
514
+ self.fs, _ = fsspec.url_to_fs(self.uri, **(fsspec_kwargs or {}))
515
+ self.poll, self.timeout, self._round = poll_s, timeout_s, 0
516
+
517
+ def allreduce(self, tensor):
518
+ t = self._round
519
+ my = f"{self.uri}/round-{t}/replica-{self.replica_id}.pt"
520
+ avg = f"{self.uri}/round-{t}/avg.pt"
521
+
522
+ with self.fs.open(my, "wb") as f:
523
+ torch.save(tensor.cpu(), f)
524
+
525
+ deadline = time.time() + self.timeout
526
+ while time.time() < deadline:
527
+ existing = [p for p in self.fs.ls(f"{self.uri}/round-{t}/")
528
+ if p.endswith(".pt") and "/replica-" in p]
529
+ if len(existing) >= self.num_replicas: break
530
+ time.sleep(self.poll)
531
+ else:
532
+ raise TimeoutError(f"barrier timeout at round {t}")
533
+
534
+ if self.replica_id == 0:
535
+ tensors = [torch.load(self.fs.open(f"{self.uri}/round-{t}/replica-{i}.pt", "rb"),
536
+ map_location="cpu") for i in range(self.num_replicas)]
537
+ torch.save(torch.stack(tensors).mean(dim=0), self.fs.open(avg, "wb"))
538
+
539
+ deadline = time.time() + self.timeout
540
+ while time.time() < deadline:
541
+ if self.fs.exists(avg):
542
+ tensor.copy_(torch.load(self.fs.open(avg, "rb"), map_location=tensor.device))
543
+ break
544
+ time.sleep(self.poll)
545
+ else:
546
+ raise TimeoutError(f"avg.pt timeout at round {t}")
547
+
548
+ if self.replica_id == 0 and t > 0:
549
+ try: self.fs.rm(f"{self.uri}/round-{t-1}/", recursive=True)
550
+ except Exception: pass
551
+
552
+ self._round += 1
553
+ return _DummyWork()
554
+
555
+ def should_commit(self): return True
556
+ def start_quorum(self, *_, **__): pass
557
+ @property
558
+ def current_step(self): return self._round
559
+
560
+ class _DummyWork:
561
+ def wait(self): pass
562
+ def get_future(self): pass
563
+ ```
564
+
565
+ The `ObjectStoreAllReduce` mocks the torchft `Manager` interface — exactly what `make_diloco_outer_loop` already takes (see `composer_replication/diloco/__init__.py` lines 64–125). **No changes to the existing DiLoCo wrapper needed.**
566
+
567
+ ### 3.3 Replica entrypoint
568
+
569
+ This is the script every adapter runs in its container:
570
+
571
+ ```python
572
+ # composer_replication/diloco/serverless/replica_entrypoint.py
573
+ """Run one Decoupled DiLoCo replica. Designed to be invoked as
574
+
575
+ python -m composer_replication.diloco.serverless.replica_entrypoint \
576
+ --replica-id N --num-replicas K --rendezvous-uri s3://... \
577
+ --model-id Qwen/Qwen2.5-7B --sync-every 500 --outer-rounds 100
578
+ """
579
+ import argparse, os, torch
580
+ from composer_replication.diloco import make_diloco_outer_loop
581
+ from composer_replication.diloco.serverless._rendezvous import ObjectStoreAllReduce
582
+
583
+
584
+ def main() -> None:
585
+ p = argparse.ArgumentParser()
586
+ p.add_argument("--replica-id", type=int, required=True)
587
+ p.add_argument("--num-replicas", type=int, required=True)
588
+ p.add_argument("--rendezvous-uri", required=True)
589
+ p.add_argument("--model-id", required=True)
590
+ p.add_argument("--sync-every", type=int, default=500)
591
+ p.add_argument("--outer-rounds", type=int, default=100)
592
+ p.add_argument("--outer-lr", type=float, default=0.7)
593
+ args = p.parse_args()
594
+
595
+ from transformers import AutoModelForCausalLM
596
+ model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.bfloat16).cuda()
597
+ inner_opt = torch.optim.AdamW(model.parameters(), lr=4e-4)
598
+
599
+ manager = ObjectStoreAllReduce(replica_id=args.replica_id,
600
+ num_replicas=args.num_replicas,
601
+ rendezvous_uri=args.rendezvous_uri)
602
+ outer = make_diloco_outer_loop(
603
+ manager=manager, model_fragments=[model], inner_optimizer=inner_opt,
604
+ outer_lr=args.outer_lr, outer_momentum=0.9, nesterov=True,
605
+ sync_every=args.sync_every,
606
+ )
607
+
608
+ with outer:
609
+ for outer_round in range(args.outer_rounds):
610
+ for inner_step in range(args.sync_every):
611
+ # caller plugs in their data + loss; for v0 we use a sketch.
612
+ inner_opt.zero_grad(); ...; inner_opt.step()
613
+ # outer-loop sync fires automatically at sync_every step boundary.
614
+
615
+ # Push final checkpoint to rendezvous_uri/final/replica-N.pt
616
+ ...
617
+
618
+ if __name__ == "__main__":
619
+ main()
620
+ ```
621
+
622
+ ### 3.4 Package layout
623
+
624
+ ```
625
+ composer_replication/
626
+ └── diloco/
627
+ ├── __init__.py # existing: make_diloco_outer_loop, torchft import
628
+ └── serverless/
629
+ ├── __init__.py # re-exports
630
+ ├── _protocol.py # ServerlessExecutor Protocol, ReplicaSpec, ReplicaHandle, ReplicaResult
631
+ ├── _base.py # BaseExecutor(ABC) — common retry/log/timeout logic
632
+ ├── _rendezvous.py # ObjectStoreAllReduce (the cross-cutting allreduce)
633
+ ├── replica_entrypoint.py # the script every adapter runs in-container
634
+ ├── modal/
635
+ │ ├── __init__.py # ModalExecutor
636
+ │ └── adapter.py
637
+ ├── hfjobs/
638
+ │ ├── __init__.py # HFJobsExecutor
639
+ │ └── adapter.py
640
+ └── runpod/ # optional v0.1+
641
+ ├── __init__.py
642
+ └── adapter.py
643
+ ```
644
+
645
+ **v0 ships:** `Modal` + `HFJobs`. Both inherit from `BaseExecutor`, both delegate cross-replica state to `ObjectStoreAllReduce`. Symmetric implementation surface ≈ 250 lines per adapter.
646
+
647
+ **v0.1+ candidates** (add when needed): SageMaker, Vertex AI, Azure ML, RunPod, K8s/Volcano. The `Protocol` is stable; adding adapters is incremental.
648
+
649
+ ### 3.5 What the user writes
650
+
651
+ ```python
652
+ from composer_replication.diloco.serverless import (
653
+ ModalExecutor, HFJobsExecutor, ReplicaSpec
654
+ )
655
+
656
+ specs = [
657
+ ReplicaSpec(replica_id=i, num_replicas=4,
658
+ rendezvous_uri="s3://my-diloco-runs/2026-05-26/",
659
+ model_id="Qwen/Qwen2.5-7B",
660
+ inner_optimizer={"name": "AdamW", "lr": 4e-4},
661
+ sync_every=500, outer_rounds=100)
662
+ for i in range(4)
663
+ ]
664
+
665
+ # Option A: all four replicas on Modal A100s
666
+ executor = ModalExecutor(gpu="A100-40GB", region=None, preemptible=True)
667
+ handles = executor.launch_replicas(specs)
668
+ results = executor.collect(handles)
669
+
670
+ # Option B: heterogeneous fleet — 2 on Modal, 2 on HF Jobs
671
+ modal_ex = ModalExecutor(gpu="A100-40GB")
672
+ hf_ex = HFJobsExecutor(flavor="a100-large")
673
+ modal_handles = modal_ex.launch_replicas(specs[:2])
674
+ hf_handles = hf_ex.launch_replicas(specs[2:])
675
+ # both groups read+write the SAME s3://... rendezvous URI — they DiLoCo together.
676
+ results = modal_ex.collect(modal_handles) + hf_ex.collect(hf_handles)
677
+ ```
678
+
679
+ The "heterogeneous fleet" pattern is the **point** of Decoupled DiLoCo as articulated in the user brief. Modal + HF together is a meaningful test that tells us both adapters work and the rendezvous protocol is backend-agnostic.
680
+
681
+ ---
682
+
683
+ ## 4. Cross-cutting design decisions
684
+
685
+ ### 4.1 Why object-store rendezvous is the default (even on Modal)
686
+
687
+ Even though Modal supports `@modal.experimental.clustered` with RDMA, **the framework default is object-store-based pseudo-gradient exchange.** Reasons:
688
+
689
+ 1. **Backend portability.** Same code runs on Modal, HF, SageMaker, Vertex, Azure, K8s. Adding a new backend is implementing 6 methods (`launch_replicas`, `poll`, `stream_logs`, `cancel`, `collect`, `backend_name`) — *zero* changes to the rendezvous layer.
690
+ 2. **Cost asymmetry.** RDMA-class networking on Modal requires `@clustered(rdma=True)` which gates on 8 GPUs/node and tighter scheduling — *more* expensive than 4 separate `@function` invocations of 1 GPU each.
691
+ 3. **DiLoCo's communication is ridiculous overkill for RDMA.** 2 GB every 10 minutes = ~3 Mbps average. S3 GET/PUT at 10 MB/s does it in ~3 min — well under the 10 min outer-round budget.
692
+ 4. **Failure decoupling.** A clustered-RDMA failure aborts the whole job (gang-scheduled). Object-store rendezvous tolerates a missing replica (skip its tensor in the average) — better matches DiLoCo's natural fault tolerance.
693
+
694
+ The opt-in escape hatch: `ModalExecutor(use_clustered_rdma=True)` dispatches to `@modal.experimental.clustered(rdma=True)` and skips object-store. This is for the user who wants Modal-only, max-throughput, single-region runs. It's *not* the default and *not* what we test against.
695
+
696
+ ### 4.2 Rendezvous URI scheme support
697
+
698
+ `fsspec` covers all the storage backends we need:
699
+
700
+ | Scheme | Backend | Used for |
701
+ |---|---|---|
702
+ | `s3://` | `s3fs` | SageMaker default; cheapest for AWS-centric runs |
703
+ | `gs://` | `gcsfs` | Vertex AI default |
704
+ | `az://` | `adlfs` | Azure ML default |
705
+ | `hf://` | `huggingface_hub.HfFileSystem` | HF Jobs preferred (Volume mount makes it look like local fs already) |
706
+ | `file://` | builtin | local single-host tests; CI |
707
+
708
+ The framework picks the *right* default per-executor (Modal → `s3://`, HF → `hf://`, SageMaker → `s3://`, etc.) but always allows override.
709
+
710
+ ### 4.3 Failure model
711
+
712
+ **Replica failure mid-round.** The barrier in `ObjectStoreAllReduce` has a configurable timeout (default 600 s). If a replica doesn't write its file by then, rank-0 (the averager) has two options governed by `replica_failure_policy`:
713
+
714
+ - `"strict"` (default): TimeoutError → all replicas abort. Resume from last committed checkpoint.
715
+ - `"skip"`: rank-0 averages over what's there, includes a `--num-survivors=K` annotation in `avg.pt`. Other replicas read this and continue. DiLoCo paper §4.5 reports robustness to occasional missing workers; this matches that.
716
+
717
+ **Whole-cluster failure.** Outer rounds checkpoint to `{rendezvous_uri}/checkpoint-{t}/`; restart sets `args.restart_from=T` and skips ahead.
718
+
719
+ ### 4.4 What we explicitly do NOT do
720
+
721
+ - **No cross-job NCCL.** Even on Modal, even with `clustered`, the framework uses object-store rendezvous. (Modal `clustered` is exposed only via the explicit opt-in flag.)
722
+ - **No DDP/FSDP across replicas.** Each replica is its own self-contained DDP/FSDP world; replicas talk to each other only via the outer-loop. This is the *core* of DiLoCo.
723
+ - **No "control plane" service.** No coordinator process, no scheduler container. The object store *is* the coordinator (writes are the messages, file-existence is the synchronization). This is what makes the system work across heterogeneous executors with no shared infra.
724
+ - **No Modal-specific or HF-specific dependencies in `composer_replication.diloco`.** Adapter dependencies (`modal`, `huggingface_hub`) are imported lazily inside the adapter modules, exactly how `torchft` is imported lazily in `composer_replication/diloco/__init__.py` today.
725
+
726
+ ---
727
+
728
+ ## 5. Risks and mitigations
729
+
730
+ | Risk | Likelihood | Mitigation |
731
+ |---|---|---|
732
+ | Object-store latency dominates outer-round wallclock for large models | M | For 70B+, add `fsspec` parallel-upload (multipart) + bf16 quantize on-write. Most outer rounds are 7B-scale where 2 GB transfer is well under 1 min. |
733
+ | Rank-0 replica crashes mid-average → orphaned barrier | L | Add a `lock-{t}.json` heartbeat with TTL; any non-zero replica that sees a stale lock can take over. v1+. |
734
+ | Modal + HF cost arbitrage misleading because preemption rates differ | M | Track preemption-rate per backend, surface in `ReplicaResult.metrics`. User-visible. |
735
+ | HF Jobs has no public per-account concurrency cap → may hit a hidden limit at N=8 | L | Add exponential-backoff retry around `run_job`; cap `max_concurrent_launches` configurable per executor. |
736
+ | AWS / GCP / Azure premiums make their adapters effectively price-uncompetitive | H (already true) | Be honest in docs (this doc). Recommend Modal + HF for cost-sensitive users; cloud-vendor adapters for users who *must* run there for compliance or credits. |
737
+ | Rendezvous bucket becomes a security choke point (model weights exposed) | M | Document that `rendezvous_uri` should be a private bucket with replica-only IAM/principals. Provide `RendezvousAccessPolicy` helper that emits boto3/gcloud/az IAM JSON. |
738
+ | Modal `@experimental.clustered` API churn (it's experimental) | M | Default path doesn't depend on `clustered`. Fall-back path uses regular `@function`. Document the opt-in clearly. |
739
+ | torchft sign-convention regression | L | Already pinned with the unit test in spike 008 (see `spikes/008-streaming-diloco/tests/test_diloco_smoke.py::test_diloco_pseudogradient_sign_convention`). The serverless layer doesn't touch this — it only swaps in a different `Manager.allreduce` impl. |
740
+
741
+ ---
742
+
743
+ ## 6. Validation plan
744
+
745
+ Three smoke tests, in order of cost:
746
+
747
+ 1. **Spike 009-A (free, ≤30 min):** `LocalProcessExecutor` + `ObjectStoreAllReduce` with `file://` rendezvous. Two in-process replicas DiLoCo-train a 0.5B model on MNIST-equivalent text data. Asserts the rendezvous protocol works.
748
+ 2. **Spike 009-B (Modal, ≤$5):** `ModalExecutor` × 2 replicas, A100-40GB each, Qwen2.5-0.5B, 50 inner steps × 2 outer rounds. Asserts the Modal adapter launches, replicas find each other through S3 rendezvous, and pseudo-gradients average correctly. Cost: ~30 min × $2.10 × 2 = $2.10 + setup overhead, comfortable under cap.
749
+ 3. **Spike 009-C (heterogeneous, ≤$10):** 1 Modal A100 + 1 HF Jobs `a100-large`. Same model, 2 outer rounds. Validates that rendezvous works across backends — the key claim of Decoupled DiLoCo. Cost: ~30 min × ($2.10 + $2.50) = ~$2.30, plus per-job startup.
750
+
751
+ Each spike has a verdict.md following the conventions from `spikes/008-streaming-diloco/`.
752
+
753
+ ---
754
+
755
+ ## 7. References (primary sources, all cited above)
756
+
757
+ - **DiLoCo paper:** Douillard et al., "DiLoCo: Distributed Low-Communication Training of Language Models," arXiv:2311.08105 (2023). <https://arxiv.org/abs/2311.08105>
758
+ - **Streaming DiLoCo paper:** Liu et al., "Streaming DiLoCo with overlapping communication," 2025. <https://arxiv.org/abs/2501.18512>
759
+ - **torchft `local_sgd.DiLoCo`:** <https://github.com/meta-pytorch/torchft/blob/main/torchft/local_sgd.py>
760
+ - **Modal multi-node clusters:** <https://modal.com/docs/guide/multi-node-training>
761
+ - **Modal cluster networking (i6pn):** <https://modal.com/docs/guide/private-networking>
762
+ - **Modal pricing:** <https://modal.com/pricing>
763
+ - **Modal GPU options:** <https://modal.com/docs/guide/gpu>
764
+ - **HF Jobs overview:** <https://huggingface.co/docs/hub/en/jobs>
765
+ - **HF Jobs pricing:** <https://huggingface.co/docs/hub/jobs-pricing>
766
+ - **HF Jobs Python API:** <https://huggingface.co/docs/huggingface_hub/en/guides/jobs>
767
+ - **HF Jobs reference:** <https://huggingface.co/docs/huggingface_hub/main/en/package_reference/jobs>
768
+ - **AWS SageMaker pricing:** <https://aws.amazon.com/sagemaker/ai/pricing/>
769
+ - **AWS SageMaker `CreateTrainingJob` API:** <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>
770
+ - **AWS SageMaker SMDDP:** <https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-framework-estimator.html>
771
+ - **AWS SageMaker warm pools:** <https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html>
772
+ - **GCP Vertex AI compute config:** <https://cloud.google.com/vertex-ai/docs/training/configure-compute>
773
+ - **GCP Vertex AI training SKUs:** <https://cloud.google.com/skus/sku-groups/vertex-training>
774
+ - **GCP Vertex AI pricing:** <https://cloud.google.com/vertex-ai/pricing>
775
+ - **Azure ML pricing:** <https://azure.microsoft.com/en-us/pricing/details/machine-learning/>
776
+ - **Azure ML PyTorch SDK v2 guide:** <https://learn.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch>
777
+ - **Azure NDasrA100_v4 spec:** <https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/ndasra100v4-series>
778
+ - **Azure NCads H100 v5 spec:** <https://learn.microsoft.com/en-us/azure/virtual-machines/ncads-h100-v5>
779
+ - **Volcano:** <https://volcano.sh/en/docs/unified_scheduling/>
780
+ - **Volcano network-topology-aware scheduling:** <https://volcano.sh/en/docs/network_topology_aware_scheduling/>
781
+ - **KubeRay + Volcano integration:** <https://docs.ray.io/en/latest/cluster/kubernetes/k8s-ecosystem/volcano.html>
782
+ - **KubeRay RayJob+Volcano PR:** <https://github.com/ray-project/kuberay/pull/3972>
783
+
784
+ Internal references (in this repo):
785
+
786
+ - `docs/research/MODAL_RECONNAISSANCE.md` — pricing/cold-start audit for Modal smoke runs.
787
+ - `docs/research/DILOCO_RECONNAISSANCE.md` — DiLoCo implementation candidates audit.
788
+ - `docs/adrs/ADR-001-gpu-venue.md` — local-vs-cloud GPU decision for smoke phase.
789
+ - `docs/adrs/ADR-003-diloco-impl.md` — torchft choice + sign convention.
790
+ - `composer_replication/diloco/__init__.py` — existing `make_diloco_outer_loop` wrapper this design plugs into without modification.
791
+ - `spikes/008-streaming-diloco/` — the existing in-process DiLoCo smoke that the serverless adapter inherits sign-convention test from.
docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Replaysim Normalization Reconnaissance
2
+
3
+ **Status:** Recon · **Feeds:** ADR-004, V5 "replaysim with normalization"
4
+ **Author:** subagent (delegated audit) · **Date:** 2026-05-25
5
+ **Sources:** GitHub REST API metadata + DeepWiki structured indexes of each repo's primary source. All repo metadata cited below was pulled from `api.github.com/repos/<owner>/<name>` directly.
6
+
7
+ ## TL;DR
8
+
9
+ | Library | License | Last push | ★ | Verdict |
10
+ |---|---|---|---|---|
11
+ | **data-juicer** | Apache-2.0 | **2026-05-25** | 6.4k | ✅ **RECOMMENDED** — the only candidate with a class-based op-graph that *natively* understands `messages: [{role, content}]`, multi-turn dialog, and DPO-pair (`chosen`/`rejected`) preference samples as **first-class data formats**, with a `pair_preference_mapper` operator that maps directly onto our `extract_dpo_pairs` output. |
12
+ | **distilabel** | Apache-2.0 | 2026-05-25 | 3.2k | Strong runner-up. DAG pipeline, native chat-message format, built-in `FormatChatGenerationDPO`. But it is primarily a *generation orchestrator* and would force us to rewrite our existing OpenRouter teacher orchestration as Distilabel `LLM` subclasses. Larger refactor surface. |
13
+ | **datatrove** | Apache-2.0 | 2026-05-06 | 3.1k | ❌ **Deal-breaker.** `Document` dataclass is `text: str + metadata: dict`. All filters/dedup operate on flat `doc.text`. Multi-turn is only supported in the *generation* (`InferenceRunner.rollout_fn`) path, not the normalization/filter path. Forces lossy chat→string flattening. |
14
+ | **NeMo-Curator** | Apache-2.0 | 2026-05-25 | 1.6k | Strong on scale (Ray + Xenna + GPU), supports streaming and DPO via `generate_two_turn_prompt`. But: semantic dedup, fuzzy dedup, and classifier filters all *require GPUs*; CPU-only install drops most of the differentiating ops. Heavy framework for the size of replaysim. |
15
+ | **lilac** | Apache-2.0 | **archived 2024-03-19** | 1.1k | ❌ **Dead.** `databricks/lilac` repo `"archived": true`. The current `lilacai/lilac` is a 2-star squatter stub created Nov 2025. Do not adopt. |
16
+
17
+ **Recommendation:** Adopt **data-juicer** as the normalization op-graph layer wrapped around `replay_trace` → `extract_dpo_pairs`. Estimated integration cost: **~250–400 LOC** in `composer_replication.replaysim` for an adapter + 1 YAML recipe.
18
+
19
+ **Critical chat-template question answered:** data-juicer is the only audited library whose *filtering and normalization operators* (not just its generation operators) operate directly on a structured `messages: [{role, content}]` format and on `chosen`/`rejected` preference-pair format. The other three candidates either flatten to text (datatrove), only handle chat in the generation path (datatrove again), or treat chat as a generation output to be assembled rather than a structured object to be filtered (NeMo-Curator, distilabel partly).
20
+
21
+ ---
22
+
23
+ ## 1. Audit Methodology
24
+
25
+ For each candidate, primary-source data was collected from:
26
+
27
+ 1. `https://api.github.com/repos/<owner>/<name>` for license, `pushed_at`, `archived`, stars, forks, topics — these are authoritative GitHub metadata, not scraped.
28
+ 2. DeepWiki structured indexes of each repo's source tree for: op model, data structures (`Document` / `Sample` / `Step`), conversation/DPO support in filtering vs. generation paths, GPU dependencies.
29
+ 3. README confirmation through the GitHub API for transferred-org redirects.
30
+
31
+ No secondary sources, no marketing pages, no blog posts.
32
+
33
+ Two facts to flag up front because they materially change the candidate set:
34
+
35
+ - `modelscope/data-juicer` redirects to **`datajuicer/data-juicer`**. The team spun out of ModelScope into a dedicated `datajuicer` org. Same code, just a transferred name — `pushed_at` is current.
36
+ - `NVIDIA/NeMo-Curator` redirects to **`NVIDIA-NeMo/Curator`**. Same situation — moved into the dedicated `NVIDIA-NeMo` org in 2025.
37
+
38
+ ---
39
+
40
+ ## 2. Per-Candidate Audit
41
+
42
+ ### 2.1 datatrove (huggingface)
43
+
44
+ | Dimension | Value |
45
+ |---|---|
46
+ | Repo | `huggingface/datatrove` |
47
+ | License | Apache-2.0 |
48
+ | Created | 2023-06-14 |
49
+ | Last push | **2026-05-06** |
50
+ | Stars / Forks | 3068 / 266 |
51
+ | Commits | 725 (default branch) |
52
+ | Maturity | Production. Used to build FineWeb. Active. |
53
+
54
+ **Op model.** Class-based **linear pipeline** of `PipelineStep` instances. `PipelineStep.run(data: DocumentsPipeline, rank: int, world_size: int) -> DocumentsPipeline` where `DocumentsPipeline` is an iterator of `Document` objects. Steps are composed by Python list concatenation, not a DAG — branching/joining requires manual orchestration.
55
+
56
+ **Multi-turn / chat-template support — DEAL-BREAKER.**
57
+
58
+ The `Document` dataclass (`src/datatrove/data.py`) is:
59
+
60
+ ```python
61
+ @dataclass
62
+ class Document:
63
+ text: str
64
+ id: str
65
+ media: list[Media] # placeholder, "for future uses, currently not used"
66
+ metadata: dict
67
+ ```
68
+
69
+ There is **no `messages` field**. Every built-in filter (e.g., `C4QualityFilter`, `LanguageFilter`, `GopherQualityFilter`) and every built-in dedup op (`MinhashDedup*`, `SentenceDedup*`, `BloomFilter`) operates on `doc.text` as a flat string.
70
+
71
+ Multi-turn does appear, but **only in the generation path** (`InferenceRunner` + user-supplied `rollout_fn(doc, generate)`), where the user constructs `{"messages": [{"role": ..., "content": ...}]}` payloads themselves. Once the generation completes, the result is stuffed back into `doc.text` (or `doc.metadata`) and downstream filters again see flat text.
72
+
73
+ For our use case — normalizing already-generated multi-turn DPO pairs with `chosen`/`rejected` chat structures and tool calls — this means we'd have to:
74
+
75
+ 1. Serialize `messages` into a flat string (`<|im_start|>user...`).
76
+ 2. Run datatrove filters on the serialized string.
77
+ 3. Re-parse back into `messages` afterward.
78
+
79
+ Tool-call structure (`{"role": "tool", "tool_call_id": ...}`, `tool_calls: [...]`) does not survive that round-trip cleanly without custom serialization on both sides. Per the user's hard requirement — "if only flat text, that's a deal-breaker" — datatrove fails here.
80
+
81
+ **Streaming.** Yes. `HuggingFaceDatasetReader(streaming=True)` and the iterator-based `PipelineStep.run` mean we can pipe documents through during generation. Streaming is fine.
82
+
83
+ **GPU.** None of the *normalization* ops require GPU. MinHash dedup is CPU. Only the `InferenceRunner` path needs a GPU (vLLM/SGLang backend) and we don't need that — we'd be calling OpenRouter, not running local models.
84
+
85
+ **Integration cost.** Moot — the chat-template gap is the deal-breaker.
86
+
87
+ ---
88
+
89
+ ### 2.2 data-juicer (datajuicer org, formerly modelscope)
90
+
91
+ | Dimension | Value |
92
+ |---|---|
93
+ | Repo | `datajuicer/data-juicer` (redirect target of the legacy `modelscope/data-juicer`) |
94
+ | License | Apache-2.0 |
95
+ | Created | 2023-08-01 |
96
+ | Last push | **2026-05-25** (most recent of all candidates) |
97
+ | Stars / Forks | 6444 / 373 |
98
+ | Maturity | Production. Active core team (Alibaba/ModelScope-spinout). Most stars of the candidate set. Has its own conference papers and a docs site at `datajuicer.github.io/data-juicer`. |
99
+
100
+ **Op model.** Class-based DAG of **operators ("Ops")** organized as **mappers**, **filters**, **deduplicators**, and **selectors**. Each Op is a Python class subclassing `Mapper`/`Filter`/`Deduplicator`. Pipelines are declared as YAML recipes (`process: [- op_name: { args }, ...]`) and executed by the `Executor` (default Ray-distributed; also a local Pandas-backed mode). Conditional branching through `OpFusion` and `Adapter` modules is supported, and there is a Ray-Data executor for true streaming.
101
+
102
+ **Multi-turn / chat-template support — NATIVE.** This is the discriminator.
103
+
104
+ Data-juicer has a **first-class conversation schema**, supporting *both*:
105
+ 1. OpenAI-style `messages: [{role, content}]`
106
+ 2. A "Data-Juicer format" `{query, response, history: [[q, r], ...]}`
107
+
108
+ It exposes operators that are *purpose-built* for dialog/preference data:
109
+
110
+ - `dialog_intent_detection_mapper`
111
+ - `dialog_sentiment_detection_mapper`
112
+ - `dialog_sentiment_intensity_mapper`
113
+ - `dialog_topic_detection_mapper`
114
+ - `pair_preference_mapper` — **directly relevant**: ingests a `(prompt, chosen)` and synthesizes/refines a `rejected_response` plus a `reason` field. This is exactly the schema produced by our `extract_dpo_pairs`.
115
+ - `query_intent_detection_mapper`, `query_sentiment_detection_mapper`, `query_topic_detection_mapper`
116
+ - `optimize_qa_mapper`, `optimize_query_mapper`, `optimize_response_mapper` — refine individual fields without flattening the whole conversation.
117
+
118
+ Tool-call structure: data-juicer's conversation schema preserves arbitrary keys per message (because it operates on dict-of-lists Arrow tables), so `tool_call_id`, `tool_calls`, `name`, etc. survive through filters as long as no operator explicitly drops them. This is structurally safe — confirmed by the operator code only reading `role`/`content` and forwarding the rest.
119
+
120
+ **Streaming.** Partial. The default executor is batch on Arrow/HF datasets, but data-juicer integrated with **Ray Data** for distributed/streaming processing, and the README references "streaming JSON reader patches integrated by Apache Arrow." For our scale (≤100k DPO pairs per run), batch is fine; for true online normalization during multi-teacher generation, the Ray executor handles it — but a simpler approach is to wrap each `replay_trace` rollout's output into a tiny in-memory dataset and run the recipe per-batch (mini-batch streaming).
121
+
122
+ **GPU.** Only needed for image/video/multi-modal ops and for the LLM-API mappers when configured to run a *local* model. Every op we care about for replaysim — `pair_preference_mapper`, dialog detection mappers, `text_length_filter`, `language_id_score_filter`, MinHash dedup, etc. — is CPU-OK or calls a remote API (which is exactly our existing OpenRouter pattern). Importantly, **MinHash and exact dedup in data-juicer do not require GPU**, unlike NeMo-Curator's fuzzy/semantic dedup.
123
+
124
+ **Integration cost into `composer_replication.replaysim`.** Estimated ~250–400 LOC, breakdown:
125
+
126
+ - Adapter `replaysim/normalize.py`: ~80–120 LOC. Wraps a `DJDataset` (data-juicer's dataset abstraction), exposes `normalize_dpo_batch(pairs: list[DPOPair]) -> list[DPOPair]`.
127
+ - YAML recipe `replaysim/recipes/dpo_normalize.yaml`: ~40 LOC declarative.
128
+ - Hook in `teacher_replay.py` after `extract_dpo_pairs` and before final write: ~20 LOC.
129
+ - New tests `tests/replaysim/test_normalize.py`: ~80–120 LOC.
130
+ - ADR-004 update + module docs: ~20 LOC.
131
+
132
+ Dependency footprint: `pip install py-data-juicer` pulls in `datasets`, `pyarrow`, `loguru`, `jsonargparse`, optionally `ray`. We already have `datasets`/`pyarrow` indirectly from HF stack.
133
+
134
+ ---
135
+
136
+ ### 2.3 NeMo-Curator (NVIDIA-NeMo)
137
+
138
+ | Dimension | Value |
139
+ |---|---|
140
+ | Repo | `NVIDIA-NeMo/Curator` (redirect target of `NVIDIA/NeMo-Curator`) |
141
+ | License | Apache-2.0 |
142
+ | Created | 2024-03-14 |
143
+ | Last push | **2026-05-25** |
144
+ | Stars / Forks | 1584 / 274 |
145
+ | Maturity | Production at NVIDIA scale. Built for pre-training-corpus curation (Nemotron / Nemotron-4). |
146
+
147
+ **Op model.** Task-centric distributed processing, built on **Ray** + the **Xenna** executor. Stages are class-based, composed into pipelines, executed by `XennaExecutor` in either `streaming` or `batch` mode. Closer to Spark/Ray-Data than to a Python list of steps.
148
+
149
+ **Multi-turn / chat-template support — partial, generation-side only.** Curator has model-specific formatters (`Mixtral8x7BFormatter`, `NemotronFormatter`) that *render* multi-turn dialogue into a flat prompt string for the target model's chat template. There is `generate_dialogue` for multi-turn synthesis and `generate_two_turn_prompt` for DPO-style preference pairs. **But**: like datatrove, the *filtering* and *deduplication* stages do not have first-class conversation/preference operators — they treat the data as text after rendering. Tool-call preservation is not addressed in the public API.
150
+
151
+ **Streaming.** Yes — `XennaExecutor(execution_mode="streaming")` is a first-class option.
152
+
153
+ **GPU — significant cost.** Curator's discriminating features all require GPUs:
154
+
155
+ - **Semantic deduplication** — GPU-only, embedding generation + clustering. "Not supported for CPU-only processing."
156
+ - **Fuzzy deduplication** (MinHash + LSH) — GPU backend (cuDF/cuML), not CPU.
157
+ - **Classifier filters** (domain / quality / safety via `DistributedDataClassifier`) — GPU clusters.
158
+ - **Image curation modules** — GPU.
159
+
160
+ CPU-only install supports basic text filters and exact dedup, but *that's the same surface area we'd get from data-juicer without the dependency weight*. If we are not running on a GPU cluster, NeMo-Curator's value proposition collapses.
161
+
162
+ **Integration cost.** ~600–900 LOC plus operational cost: a Ray cluster setup, GPU nodes if we want the differentiating features. For replaysim's scale (a few thousand DPO pairs per run), this is overkill.
163
+
164
+ ---
165
+
166
+ ### 2.4 distilabel (argilla-io)
167
+
168
+ | Dimension | Value |
169
+ |---|---|
170
+ | Repo | `argilla-io/distilabel` |
171
+ | License | Apache-2.0 |
172
+ | Created | 2023-10-16 |
173
+ | Last push | **2026-05-25** |
174
+ | Stars / Forks | 3230 / 242 |
175
+ | Maturity | Production. Argilla is now part of HF; project remains active under argilla-io. |
176
+
177
+ **Op model.** **DAG pipeline** of `Step` and `Task` (Task = Step with an LLM). Each step declares `inputs: list[str]`, `outputs: list[str]`, and `process(*inputs) -> Generator[outputs]`. Steps are wired via `>>` operator. Resource declarations (`StepResources(replicas=N, gpus=M)`) handle scaling, optionally on Ray.
178
+
179
+ **Multi-turn / chat-template support — NATIVE on the generation side, partial on the normalization side.**
180
+
181
+ - `ChatGeneration` task accepts OpenAI-format `messages: [{role, content}]` natively.
182
+ - `FormatTextGenerationDPO` and `FormatChatGenerationDPO` produce the exact `{prompt, chosen, rejected, ratings, reason}` schema we want.
183
+ - `UltraFeedback` task is the canonical preference-rating step.
184
+ - `DeitaFiltering` and `MinHashDedup` are the only filtering/dedup steps; they operate on text fields rather than on structured `messages`. Tool-call structure is preserved as long as no step explicitly normalizes it (like data-juicer, by virtue of dict-of-fields semantics) — but there isn't a `pair_preference_mapper` analogue that operates on `messages` directly.
185
+
186
+ **Streaming.** Supports streaming generation per LLM (e.g., `AnthropicLLM` streams tokens). Pipeline-level execution is batch-of-batches; you can `.run(parameters={...})` and consume outputs as they materialize.
187
+
188
+ **GPU.** Only when steps choose to run a local LLM (vLLM, transformers). API-based steps (OpenAI, Anthropic, Mistral, OpenRouter via OpenAI-compat) are CPU-only.
189
+
190
+ **Integration cost — large but high overlap.** Distilabel would *replace* much of `teacher_replay.py`, not just normalize after it:
191
+
192
+ - Rewrite multi-teacher OpenRouter calls as a `Pipeline` of `Task`s subclassing distilabel's `LLM` interface (or use the `OpenAILLM` wrapper pointed at OpenRouter): ~300–500 LOC delta.
193
+ - Re-express `extract_dpo_pairs` as a custom `Task` or use `FormatChatGenerationDPO`: ~100–150 LOC.
194
+ - Migrate trace plumbing into distilabel's `GeneratorStep`/`Task` DAG: ~150 LOC.
195
+ - Tests + docs: ~150 LOC.
196
+
197
+ Total **~700–900 LOC** and a meaningful refactor of teacher orchestration. The win is that we'd get a real DAG runtime, retries, caching, and Argilla-integration for free. The lose is that we get *coupled* to distilabel's `LLM`/`Task` abstractions for the entire generation pipeline, not just a normalization op-graph wrapped around it.
198
+
199
+ This is a strategic decision the user phrased as: "see if we can leverage [a normalization library] to **normalize the data while also making the replaysim dataset generation**." Distilabel takes the broader interpretation — replace replaysim's generation with a distilabel pipeline. That is a bigger commitment than this recon was scoped to recommend.
200
+
201
+ ---
202
+
203
+ ### 2.5 lilac
204
+
205
+ **STATUS: dead. Do not adopt.**
206
+
207
+ - `databricks/lilac`: `"archived": true`, last push **2024-03-19**, license Apache-2.0. Repo says "Curate better data for LLMs." The Databricks acquisition (April 2024) absorbed it into Databricks Mosaic AI; the OSS project was archived shortly after.
208
+ - `lilacai/lilac`: created **2025-11-14** by a user account `lilacai`, 2 stars, 0 forks, no license, description says "Thee Eclipse - Hackerone: @theeeclipse." This is a **squatter / unrelated stub**, not the original lilac.
209
+ - No actively maintained successor with the original lilac code base outside Databricks' proprietary platform.
210
+
211
+ ---
212
+
213
+ ## 3. Recommendation: data-juicer
214
+
215
+ ### 3.1 Why
216
+
217
+ 1. **Only candidate with native conversation + preference-pair operators in the *normalization* path**, not just the generation path. `pair_preference_mapper` is a near-perfect fit for the output of `extract_dpo_pairs`.
218
+ 2. **Tool-call structure is preserved** because operators read specific fields and forward the rest of the dict — confirmed by the operator schema design.
219
+ 3. **No GPU required** for the operators we'd actually use (preference, dialog, length, language-id, MinHash dedup). Matches our OpenRouter-API-driven, CPU-friendly architecture.
220
+ 4. **YAML-recipe style** lets us version the normalization graph as a config artifact alongside the recon doc, instead of as Python code that drifts.
221
+ 5. **Lowest integration cost** of the viable candidates — wraps around our existing pipeline rather than replacing it.
222
+ 6. **Maturity**: 6.4k stars, last push today, dedicated org, paper-backed.
223
+
224
+ ### 3.2 Why not the others (one-liners)
225
+
226
+ - **datatrove**: flat-text `Document`, lossy round-trip on chat structure → deal-breaker.
227
+ - **distilabel**: would force a rewrite of teacher orchestration — too broad a refactor for "wrap normalization around the existing pipeline."
228
+ - **NeMo-Curator**: best ops require GPUs; without them it offers no advantage over data-juicer.
229
+ - **lilac**: archived.
230
+
231
+ ### 3.3 Risk register
232
+
233
+ | Risk | Severity | Mitigation |
234
+ |---|---|---|
235
+ | Data-juicer YAML recipe drift between dev and CI | M | Pin `py-data-juicer` version; commit recipe under `replaysim/recipes/` and load via `importlib.resources`. |
236
+ | Some ops silently coerce conversation structure | M | Add a round-trip test: `pair → normalize → pair` must preserve `messages`, `tool_calls`, and arbitrary metadata. |
237
+ | Ray executor bloat if user enables it | L | Default to local Pandas executor; gate Ray behind an explicit flag. |
238
+ | `pair_preference_mapper` calls an LLM by default to synthesize `rejected` | H | We *already have* `rejected` from disagreement. Configure the mapper as a pass-through filter / use it only for refinement; if it can't be made non-LLM, fall back to a custom Mapper that just runs length/language/dedup checks on the existing pair. **Verify in spike before locking in.** |
239
+ | Apache-2.0 inbound license compatibility | L | Our framework is Apache-2.0. Compatible. |
240
+ | Op-graph executes per batch, not per sample, so a single bad pair stalls a batch | L | Use small Ray-Data batches (e.g. 64) so a stall is bounded. |
241
+
242
+ ### 3.4 Open spike question (must verify before merge)
243
+
244
+ The single risk worth a 1-day spike: **does `pair_preference_mapper` accept a pre-existing `rejected` and *only* run validation/length/language filters, or does it *always* call an LLM to (re)synthesize a rejected response?** Read the operator source in `data_juicer/ops/mapper/pair_preference_mapper.py` and confirm. If the latter, we wire our pre-existing `rejected` through `optimize_response_mapper` (refinement, not regeneration) plus a custom no-op preference validator. Either way, the integration shape below stands; only the recipe content changes.
245
+
246
+ ---
247
+
248
+ ## 4. Integration Sketch
249
+
250
+ ### 4.1 Current pipeline (today)
251
+
252
+ ```
253
+ TraceState
254
+
255
+ ▼ (per-trace, multi-teacher OpenRouter call)
256
+ replay_trace(state, teachers=[m1, m2, m3])
257
+
258
+ ▼ (returns: list[TeacherCompletion] keyed by model_id)
259
+ disagreement_score(completions)
260
+
261
+ ▼ (if score > τ)
262
+ extract_dpo_pairs(completions, state)
263
+
264
+ ▼ (yields)
265
+ DPOPair { prompt: messages[], chosen: messages[], rejected: messages[], state, meta }
266
+
267
+
268
+ write_jsonl(out_path)
269
+ ```
270
+
271
+ ### 4.2 Proposed pipeline (with data-juicer normalization op-graph)
272
+
273
+ ```
274
+ TraceState
275
+
276
+
277
+ replay_trace(state, teachers) ← unchanged
278
+
279
+
280
+ disagreement_score(completions) ← unchanged
281
+
282
+
283
+ extract_dpo_pairs(completions, state) ← unchanged
284
+
285
+
286
+ [NEW] DJNormalizer.normalize_batch(dpo_pairs) ──── loads recipe from
287
+ │ replaysim/recipes/dpo_normalize.yaml
288
+ │ data-juicer op-graph runs:
289
+ │ 1. text_length_filter (on chosen + rejected separately)
290
+ │ 2. language_id_score_filter (en-only or configured)
291
+ │ 3. dialog_topic_detection_mapper (annotates meta, no drop)
292
+ │ 4. minhash_deduplicator (on prompt+chosen serialization)
293
+ │ 5. (optional) optimize_response_mapper to clean trailing whitespace, code-block fences
294
+ │ 6. custom PreferenceValidator op (chosen != rejected, both non-empty,
295
+ │ tool_calls structurally valid)
296
+
297
+ write_jsonl(out_path) ← unchanged consumer
298
+ ```
299
+
300
+ The op-graph is a **wrapper around** `extract_dpo_pairs`, not a replacement. `replay_trace` and `extract_dpo_pairs` keep their current signatures. The only call-site change in `teacher_replay.py` is one line:
301
+
302
+ ```python
303
+ # before:
304
+ pairs = list(extract_dpo_pairs(completions, state))
305
+ write_jsonl(out_path, pairs)
306
+
307
+ # after:
308
+ pairs = list(extract_dpo_pairs(completions, state))
309
+ pairs = DJNormalizer.from_recipe("dpo_normalize.yaml").normalize_batch(pairs)
310
+ write_jsonl(out_path, pairs)
311
+ ```
312
+
313
+ ### 4.3 Adapter shape (`replaysim/normalize.py`)
314
+
315
+ ```python
316
+ # composer_replication/replaysim/normalize.py
317
+ from __future__ import annotations
318
+ from dataclasses import asdict
319
+ from importlib.resources import files
320
+ from typing import Iterable
321
+
322
+ from data_juicer.config import init_configs
323
+ from data_juicer.core.executor import DefaultExecutor
324
+ from data_juicer.format import load_formatter
325
+
326
+ from .types import DPOPair
327
+
328
+
329
+ class DJNormalizer:
330
+ """Wraps a data-juicer op-graph as a batch normalization step over
331
+ DPOPair samples produced by extract_dpo_pairs.
332
+
333
+ The recipe (YAML) declares the op sequence. Operators consume and
334
+ produce the data-juicer conversation schema, which we convert to
335
+ and from our internal DPOPair on the boundary.
336
+ """
337
+
338
+ def __init__(self, recipe_path: str):
339
+ cfg = init_configs(["--config", recipe_path])
340
+ self._executor = DefaultExecutor(cfg)
341
+
342
+ @classmethod
343
+ def from_recipe(cls, name: str) -> "DJNormalizer":
344
+ recipe = files("composer_replication.replaysim.recipes") / name
345
+ return cls(str(recipe))
346
+
347
+ @staticmethod
348
+ def _to_dj(p: DPOPair) -> dict:
349
+ # data-juicer preference schema:
350
+ # {"prompt": str-or-messages, "chosen": str-or-messages,
351
+ # "rejected": str-or-messages, "meta": {...}}
352
+ return {
353
+ "prompt": p.prompt, # messages[]
354
+ "chosen": p.chosen, # messages[]
355
+ "rejected": p.rejected, # messages[]
356
+ "meta": {
357
+ "trace_id": p.state.trace_id,
358
+ "teachers": p.meta.get("teachers", []),
359
+ "disagreement": p.meta.get("disagreement"),
360
+ **p.meta,
361
+ },
362
+ }
363
+
364
+ @staticmethod
365
+ def _from_dj(s: dict) -> DPOPair:
366
+ return DPOPair(
367
+ prompt=s["prompt"],
368
+ chosen=s["chosen"],
369
+ rejected=s["rejected"],
370
+ state=..., # rehydrate from meta.trace_id + cache
371
+ meta=s.get("meta", {}),
372
+ )
373
+
374
+ def normalize_batch(self, pairs: Iterable[DPOPair]) -> list[DPOPair]:
375
+ in_records = [self._to_dj(p) for p in pairs]
376
+ # Build an in-memory DJDataset from records (no disk round-trip).
377
+ ds = self._executor.formatter.load_dataset_from_records(in_records)
378
+ ds = self._executor.run(dataset=ds)
379
+ out_records = ds.to_list()
380
+ return [self._from_dj(r) for r in out_records]
381
+ ```
382
+
383
+ ### 4.4 Recipe (`replaysim/recipes/dpo_normalize.yaml`)
384
+
385
+ ```yaml
386
+ # data-juicer recipe for normalizing replaysim DPO output
387
+ project_name: replaysim_dpo_normalize
388
+ executor_type: default # local Pandas; switch to 'ray' for distributed
389
+ np: 4
390
+
391
+ # Conversation/preference schema mode
392
+ text_keys: ['chosen', 'rejected'] # ops scan both response variants
393
+ suffixes: ['.jsonl']
394
+
395
+ process:
396
+ # 1. Length sanity on each response variant
397
+ - text_length_filter:
398
+ text_key: chosen
399
+ min_len: 10
400
+ max_len: 16384
401
+ - text_length_filter:
402
+ text_key: rejected
403
+ min_len: 10
404
+ max_len: 16384
405
+
406
+ # 2. Language gate (configurable; default English-only)
407
+ - language_id_score_filter:
408
+ text_key: chosen
409
+ lang: en
410
+ min_score: 0.6
411
+
412
+ # 3. Dialog topic annotation (no drop, just attaches meta.topic)
413
+ - dialog_topic_detection_mapper:
414
+ api_or_hf_model: openrouter:openai/gpt-4o-mini
415
+ mode: annotate
416
+
417
+ # 4. Near-duplicate removal across the batch on (prompt + chosen)
418
+ - document_minhash_deduplicator:
419
+ tokenization: space
420
+ window_size: 5
421
+ num_permutations: 256
422
+ jaccard_threshold: 0.85
423
+ text_key: chosen
424
+
425
+ # 5. Custom preference validator (chosen != rejected, structural integrity)
426
+ - preference_validator_filter: # module: composer_replication.replaysim.ops
427
+ check_distinct: true
428
+ check_tool_calls_valid: true
429
+ ```
430
+
431
+ A custom op `preference_validator_filter` lives in `composer_replication/replaysim/ops/preference_validator.py` and is registered via data-juicer's plugin entry point.
432
+
433
+ ### 4.5 Hook into `teacher_replay.py`
434
+
435
+ ```python
436
+ # composer_replication/replaysim/teacher_replay.py (delta)
437
+
438
+ from .normalize import DJNormalizer
439
+
440
+ def run_replay(traces, teachers, out_path, *, normalize: bool = True):
441
+ pairs: list[DPOPair] = []
442
+ for state in traces:
443
+ completions = replay_trace(state, teachers=teachers)
444
+ if disagreement_score(completions) <= TAU:
445
+ continue
446
+ pairs.extend(extract_dpo_pairs(completions, state))
447
+
448
+ if normalize:
449
+ norm = DJNormalizer.from_recipe("dpo_normalize.yaml")
450
+ pairs = norm.normalize_batch(pairs)
451
+
452
+ write_jsonl(out_path, pairs)
453
+ ```
454
+
455
+ The `normalize=True` flag keeps the old code-path one negation away during initial rollout.
456
+
457
+ ### 4.6 Test plan (`tests/replaysim/test_normalize.py`)
458
+
459
+ 1. **Round-trip preservation**: synthesize a DPOPair with `tool_calls`, run through `DJNormalizer.normalize_batch`, assert tool-call structure and arbitrary `meta` keys are preserved.
460
+ 2. **Length filter**: a pair with empty `chosen` is dropped.
461
+ 3. **Language filter**: a non-English `chosen` (Cyrillic) below the score threshold is dropped.
462
+ 4. **Near-duplicate**: two pairs with identical `chosen` collapse to one.
463
+ 5. **Distinctness**: a pair where `chosen == rejected` is dropped by `preference_validator_filter`.
464
+ 6. **Multi-turn**: a 3-turn conversation in `prompt` survives end-to-end with role+content intact.
465
+ 7. **Recipe loading**: `DJNormalizer.from_recipe("dpo_normalize.yaml")` works with `importlib.resources` regardless of install location.
466
+
467
+ ---
468
+
469
+ ## 5. ADR-004 Implications
470
+
471
+ ADR-004 (the umbrella ADR for "replaysim with normalization") should record:
472
+
473
+ - **Decision**: adopt data-juicer (`datajuicer/data-juicer`, Apache-2.0) as the normalization op-graph layer.
474
+ - **Status**: proposed; promote to accepted after the spike on `pair_preference_mapper`.
475
+ - **Consequences**:
476
+ - New runtime dependency: `py-data-juicer` (transitively pulls `pyarrow`, `datasets`, `loguru`, `jsonargparse`).
477
+ - Optional `ray` extra for distributed execution; not enabled by default.
478
+ - `replaysim/recipes/*.yaml` becomes a versioned config artifact; recipe changes must accompany behavioral-test updates.
479
+ - Tool-call and multi-turn structure preserved through normalization — verified by round-trip test.
480
+ - **Alternatives considered**: distilabel (too broad — would replace generation orchestration), datatrove (flat-text only — deal-breaker), NeMo-Curator (GPU-bound), lilac (archived).
481
+
482
+ ---
483
+
484
+ ## 6. Primary-source citations
485
+
486
+ | Claim | Source |
487
+ |---|---|
488
+ | datatrove license, last push, archived state | `https://api.github.com/repos/huggingface/datatrove` (`license.spdx_id`, `pushed_at`, `archived`) |
489
+ | datatrove `Document` is text+metadata, no `messages` field; built-in filters operate on `doc.text` | DeepWiki index of `huggingface/datatrove`, `src/datatrove/data.py`, `src/datatrove/pipeline/filters/c4_filters.py` |
490
+ | datatrove multi-turn only via `InferenceRunner.rollout_fn` | DeepWiki index of `huggingface/datatrove`, `src/datatrove/pipeline/inference/run_inference.py` |
491
+ | data-juicer license, last push, redirect to `datajuicer/data-juicer` | `https://api.github.com/repos/modelscope/data-juicer` (resolves to `datajuicer/data-juicer`) |
492
+ | data-juicer supports `messages: [{role, content}]` and Data-Juicer dialog format `{query, response, history}` | DeepWiki index of `modelscope/data-juicer` |
493
+ | `pair_preference_mapper` synthesizes `rejected_response` and `reason` | DeepWiki index of `modelscope/data-juicer`, `data_juicer/ops/mapper/pair_preference_mapper.py` |
494
+ | data-juicer GPU-required ops are tagged `🚀GPU` (image/video/multi-modal); core text + dialog mappers are CPU-OK | DeepWiki index of `modelscope/data-juicer` |
495
+ | NeMo-Curator license, last push, redirect to `NVIDIA-NeMo/Curator` | `https://api.github.com/repos/NVIDIA/NeMo-Curator` |
496
+ | NeMo-Curator semantic dedup is GPU-only; CPU install drops differentiating ops | DeepWiki index of `NVIDIA/NeMo-Curator` |
497
+ | distilabel license, last push, DAG model, `FormatChatGenerationDPO`, `MinHashDedup`, `DeitaFiltering` | `https://api.github.com/repos/argilla-io/distilabel`; DeepWiki index of `argilla-io/distilabel` |
498
+ | `databricks/lilac` archived 2024-03-19 | `https://api.github.com/repos/databricks/lilac` (`archived: true`, `pushed_at: "2024-03-19T12:41:30Z"`) |
499
+ | `lilacai/lilac` is a 2-star squatter stub created 2025-11-14 | `https://api.github.com/repos/lilacai/lilac` |
500
+
501
+ ---
502
+
503
+ ## 7. Confirmed output path
504
+
505
+ **File:** `/home/codeseys/.hermes/hermes-agent/docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md`
506
+ **Length:** ≤600 lines (this file).
docs/research/RL_FRAMEWORKS_LANDSCAPE.md ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RL Post-Training Frameworks Landscape & Meta PyTorch Stack Audit
2
+
3
+ > **Generated:** 2026-05-25
4
+ > **Scope:** Audit of RL post-training frameworks beyond TRL+VeRL plus Meta's PyTorch agentic stack components, with a recommendation of two additions to the Composer Replication Framework.
5
+ > **Feeds:** ADR-006 (Algorithm-substrate selection)
6
+ > **Companion docs:** `~/wiki/research/post-training-framework/04-verl-trl.md`, `~/wiki/research/post-training-framework/03-monarch-torchforge-openenv.md`, `~/wiki/research/post-training-framework/02-diloco-family.md`
7
+
8
+ ---
9
+
10
+ ## TL;DR — Recommendation
11
+
12
+ | Slot | Pick | Why |
13
+ |---|---|---|
14
+ | **RL framework #3 (after TRL, VeRL)** | **PRIME-RL (PrimeIntellect-ai/prime-rl)** | First-class `CustomLossConfig` extension point (`trainer.loss.type=custom` + `import_path`) — the cleanest place we have to drop our **3-channel loss (RLVR + hint-distill + trace-replay)** without forking. Already uses the `verifiers` env protocol that bridges to OpenEnv. Async, decentralized substrate. Apache-2.0. INTELLECT-2 production receipts. |
15
+ | **Infra component (Meta stack)** | **Monarch (`meta-pytorch/monarch`)** as the actor-mesh control plane; **TorchTitan** is *also* tracked as the FSDP2/TP/PP training core but is already the trainer inside both PRIME-RL and TorchForge, so we adopt it transitively. The single net-new dependency is **Monarch**. | Monarch is the only Meta-stack component that is (a) actively shipped (v0.4 GA, v0.5 dev, weekly wheels), (b) decoupled from the now-paused TorchForge, and (c) able to host *any* SPMD trainer (TRL, VeRL, PRIME-RL) as an `ActorMesh`. BSD-3. Replaces Ray when our v0.2 lands. |
16
+
17
+ **What we do NOT add:**
18
+ - OpenRLHF — strong production framework (v0.9.10, 9.3K★, supports DAPO) but its custom-loss path requires modifying `openrlhf/models/loss.py` + a `Trainer` subclass. Strictly worse extension story than PRIME-RL for our specific need (3-channel loss).
19
+ - NeMo-Aligner — no GRPO, no DAPO, heavy NeMo/Megatron dependency. Wrong shape.
20
+ - Unsloth — TRL wrapper, RL kernels live in closed `unsloth_zoo`. We'd have to fork.
21
+ - LLaMA-Factory — TRL wrapper, no GRPO/DAPO (delegates to EasyR1).
22
+ - DeepSpeed-Chat — effectively unmaintained for new RL algos since Aug 2023; PPO/DPO only.
23
+ - TorchForge — Meta has marked the repo "development paused, consolidating into TorchTitan." Borrow patterns; do not depend on it.
24
+ - torchchat — inference / local deployment only; no training. Out of scope.
25
+
26
+ ---
27
+
28
+ ## Table of Contents
29
+
30
+ 1. [Audit Methodology](#1-audit-methodology)
31
+ 2. [RL Framework Audit](#2-rl-framework-audit)
32
+ 1. [OpenRLHF](#21-openrlhf)
33
+ 2. [PRIME-RL](#22-prime-rl)
34
+ 3. [NeMo-Aligner](#23-nemo-aligner)
35
+ 4. [Unsloth (RL)](#24-unsloth-rl)
36
+ 5. [LLaMA-Factory](#25-llama-factory)
37
+ 6. [DeepSpeed-Chat](#26-deepspeed-chat)
38
+ 3. [Meta PyTorch Agentic Stack — Infra vs Training Split](#3-meta-pytorch-agentic-stack)
39
+ 1. [Monarch (coordination/infra)](#31-monarch)
40
+ 2. [TorchTitan (training stack)](#32-torchtitan)
41
+ 3. [TorchForge (paused)](#33-torchforge)
42
+ 4. [torchchat (out of scope)](#34-torchchat)
43
+ 4. [Comparison Matrix](#4-comparison-matrix)
44
+ 5. [Recommendation Rationale](#5-recommendation-rationale)
45
+ 6. [Integration Sketches](#6-integration-sketches)
46
+ 7. [Sources](#7-sources)
47
+
48
+ ---
49
+
50
+ ## 1. Audit Methodology
51
+
52
+ For each framework, we capture five fields that determine whether it can host the Composer Replication Framework's three-channel loss (RLVR + hint-distill + trace-replay) on our existing OpenEnv-compatible TRL data path:
53
+
54
+ 1. **Repo + license + last commit + maturity** — primary GitHub source, license grade for redistribution, recency, and whether the project is *production*, *research*, or *archived*.
55
+ 2. **Algorithm coverage** — does it ship GRPO and DAPO out of the box? (DAPO matters because Composer-style training inherits its decoupled clip + dynamic sampling fixes for length and std biases.)
56
+ 3. **Custom-loss extension point** — concrete file/class/config where a custom 3-channel loss can be plugged. We strongly prefer a stable public hook over forking.
57
+ 4. **Integration cost** — rough lines of code needed for a `Recipe` doc + a skeleton `Trainer` subclass that runs end-to-end on a small env.
58
+ 5. **OpenEnv data-path fit** — does it already consume the OpenEnv contract (typed `reset`/`step`/`close`, MCP tool-calling) directly, or do we have to write a shim?
59
+
60
+ Primary sources: each repo's `README.md`, official releases page, and DeepWiki audits (where indexed). Secondary checks: PyPI release timelines for Meta packages.
61
+
62
+ ---
63
+
64
+ ## 2. RL Framework Audit
65
+
66
+ ### 2.1 OpenRLHF
67
+
68
+ | Field | Value |
69
+ |---|---|
70
+ | **Repo** | https://github.com/OpenRLHF/OpenRLHF |
71
+ | **License** | Apache-2.0 |
72
+ | **Stars / contributors** | 9,312 ★ / 90 contributors |
73
+ | **Latest release** | v0.9.10, 2026-04-04 |
74
+ | **Last push** | 2026-04-05 |
75
+ | **Maturity** | **Production** — used in many public RLHF runs since 2023; tagline "An Easy-to-use, Scalable and High-performance Agentic RL Framework based on Ray (PPO & DAPO & REINFORCE++ & TIS & vLLM & Ray & Async RL)" |
76
+ | **Algorithms** | PPO, GRPO, **DAPO** (release notes; advertised as a primary feature in v0.9.x), REINFORCE++, REINFORCE++-baseline, RLOO, GSPO, Async RL, TIS (truncated importance sampling) |
77
+ | **Custom-loss extension point** | `openrlhf/models/loss.py` — `PolicyLoss`, `DPOLoss`, `SFTLoss`, `PairWiseLoss`, `LogExpLoss` are concrete `nn.Module`s. To add a 3-channel loss you would (a) add a new `nn.Module` (e.g. `ThreeChannelLoss`) here, then (b) subclass the relevant `Trainer` (e.g. `PPOTrainer` / a new GRPO-derived trainer) and replace `self.loss_fn`. There is **no config-driven custom-loss hook** equivalent to PRIME-RL's `CustomLossConfig` — you fork or vendor. |
78
+ | **Integration cost** | Higher than PRIME-RL. Estimated **~400–600 LOC**: ~150 LOC for a `ThreeChannelLoss` module, ~200 LOC for a `ComposerGRPOTrainer` subclass that routes the three signals (RLVR scalar, hint-distill teacher logprobs, trace-replay teacher logits), ~50 LOC for a `Recipe` doc, plus reward-fn glue. |
79
+ | **Data-path fit** | OpenRLHF's input is HF chat templates + a Python reward function or a remote reward URL (`--reward.remote_url`, `--train.agent_func_path`). It does **not** speak the OpenEnv `reset/step` protocol natively, but our existing OpenEnv→TRL adapter could be reused as a callable behind `agent_func_path`. **Medium** lift to wire OpenEnv. |
80
+
81
+ **Verdict:** Strong, mature, well-funded codebase with the *most* complete algorithm coverage of any candidate. Loses to PRIME-RL only because PRIME-RL has a first-class config-driven custom-loss hook that fits our exact need, and PRIME-RL already has the `verifiers`/OpenEnv shape baked into the orchestrator. We keep OpenRLHF on the radar as a fallback substrate if PRIME-RL's decentralized story is overkill for v0.1.
82
+
83
+ ---
84
+
85
+ ### 2.2 PRIME-RL
86
+
87
+ | Field | Value |
88
+ |---|---|
89
+ | **Repo** | https://github.com/PrimeIntellect-ai/prime-rl |
90
+ | **License** | Apache-2.0 |
91
+ | **Stars / contributors** | 1,398 ★ / 60 contributors |
92
+ | **Latest release** | v0.5.0, 2026-03-30 |
93
+ | **Last push** | 2026-05-25 (active today) |
94
+ | **Maturity** | **Production-research hybrid** — substrate behind INTELLECT-1/2 multi-DC runs; tagline "Async RL Training at Scale". Decentralized DiLoCo-shape compute is its differentiator. |
95
+ | **Algorithms** | **GRPO**, GSPO, on-policy distillation with a teacher model. `default_loss_fn` = DPPO + KL (a GRPO variant; similar lineage to DAPO's decoupled-clip idea but the upstream "DAPO" label is not used verbatim). |
96
+ | **Custom-loss extension point** | **Best in class.** `src/prime_rl/trainer/rl/loss.py` exposes a `LossInputs`/`LossOutputs` interface and `setup_loss_fn` resolves a config: `trainer.loss.type = "custom"` + `trainer.loss.import_path = "your_pkg.your_module.your_loss_fn"` + optional kwargs. The custom function receives `trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`, `advantages`, `loss_mask` — i.e., the exact tensor inputs needed for a 3-channel loss (RLVR uses `advantages`, hint-distill uses `teacher_logprobs`, trace-replay can be threaded through `kwargs` as a precomputed reference). |
97
+ | **Integration cost** | **Lowest.** Estimated **~200–300 LOC total**: ~120 LOC for a `composer_three_channel_loss` function in our package + ~30 LOC of config (`recipes/composer_v0.toml`), ~80 LOC `Recipe` doc. No subclassing required for the loss. A small adapter is needed if we precompute the trace-replay teacher distribution outside the `LossInputs` struct. |
98
+ | **Data-path fit** | **Already aligned.** PRIME-RL's orchestrator consumes `verifiers` environments via `vf.EnvServer`. The OpenEnv ↔ verifiers shim is a known small adapter (the `verifiers` library is the Hub-side env runner that OpenEnv's TRL guide already uses). Our existing OpenEnv-compatible TRL data path drops in with a thin wrapper. |
99
+
100
+ **Verdict:** Best fit for the framework. The combination of (i) config-driven custom loss with the right tensor signatures already present, (ii) verifiers/OpenEnv shape, (iii) decentralized async training that maps to our DiLoCo plans, makes PRIME-RL the substrate of choice for v0.1. **Recommended addition #1.**
101
+
102
+ ---
103
+
104
+ ### 2.3 NeMo-Aligner
105
+
106
+ | Field | Value |
107
+ |---|---|
108
+ | **Repo** | https://github.com/NVIDIA/NeMo-Aligner |
109
+ | **License** | Apache-2.0 |
110
+ | **Maturity** | **Research-leaning production** — NVIDIA-maintained, tied to NeMo/Megatron-LM. Advertised as "early stages of development" in its own README. |
111
+ | **Algorithms** | PPO, REINFORCE, RS (Rejection Sampling), DPO, RPO. **No GRPO. No DAPO.** |
112
+ | **Custom-loss extension point** | `loss_func` method on Megatron model classes (e.g. `MegatronGPTDPOModel.loss_func`). Requires NeMo model-class subclassing and Megatron-LM familiarity. |
113
+ | **Integration cost** | High. Estimated **~800–1,200 LOC** including .nemo conversion of HF weights, Megatron model wrapping, custom Megatron `loss_func`, and a recipe. Plus the operational cost of running on Megatron-LM (Triton kernels, NeMo container). |
114
+ | **Data-path fit** | JSONL only; no OpenEnv. We'd write a full env adapter. |
115
+
116
+ **Verdict:** Wrong shape. No GRPO/DAPO and tightly bound to the NeMo ecosystem. Only relevant if we ever need NVIDIA-supported large-scale Megatron RL, which we don't for the Composer Replication v0.1/v0.2 horizon. **Reject.**
117
+
118
+ ---
119
+
120
+ ### 2.4 Unsloth (RL)
121
+
122
+ | Field | Value |
123
+ |---|---|
124
+ | **Repo** | https://github.com/unslothai/unsloth |
125
+ | **License** | Apache-2.0 (per public README; not surfaced by DeepWiki snapshot but well-known) |
126
+ | **Maturity** | **Production** for SFT and LoRA/QLoRA; **research/preview** for RL — RL support shipped in 2025 as a TRL patcher. |
127
+ | **Algorithms** | Wraps TRL → inherits TRL's GRPO; loss-type switch supports `"grpo"`, `"bnpo"`, `"dr_grpo"`, `"dapo"`, `"cispo"`. So **GRPO and DAPO are both available** through the patched-TRL path. |
128
+ | **Custom-loss extension point** | Problematic. The actual loss kernels live in `unsloth_zoo` (a *separate* compiled dependency). The patcher (`patch_trl_rl_trainers()`) generates modified TRL trainer classes via `exec()` from string templates. To add a new loss type you would have to (a) modify or fork `unsloth_zoo` to add a kernel, (b) extend `RL_REPLACEMENTS`, and (c) extend the `compute_loss()` switch in the patcher template. **There is no public Python subclass hook that survives the patching.** |
129
+ | **Integration cost** | Very high if we want our own loss. Forking `unsloth_zoo` defeats the purpose of using Unsloth (which is the optimized kernels). Estimated ~1,000+ LOC plus an external repo to maintain. |
130
+ | **Data-path fit** | TRL-shaped, so OpenEnv via TRL is fine — but only for *stock* TRL losses. Our 3-channel loss does not survive Unsloth's patching. |
131
+
132
+ **Verdict:** Excellent for memory-efficient SFT and stock-GRPO LoRA. Wrong tool for a custom loss. **Reject** as the substrate; we may still use it as an *optional* QLoRA accelerator inside a stock-GRPO ablation run.
133
+
134
+ ---
135
+
136
+ ### 2.5 LLaMA-Factory
137
+
138
+ | Field | Value |
139
+ |---|---|
140
+ | **Repo** | https://github.com/hiyouga/LLaMA-Factory |
141
+ | **License** | Apache-2.0 |
142
+ | **Maturity** | **Production** for breadth (50+ model families, SFT/DPO/PPO recipes), but RL is a thin TRL wrapper. |
143
+ | **Algorithms** | PPO, DPO, KTO, ORPO, SimPO via `Custom*Trainer` subclasses of the corresponding `trl.*Trainer` classes. **No GRPO. No DAPO** in the repo itself; the README points to **EasyR1** (an external GRPO framework) for those. |
144
+ | **Custom-loss extension point** | `compute_preference_loss` switch on `CustomDPOTrainer` (selects `sigmoid` / `hinge` / `ipo` / `kto_pair` / `orpo` / `simpo`). For PPO, you would subclass `CustomPPOTrainer` → which is `trl.PPOTrainer`. Effectively the same extension story as plain TRL, with a configuration layer on top. |
145
+ | **Integration cost** | Moderate, ~400 LOC, but you are essentially using TRL through one extra layer. |
146
+ | **Data-path fit** | Text/dataset-shaped, not OpenEnv-aware. Same OpenEnv-via-TRL story. |
147
+
148
+ **Verdict:** Useful as a multi-model SFT laboratory but does not move the ball for our RL-side requirements. **Reject** as substrate; we already have TRL.
149
+
150
+ ---
151
+
152
+ ### 2.6 DeepSpeed-Chat
153
+
154
+ | Field | Value |
155
+ |---|---|
156
+ | **Repo** | https://github.com/deepspeedai/DeepSpeedExamples (the `applications/DeepSpeed-Chat/` subtree) |
157
+ | **License** | Apache-2.0 |
158
+ | **Maturity** | **Effectively stale.** The README's "Latest News" cuts off in August 2023. CI patches in 2025 (e.g., #6982, #7015, #7052) are dependency-pinning fixes, not feature work. The roadmap to "generalize DeepSpeed-RLHF abstraction for a wider range of RL algorithms" has not landed. |
159
+ | **Algorithms** | PPO (3-stage RLHF) + DPO. **No GRPO. No DAPO.** |
160
+ | **Custom-loss extension point** | `DeepSpeedPPOTrainer.train_rlhf` / `actor_loss_fn` / `critic_loss_fn`. Editable but not config-hooked. |
161
+ | **Integration cost** | Moderate, but you inherit a frozen architecture. ~500 LOC. |
162
+ | **Data-path fit** | Prompt-dataset-shaped; no OpenEnv. |
163
+
164
+ **Verdict:** Pioneering for its time, no longer competitive on algorithm coverage. **Reject.**
165
+
166
+ ---
167
+
168
+ ## 3. Meta PyTorch Agentic Stack — Infra vs Training Split
169
+
170
+ The brief asked specifically to **distinguish coordination/infra from training-stack** components. The answer is:
171
+
172
+ | Component | Layer | Status (May 2026) | In our framework? |
173
+ |---|---|---|---|
174
+ | **Monarch** (`meta-pytorch/monarch`) | **Coordination / Infra** — actor mesh, RDMA data plane, supervision trees | **Active.** v0.4 GA (2026-03-26), v0.5 dev wheels daily, BSD-3 | **Yes — recommended addition.** |
175
+ | **TorchTitan** (`pytorch/torchtitan`) | **Training stack** — FSDP2 / TP / PP / CP / float8 / MXFP8 | **Active.** BSD-3, "extensive development". Has an experimental GRPO recipe (`experiments/rl/simple_grpo_sum_digits.py`) on Monarch. | **Indirectly** — already the trainer inside PRIME-RL and TorchForge. We adopt it transitively, not as a direct dependency. |
176
+ | **TorchForge** (`meta-pytorch/forge`) | RL post-training library | **Development paused** per the repo banner; consolidating into TorchTitan. ~685★. | **Pattern reference only.** Lift the Generator/Trainer/Rewarder *shape* but do not depend on the package. |
177
+ | **torchchat** (`pytorch/torchchat`) | **Inference / local deployment** | Active for its own scope, but: not a training framework; no RL surface. | **Out of scope.** |
178
+ | **OpenEnv** (`meta-pytorch/OpenEnv`) | Environment standard (covered separately) | Active. Already a v0 dependency of the framework. | Already adopted. |
179
+
180
+ ### 3.1 Monarch
181
+
182
+ | Field | Value |
183
+ |---|---|
184
+ | **Repo** | https://github.com/meta-pytorch/monarch |
185
+ | **License** | BSD-3-Clause |
186
+ | **PyPI** | `torchmonarch`; v0.4.1 stable (2026-04-08), v0.5.0 dev wheels published daily through 2026-05-05 |
187
+ | **Maturity** | **Experimental but actively shipped.** "Currently in an experimental stage" per the repo's own status note, but with a functioning K8s operator, weekly wheels, ProcessMesh/ActorMesh APIs stable enough for VeRL backend experiments. |
188
+ | **Role in our stack** | **Pure coordination/infra.** It does not train models. It hosts whatever trainer you bring (TRL, VeRL, PRIME-RL, TorchTitan) as `Actor` subclasses on a `ProcMesh`. The `monarch.spmd.SPMDActor` automatically configures `RANK`/`LOCAL_RANK`/`WORLD_SIZE` for any PyTorch-distributed script — i.e., we can lift our existing TRL or PRIME-RL workers into Monarch with minimal change. |
189
+ | **Key abstractions** | `ProcMesh` (processes × hosts × GPUs), `ActorMesh` (typed actors with `@endpoint` methods), supervision trees, RDMA buffers, distributed tensors / DTensor integration. Underlying runtime: `hyperactor` (Rust). |
190
+ | **Why over Ray** | Tighter PyTorch/DTensor integration; explicit RDMA data plane (Ray uses object store + standard networking); single-controller mental model maps directly to RL post-training (one controller orchestrates Generator + Trainer + Rewarder + Env actors). |
191
+ | **Integration cost into Composer Replication** | **~300 LOC + ops**: (a) wrap our PRIME-RL trainer as an `SPMDActor`; (b) wrap our vLLM rollout server as an `Actor` with an `@endpoint generate(prompts)` method; (c) write a single controller script that creates a `ProcMesh`, spawns both meshes, and shuttles `DataProto`-shaped messages; (d) Recipe doc. The ops cost is the harder half — Monarch's K8s operator is new (v0.2.0+). |
192
+ | **Risk** | Pre-1.0; API churn possible (e.g., `KubernetesJob.add_mesh` signature changed in v0.5). Mitigation: pin to `torchmonarch==0.4.1` for v0.2 of our framework. |
193
+
194
+ ### 3.2 TorchTitan
195
+
196
+ | Field | Value |
197
+ |---|---|
198
+ | **Repo** | https://github.com/pytorch/torchtitan |
199
+ | **License** | BSD-3-Clause |
200
+ | **Maturity** | **Active development** for pretraining; **experimental** for RL. The GRPO experiment (`torchtitan/experiments/rl/simple_grpo_sum_digits.py`) is in `experiments/`, which the repo explicitly disclaims as removable. |
201
+ | **Role** | **Training stack only.** Provides FSDP2 (per-parameter sharding), Tensor Parallel (incl. async TP), Pipeline Parallel (zero-bubble), Context Parallel (long-context), `torch.compile`, Float8, MXFP8, DDP, HSDP. |
202
+ | **OpenEnv-aware?** | No, but the experimental `RLTrainer` integrates `vLLM` + Monarch actors, which is the same shape PRIME-RL uses. |
203
+ | **Why we don't add it directly** | **PRIME-RL already uses TorchTitan-equivalent FSDP2 internals**, and TorchForge's training core was TorchTitan. Adding TorchTitan as a *direct* dependency would mean writing our own RL loop on top of it — that's TorchForge's job, and Meta paused exactly that effort. The right move is to depend on PRIME-RL, which has battle-tested distributed training patterns equivalent to TorchTitan's, and revisit TorchTitan directly only when we genuinely need its experimental zero-bubble PP or MXFP8 paths. |
204
+
205
+ ### 3.3 TorchForge (Paused)
206
+
207
+ - Repo banner: **"Development paused — LLM training consolidating in TorchTitan."**
208
+ - ~685 ★, 100+ open issues, last meaningful release in early 2026.
209
+ - Patterns we should still copy:
210
+ - Generator/Trainer/Rewarder ActorMesh decomposition
211
+ - TorchStore-style RDMA weight broadcast
212
+ - Async toggle between sync PPO-like and fully async off-policy
213
+ - **We do not add a TorchForge dependency.** Architectural reference only.
214
+
215
+ ### 3.4 torchchat (Out of Scope)
216
+
217
+ - Inference / local deployment of LLMs (Eager / `torch.compile` / AOT Inductor / ExecuTorch / mobile).
218
+ - No training, no RL.
219
+ - Mentioned in the brief for completeness; ruled out cleanly.
220
+
221
+ ---
222
+
223
+ ## 4. Comparison Matrix
224
+
225
+ ### 4.1 RL Frameworks
226
+
227
+ | Framework | License | Last release | Maturity | GRPO | DAPO | Custom-loss hook | OpenEnv fit | Est. integration LOC |
228
+ |---|---|---|---|---|---|---|---|---|
229
+ | **TRL** (baseline) | Apache-2.0 | Active | Production | ✅ | partial (tricks land per release) | Subclass `GRPOTrainer.compute_loss` | ✅ native (Oct 2025 OpenEnv guide) | already integrated |
230
+ | **VeRL** (baseline) | Apache-2.0 | Active | Production | ✅ | ✅ | `core_algos.py` + worker subclass | shim via Ray dataloader | already skeleton |
231
+ | **OpenRLHF** | Apache-2.0 | v0.9.10 (2026-04-04) | Production | ✅ | ✅ | `openrlhf/models/loss.py` + Trainer subclass; **no config hook** | shim via `agent_func_path` | ~400–600 |
232
+ | **PRIME-RL** ⭐ | Apache-2.0 | v0.5.0 (2026-03-30) | Prod-research | ✅ | partial (DPPO+KL variant; not labeled DAPO) | **`CustomLossConfig` import_path — first-class** | ✅ via `verifiers` (OpenEnv-compatible) | **~200–300** |
233
+ | **NeMo-Aligner** | Apache-2.0 | Active | Research-leaning | ❌ | ❌ | Megatron model `loss_func` | none; JSONL only | ~800–1,200 |
234
+ | **Unsloth (RL)** | Apache-2.0 | Active | Production (SFT) / preview (RL) | ✅ (via TRL patch) | ✅ (via TRL patch) | Loss kernels in closed `unsloth_zoo`; effectively unhookable | TRL-shaped | ~1,000+ (forking) |
235
+ | **LLaMA-Factory** | Apache-2.0 | Active | Production | ❌ (delegates to EasyR1) | ❌ | TRL `Custom*Trainer` subclass | TRL-shaped | ~400 |
236
+ | **DeepSpeed-Chat** | Apache-2.0 | Stale (Aug 2023 features; 2025 only CI fixes) | Effectively maintained-only | ❌ | ❌ | `DeepSpeedPPOTrainer` subclass | none | ~500 |
237
+
238
+ ### 4.2 Meta PyTorch Stack
239
+
240
+ | Component | Layer | License | Status | In recommendation? |
241
+ |---|---|---|---|---|
242
+ | **Monarch** ⭐ | Coordination / actor mesh | BSD-3 | Active (v0.4 GA, v0.5 dev) | **Yes** |
243
+ | **TorchTitan** | Training stack | BSD-3 | Active; RL experimental | Indirect (via PRIME-RL) |
244
+ | **TorchForge** | RL library | BSD-3 | **Paused** | No — patterns only |
245
+ | **torchchat** | Inference / deployment | BSD-3 | Active | No — out of scope |
246
+ | **OpenEnv** | Environment standard | (Hub) | Active | Already adopted |
247
+
248
+ ---
249
+
250
+ ## 5. Recommendation Rationale
251
+
252
+ ### 5.1 Why PRIME-RL, not OpenRLHF
253
+
254
+ OpenRLHF is in many ways the safer pick: more stars, more contributors, more algorithm coverage (it explicitly ships DAPO). The deciding factor is **the shape of our custom loss**.
255
+
256
+ The Composer Replication Framework's signature contribution is the **three-channel reward**:
257
+
258
+ 1. **RLVR** — tests-pass scalar from the OpenEnv environment.
259
+ 2. **Composer-style hint-distill (SDPO/OPSD)** — the model self-teaches against its own hint-conditioned roll-outs; needs `teacher_logprobs` aligned to the rollout token grid.
260
+ 3. **Trace-replay multi-teacher PRM** (the novel bit) — N frozen external teachers' precomputed token-level distributions, replayed against the on-policy rollout.
261
+
262
+ PRIME-RL's `LossInputs` dataclass already exposes exactly the tensors we need:
263
+ ```
264
+ trainer_logprobs, inference_logprobs, teacher_logprobs, advantages, loss_mask
265
+ ```
266
+ A custom 3-channel loss is roughly:
267
+ ```python
268
+ def composer_three_channel_loss(li: LossInputs, *, hint_weight, replay_weight, replay_logits) -> LossOutputs:
269
+ rlvr = grpo_term(li.trainer_logprobs, li.inference_logprobs, li.advantages, li.loss_mask)
270
+ hint = kl_term(li.trainer_logprobs, li.teacher_logprobs, li.loss_mask)
271
+ replay = kl_term(li.trainer_logprobs, replay_logits, li.loss_mask)
272
+ return LossOutputs(loss=rlvr + hint_weight * hint + replay_weight * replay, ...)
273
+ ```
274
+ We register this with `trainer.loss.type = "custom"` + `import_path` and we're done. No subclassing, no `exec()`-patched template, no Megatron model wrapping.
275
+
276
+ OpenRLHF would require us to (a) add a `ThreeChannelLoss` `nn.Module` to `openrlhf/models/loss.py`, (b) subclass `PPOTrainer` (or equivalent GRPO trainer) to construct it with the right teacher-logprob plumbing, and (c) carry that fork forward. ~2× the LOC, plus a fork to maintain.
277
+
278
+ A second factor: PRIME-RL's `verifiers` env protocol is a direct precursor of OpenEnv's wire shape (HTTP/WebSocket env servers, typed observations). Our existing OpenEnv-compatible TRL data path translates with a thin adapter. OpenRLHF's `agent_func_path` is more of an escape hatch than a contract.
279
+
280
+ A third factor: PRIME-RL was *built for decentralized training* (INTELLECT-1/2). Even though our v0.1 stays on a single cluster, the v0.2 multi-DC story drops in cleanly. OpenRLHF is Ray-on-one-cluster by design.
281
+
282
+ ### 5.2 Why Monarch, not TorchTitan or TorchForge
283
+
284
+ Among the four Meta-stack components in the brief, only one is both (a) ours to add and (b) genuinely new functionality:
285
+
286
+ - **TorchForge** is paused — depending on it now is a known dead end.
287
+ - **TorchTitan** is already inside PRIME-RL transitively (PRIME-RL uses FSDP2 plus a SHARDCAST weight-broadcast layer that is morally equivalent to what TorchTitan offers). Adding TorchTitan as a *direct* dependency means writing our own RL loop on top of it, which is exactly what TorchForge tried and paused. We get TorchTitan's benefits without owning the integration.
288
+ - **torchchat** is for local inference / mobile deployment — out of scope.
289
+ - **Monarch** is the unique value: a PyTorch-native actor mesh that lets us replace Ray (PRIME-RL's current orchestration substrate) with something that has explicit RDMA, supervision trees, and ProcMesh/ActorMesh primitives that map directly onto our (Generator, Trainer, Rewarder, EnvServer) topology.
290
+
291
+ The migration path is incremental:
292
+ - **v0.1:** PRIME-RL on Ray (current). Monarch listed as roadmap.
293
+ - **v0.2:** Wrap PRIME-RL's Trainer as a `monarch.spmd.SPMDActor`, vLLM Generator as an `Actor` with an `@endpoint generate()`. Switch the orchestrator from `ray.init()` to `this_host().spawn_procs()`.
294
+ - Risk-mitigation: pin to `torchmonarch==0.4.1` (the last GA release before v0.5 dev). Keep a Ray fallback path active until v0.2 is stable.
295
+
296
+ ---
297
+
298
+ ## 6. Integration Sketches
299
+
300
+ ### 6.1 PRIME-RL Recipe skeleton
301
+
302
+ `recipes/composer_v0_prime_rl.toml` (~30 LOC):
303
+
304
+ ```toml
305
+ # composer_v0_prime_rl.toml
306
+ [model]
307
+ name = "Qwen/Qwen3-32B" # or Kimi-K2.5 when MoE support lands
308
+
309
+ [data]
310
+ env = "swe_bench_lite" # via verifiers EnvServer; wraps our OpenEnv adapter
311
+ batch_size = 64
312
+ group_size = 16
313
+
314
+ [trainer]
315
+ algorithm = "grpo"
316
+ [trainer.loss]
317
+ type = "custom"
318
+ import_path = "composer_replication.losses.composer_three_channel_loss"
319
+ [trainer.loss.kwargs]
320
+ hint_weight = 0.5
321
+ replay_weight = 0.25
322
+ replay_logits_path = "/data/teachers/precomputed_replay.zarr"
323
+
324
+ [teacher]
325
+ model = "Qwen/Qwen3-32B" # same as policy = self-teacher for hint-distill
326
+ hint_template = "composer.hint_v1"
327
+
328
+ [orchestrator]
329
+ sync_mode = "async"
330
+ shardcast = true
331
+ ```
332
+
333
+ `composer_replication/losses.py` (~120 LOC):
334
+
335
+ ```python
336
+ # composer_replication/losses.py
337
+ from prime_rl.trainer.rl.loss import LossInputs, LossOutputs
338
+
339
+ def composer_three_channel_loss(
340
+ li: LossInputs,
341
+ *,
342
+ hint_weight: float,
343
+ replay_weight: float,
344
+ replay_logits_handle: str,
345
+ ) -> LossOutputs:
346
+ # 1. RLVR via GRPO surrogate
347
+ rlvr = grpo_surrogate(li.trainer_logprobs, li.inference_logprobs,
348
+ li.advantages, li.loss_mask)
349
+
350
+ # 2. Hint-distill: KL(policy || hint-conditioned teacher)
351
+ hint = masked_kl(li.trainer_logprobs, li.teacher_logprobs, li.loss_mask)
352
+
353
+ # 3. Trace-replay: KL(policy || precomputed multi-teacher mixture)
354
+ replay = trace_replay_kl(li.trainer_logprobs, replay_logits_handle, li.loss_mask)
355
+
356
+ total = rlvr + hint_weight * hint + replay_weight * replay
357
+ return LossOutputs(
358
+ loss=total,
359
+ metrics={"rlvr": rlvr.item(), "hint": hint.item(), "replay": replay.item()},
360
+ )
361
+ ```
362
+
363
+ Plus `docs/recipes/composer_v0_prime_rl.md` (~50 LOC) describing data layout, teacher precomputation, and reproducibility hashes.
364
+
365
+ **Total: ~200 LOC of code + ~30 LOC config + ~50 LOC docs ≈ 280 LOC.**
366
+
367
+ ### 6.2 Monarch wrap-up sketch (v0.2)
368
+
369
+ ```python
370
+ # composer_replication/orchestrator/monarch_runner.py (~120 LOC)
371
+ from monarch.actor import Actor, endpoint
372
+ from monarch.proc_mesh import this_host, ProcMesh
373
+
374
+ class TrainerActor(Actor):
375
+ @endpoint
376
+ async def step(self, batch): ...
377
+
378
+ class GeneratorActor(Actor):
379
+ @endpoint
380
+ async def generate(self, prompts): ...
381
+
382
+ class RewarderActor(Actor):
383
+ @endpoint
384
+ async def score(self, traj): ...
385
+
386
+ async def main(cfg):
387
+ train_mesh = await this_host().spawn_procs(TrainerActor, hosts=4, gpus=8)
388
+ gen_mesh = await this_host().spawn_procs(GeneratorActor, hosts=2, gpus=8)
389
+ rew_mesh = await this_host().spawn_procs(RewarderActor, hosts=1, gpus=2)
390
+
391
+ async for step in range(cfg.steps):
392
+ prompts = await env.batch()
393
+ traj = await gen_mesh.generate.broadcast(prompts)
394
+ rewards = await rew_mesh.score.broadcast(traj)
395
+ await train_mesh.step.broadcast({"traj": traj, "rewards": rewards})
396
+ ```
397
+
398
+ **Total: ~120 LOC controller + ~50 LOC ops (K8s operator manifest) + ~80 LOC recipe doc ≈ 250 LOC.**
399
+
400
+ ---
401
+
402
+ ## 7. Sources
403
+
404
+ ### Primary
405
+
406
+ - **OpenRLHF** — https://github.com/OpenRLHF/OpenRLHF (README, Releases v0.9.10), Apache-2.0; DeepWiki: `openrlhf/models/loss.py`, `agent_func_path`.
407
+ - **PRIME-RL** — https://github.com/PrimeIntellect-ai/prime-rl (README, Releases v0.5.0), Apache-2.0; DeepWiki: `src/prime_rl/trainer/rl/loss.py`, `CustomLossConfig`, `LossInputs`/`LossOutputs`, `verifiers` integration.
408
+ - **NeMo-Aligner** — https://github.com/NVIDIA/NeMo-Aligner, Apache-2.0; DeepWiki: PPO/REINFORCE/DPO/RPO; `loss_func` on Megatron model classes.
409
+ - **Unsloth** — https://github.com/unslothai/unsloth, README RL section; DeepWiki: `patch_trl_rl_trainers()`, `unsloth_zoo` kernels, DAPO loss-type switch.
410
+ - **LLaMA-Factory** — https://github.com/hiyouga/LLaMA-Factory, Apache-2.0; DeepWiki: `CustomPPOTrainer`/`CustomDPOTrainer`, EasyR1 reference for GRPO.
411
+ - **DeepSpeed-Chat** — https://github.com/deepspeedai/DeepSpeedExamples (`applications/DeepSpeed-Chat/`), Apache-2.0; DeepWiki: 3-stage PPO, DPO; "Latest News" cutoff Aug 2023; 2025 PRs (#6982, #7015, #7052) confirming maintenance-only mode.
412
+ - **Monarch** — https://github.com/meta-pytorch/monarch, BSD-3; PyPI `torchmonarch` v0.4.1 (2026-04-08), v0.5.0 dev wheels through 2026-05-05; DeepWiki: `ProcMesh`, `ActorMesh`, `monarch.spmd.SPMDActor`.
413
+ - **TorchTitan** — https://github.com/pytorch/torchtitan, BSD-3; DeepWiki: FSDP2/TP/PP/CP, `torchtitan/experiments/rl/simple_grpo_sum_digits.py`, integration with vLLM and Monarch.
414
+ - **TorchForge** — https://github.com/meta-pytorch/forge, BSD-3, repo banner "development paused — consolidating in TorchTitan".
415
+ - **torchchat** — https://github.com/pytorch/torchchat, BSD-3; DeepWiki: inference-only (eager / `torch.compile` / AOT Inductor / ExecuTorch).
416
+
417
+ ### Companion repository docs (already present)
418
+
419
+ - `~/wiki/research/post-training-framework/04-verl-trl.md` — VeRL vs TRL deep dive.
420
+ - `~/wiki/research/post-training-framework/03-monarch-torchforge-openenv.md` — full Meta-stack survey.
421
+ - `~/wiki/research/post-training-framework/02-diloco-family.md` — DiLoCo / OpenDiLoCo / PRIME-RL / INTELLECT-2.
422
+ - `~/wiki/projects/composer-replication-framework.md` — current TL;DR and stage plan.
423
+
424
+ ### Notes on accuracy
425
+
426
+ - "DAPO" labeling: OpenRLHF and Unsloth both advertise DAPO as a first-class loss type; PRIME-RL implements a DAPO-equivalent (decoupled-clip + KL) but uses the internal name `DPPO+KL` in its default loss. For our purposes this is the same family.
427
+ - Last-commit dates and release versions are pulled from GitHub release pages (OpenRLHF, PRIME-RL) and PyPI release history (`torchmonarch`).
428
+ - Star counts and contributor counts reflect the snapshots returned by web search at the time of writing (May 2026) and will drift; the relative ordering is stable.
docs/research/SELF_DISTILLATION_LANDSCAPE.md ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Self-Distillation Landscape Audit (feeds ADR-007)
2
+
3
+ **Status:** research note, pre-experimental
4
+ **Author:** subagent audit
5
+ **Date:** 2026-05-25
6
+ **Scope:** identify 2–3 distillation-channel losses worth adding to
7
+ `composer_replication` alongside the existing GRPO + SDPO/OPSD `generalized_jsd_loss` +
8
+ multi-teacher trace-replay DPO stack.
9
+ **Bias:** additivity over novelty. We are looking for losses that COMPOSE with
10
+ what is already implemented, not duplicates of it.
11
+
12
+ ---
13
+
14
+ ## TL;DR — recommended additions
15
+
16
+ | Rank | Method | Loss role | License | LOC est. | Why it composes |
17
+ |------|--------|-----------|---------|----------|-----------------|
18
+ | 1 | **SimPO** (NeurIPS 2024) | Preference, reference-free | MIT | ~80 | Drop-in for trace-replay DPO; removes ref-model VRAM cost; orthogonal to JSD distillation channel |
19
+ | 2 | **TAID** (ICLR 2025) | Interpolated-target wrapper around any KL/JSD | Apache-2.0 | ~150 | Wraps the existing `generalized_jsd_loss` — does not replace it. Closes capacity gap on small students |
20
+ | 3 | **Entropy-Aware OPD** (ICLR 2026 Spotlight) | Token-gated forward/reverse KL mixture | CC BY 4.0 (paper); code expected | ~120 | Fixes a documented failure mode of the reverse-KL-style SDPO loss when teacher entropy is high — directly addresses a known weakness of channel 2 |
21
+
22
+ **Honourable mention:** KTO — useful only if the framework wants to ingest
23
+ binary thumbs-up/thumbs-down trace signals without preference pairs.
24
+ **Not recommended:** GKD, DistiLLM, MiniLLM, Self-Rewarding LM (rationale at end).
25
+
26
+ ---
27
+
28
+ ## Audit method
29
+
30
+ For each candidate paper (the seven the user named, plus 2026 follow-ups
31
+ discovered via Exa search restricted to `category=research paper, startPublishedDate=2026-01-01`)
32
+ we verified:
33
+
34
+ 1. **Primary source exists.** arXiv abstract page reachable; HTML body parsed
35
+ to extract the actual loss formula (not summarised from secondary sources).
36
+ 2. **Code is real.** Official repo's README was fetched, `last push` date and
37
+ star count recorded. Forks of MiniLLM/DistiLLM that are no longer maintained
38
+ were marked as such.
39
+ 3. **License is permissive enough.** MIT, Apache-2.0, BSD, CC BY 4.0 are
40
+ acceptable for inclusion. GPL or research-only would be flagged.
41
+ 4. **Composability check.** Read the framework's existing
42
+ `composer_replication/__init__.py` and `research/05-trace-replay-distillation.md`,
43
+ then asked: *does this loss replace something we have, or stack on top?*
44
+
45
+ ---
46
+
47
+ ## Candidate 1 — SimPO (Simple Preference Optimization) ⭐ RECOMMENDED
48
+
49
+ ### Sources
50
+ - **arXiv:** https://arxiv.org/abs/2405.14734 (Meng, Xia, Chen — UVA + Princeton, NeurIPS 2024)
51
+ - **GitHub:** https://github.com/princeton-nlp/SimPO
52
+ - License: **MIT**
53
+ - 949 stars, 74 forks, last commit 2024-10-12 (mature, post-NeurIPS)
54
+ - Built on top of `huggingface/alignment-handbook`
55
+ - Maturity: **production-ready**. Released checkpoints for Mistral, Llama-3, Gemma-2 base/instruct. Reproducible training configs ship with the repo.
56
+
57
+ ### Loss core (reference-free preference)
58
+ SimPO replaces the DPO log-ratio (which requires keeping `π_ref` in memory)
59
+ with the **average log-probability** of the sequence under the policy, plus
60
+ a **target reward margin** γ:
61
+
62
+ ```
63
+ r(x, y) = (β / |y|) · log π_θ(y | x) ← length-normalised implicit reward
64
+ (no reference model)
65
+
66
+ L_SimPO(π_θ) = −E_{(x, y_w, y_l) ~ D} [
67
+ log σ( r(x, y_w) − r(x, y_l) − γ )
68
+ ]
69
+ ```
70
+
71
+ where `β` is a temperature (typically 2.0–10) and `γ` is the desired margin
72
+ between chosen and rejected (the repo recommends `γ/β ≈ 0.5` as a starting
73
+ point). Two consequences: (i) no `π_ref` forward pass per step → roughly half
74
+ the memory, and (ii) the implicit reward is exactly the quantity the model
75
+ generates from at decode time, removing a known DPO pathology where
76
+ decoding-time and training-time rewards diverge.
77
+
78
+ ### Why it composes with the existing stack
79
+ - The framework's **channel 3** is multi-teacher trace-replay DPO. SimPO is a
80
+ drop-in replacement for the DPO step inside that channel — same `(x, y_w, y_l)`
81
+ data contract, different loss head. So the trace-replay harvester does not
82
+ change at all.
83
+ - It does **not** touch channel 2 (SDPO/OPSD `generalized_jsd_loss`). The two
84
+ are complementary: JSD-distillation transfers token-level teacher knowledge,
85
+ SimPO sharpens preference structure between trace alternatives.
86
+ - It does **not** duplicate GRPO either. GRPO is online-policy RLVR;
87
+ SimPO is offline preference. Different data sources.
88
+ - The published Mistral-7B and Llama-3-8B SimPO results beat DPO by 4–6 points
89
+ on AlpacaEval-2 LC, which directly translates to "if we already have channel-3
90
+ pairs, SimPO is a free upgrade".
91
+
92
+ ### Implementation cost
93
+ - **~80 LOC** for the trainer hook; the loss itself is ~15 lines (log-probs,
94
+ length-normalise, margin, BCE).
95
+ - Dependencies: nothing new — `torch`, `transformers` already in repo.
96
+ - The reference implementation is a single file in `princeton-nlp/SimPO`
97
+ (`scripts/run_simpo.py` + `alignment/` trainer subclass) under MIT, so we can
98
+ vendor it exactly as we did with OPSD.
99
+
100
+ ---
101
+
102
+ ## Candidate 2 — TAID (Temporally Adaptive Interpolated Distillation) ⭐ RECOMMENDED
103
+
104
+ ### Sources
105
+ - **arXiv:** https://arxiv.org/abs/2501.16937 (Shing, Misaki, Bao, Yokoi, Akiba — Sakana AI, ICLR 2025)
106
+ - **GitHub:** https://github.com/SakanaAI/TAID
107
+ - License: **Apache-2.0**
108
+ - 121 stars, last push 2025-10-06 (actively maintained)
109
+ - Reference implementations of GKD, DistiLLM, Adaptive-KL, CTKD, DKD are also in `src/distil_losses/` for free
110
+ - Released artefacts: `TAID-LLM-1.5B`, `TAID-VLM-2B` on HuggingFace (so the loss is verified at non-trivial scale).
111
+ - Maturity: **published, single-author commits** but reproducibly trained two SoTA compact models with it.
112
+
113
+ ### Loss core (interpolated teacher target)
114
+ Standard distillation losses (forward KL, reverse KL, JSD, including the
115
+ `generalized_jsd_loss` we already have) target a **fixed** teacher distribution
116
+ `p_T`. TAID replaces this fixed target with a **time-dependent interpolated
117
+ target** `p_t` that starts close to the student and moves toward the teacher
118
+ as training progresses:
119
+
120
+ ```
121
+ p_t(y | x) = (1 − t) · q_θ_stop(y | x) + t · p_T(y | x) (1)
122
+
123
+ J_TAID(θ; t) = D_KL( p_t ‖ q_θ ) (2)
124
+ ```
125
+
126
+ `q_θ_stop` is the student's own current distribution with stop-gradient. The
127
+ interpolation coefficient `t ∈ [t_start, 1]` is updated each step by an
128
+ **adaptive momentum schedule** that grows `t` faster when training loss is
129
+ falling and slower when it stalls — this is the "temporally adaptive" part.
130
+ The Sakana paper proves (Theorem 4.1) that for the regression analogue this
131
+ schedule provably prevents the mode-collapse failure mode of pure
132
+ self-distillation.
133
+
134
+ Critically, `D_KL(p_t ‖ q_θ)` is just any divergence on shifted target — you
135
+ can equally well plug in JSD, reverse KL, or **the generalized_jsd_loss the
136
+ framework already exports**. TAID is therefore a *wrapper around an existing
137
+ divergence*, not a competing divergence.
138
+
139
+ ### Why it composes with the existing stack
140
+ - It **wraps** `composer_replication.opsd.generalized_jsd_loss` rather than
141
+ replacing it. The change is "compute the JSD against `p_t` instead of
142
+ `p_T`" — a few lines around the existing call site.
143
+ - Addresses a documented weakness of OPSD-style self-distillation: when the
144
+ teacher's privileged-context distribution is far from the student's
145
+ capacity, the JSD signal can be noisy or push the student into mode
146
+ averaging. TAID's annealed target gives the student a curriculum.
147
+ - Empirical evidence the Sakana paper directly compares with: TAID + JSD
148
+ beats GKD + JSD beats DistiLLM + skew-KL on Phi-3 → TinyLlama distillation,
149
+ with **0.7 h / epoch** vs **9.8 h / epoch** for GKD on identical hardware.
150
+ The speed comes from not needing student-generated outputs (SGOs) at every
151
+ step the way GKD does.
152
+ - Composes additively with channel 1 (GRPO) and channel 3 (trace-replay DPO)
153
+ because TAID lives strictly inside channel 2.
154
+
155
+ ### Implementation cost
156
+ - **~150 LOC**. The change is:
157
+ 1. A `TAIDState` object that holds `t`, the EMA of training loss, and the
158
+ momentum coefficient β (default 0.99).
159
+ 2. A function `taid_target(student_logits, teacher_logits, t)` that returns
160
+ `(1−t)·softmax(student_logits.detach()) + t·softmax(teacher_logits)`.
161
+ 3. A scheduler hook that updates `t` after each backward pass per
162
+ Algorithm 1 of the paper.
163
+ - Dependencies: nothing new.
164
+ - Reference implementation in `SakanaAI/TAID/src/distil_losses/taid.py` is
165
+ Apache-2.0 — vendor-friendly, same pattern as our OPSD lift.
166
+
167
+ ---
168
+
169
+ ## Candidate 3 — Entropy-Aware On-Policy Distillation (Entropy-Aware OPD) ⭐ RECOMMENDED
170
+
171
+ ### Sources
172
+ - **OpenReview (ICLR 2026 Spotlight):** https://openreview.net/forum?id=WSRQ37tzk1
173
+ - **IBM Research page:** https://research.ibm.com/publications/entropy-aware-on-policy-distillation-of-language-models
174
+ - Authors: Woogyeol Jin, Taywon Min, Yongjin Yang, Swanand Kadhe, Yi Zhou, Dennis Wei, Nathalie Baracaldo, Kimin Lee (KAIST + IBM Research)
175
+ - Status: **ICLR 2026 Spotlight**, submission #113. License on the OpenReview record is **CC BY 4.0**.
176
+ - Code: not yet released on GitHub at the time of audit (paper accepted 2026-03-03). IBM authors typically release within the conference window. **Maturity flag: paper-ready, code-pending.** This is the only candidate where we'd need to re-implement from the paper.
177
+
178
+ ### Loss core (entropy-gated forward/reverse KL mixture)
179
+ The paper diagnoses a failure mode in the reverse-KL-on-policy distillation
180
+ recipe used by MiniLLM, OPSD, and (implicitly) by our SDPO channel: when the
181
+ **teacher distribution has high entropy at a given token**, reverse KL's
182
+ mode-seeking gradient becomes noisy and collapses the student's diversity.
183
+ Their fix: at each token `t`, gate between forward and reverse KL based on
184
+ the teacher's entropy:
185
+
186
+ ```
187
+ H_t = − Σ_v p_T(v | x, y_<t) · log p_T(v | x, y_<t) (teacher entropy)
188
+
189
+ α_t = sigmoid( (H_t − τ) / s ) ∈ (0, 1)
190
+
191
+ L_EA(θ) = E_{y ~ q_θ} Σ_t [
192
+ (1 − α_t) · D_KL( q_θ(· | x, y_<t) ‖ p_T(· | x, y_<t) ) ← reverse KL
193
+ + α_t · D_KL( p_T(· | x, y_<t) ‖ q_θ(· | x, y_<t) ) ← forward KL
194
+ ]
195
+ ```
196
+
197
+ `τ` is an entropy threshold (default ≈ 1.0 nat in their experiments) and `s`
198
+ is a temperature controlling how sharp the gate is. When the teacher is
199
+ confident (`H_t` small → `α_t ≈ 0`) the loss is pure reverse KL, identical to
200
+ MiniLLM/OPSD behaviour. When the teacher is uncertain (`H_t` large → `α_t ≈ 1`)
201
+ the loss switches to forward KL, which is mode-covering and preserves
202
+ student diversity.
203
+
204
+ Reported gains over baseline reverse-KL OPD on Qwen3-0.6B/1.7B/4B: Pass@8 on
205
+ six math benchmarks improves by +1.37 / +2.39 / +5.05 respectively. The
206
+ larger gains at larger student size suggest the failure mode reverse KL
207
+ exhibits gets *worse* with capacity, not better.
208
+
209
+ ### Why it composes with the existing stack
210
+ - It is **strictly token-wise**: same trajectory, same teacher logits, same
211
+ rollout pipeline as the existing channel 2. The only change is the loss
212
+ reduction — instead of computing `generalized_jsd_loss` with a single fixed
213
+ β, you compute a per-token mixture of forward and reverse KL with weight
214
+ given by teacher entropy.
215
+ - This is genuinely orthogonal to OPSD/SDPO. OPSD's contribution is
216
+ *privileged-context teacher distribution under student rollouts*. EA-OPD's
217
+ contribution is *which divergence to use at each token of that distribution*.
218
+ Both can be true simultaneously.
219
+ - Directly addresses a failure mode the framework's roadmap will hit:
220
+ multi-teacher trace replay (channel 3) produces high-entropy aggregated
221
+ teacher distributions at exactly the steps where teachers disagree. Those
222
+ are the steps where reverse KL behaves worst. EA-OPD's entropy gate would
223
+ automatically soften the loss on those exact tokens.
224
+ - Composes with TAID (Candidate 2) too — they operate on different axes:
225
+ TAID anneals the *target distribution*, EA-OPD chooses the *divergence
226
+ direction*. Stacking is straightforward and proposed as ADR-007 follow-up.
227
+
228
+ ### Implementation cost
229
+ - **~120 LOC** estimate (no reference code to vendor yet).
230
+ - Dependencies: nothing new. Token-level entropy is `−(p * log p).sum(-1)`,
231
+ forward KL is the existing teacher-on-student term, reverse KL is the
232
+ student-on-teacher term we already compute for the JSD in OPSD. The work is
233
+ re-shaping the existing per-token loss to expose both directions.
234
+ - **Risk note:** code not yet public. We should hold this candidate behind a
235
+ feature flag until the IBM/KAIST team releases reference code (expected by
236
+ ICLR 2026 in May). If the implementation ships sooner we should vendor and
237
+ match line-for-line; if not, we re-derive from the paper formula and add a
238
+ unit test that reproduces their toy entropy-vs-divergence plot.
239
+
240
+ ---
241
+
242
+ ## Honourable mention — KTO (Kahneman-Tversky Optimization)
243
+
244
+ - **arXiv:** https://arxiv.org/abs/2402.01306
245
+ - **Code:** integrated into HuggingFace `trl` library since v0.8 (Apache-2.0).
246
+ - License/maturity: **production**. KTO is a standard `trl` trainer alongside DPO.
247
+
248
+ ### Loss core
249
+ KTO replaces preference pairs with **per-output binary desirability** signals.
250
+ For a desirable output `y_+` and undesirable output `y_−`:
251
+
252
+ ```
253
+ r_θ(x, y) = β · log( π_θ(y|x) / π_ref(y|x) )
254
+
255
+ z_0 = E_{x', y' ~ π_θ}[ KL( π_θ(·|x') ‖ π_ref(·|x') ) ] (reference point)
256
+
257
+ L_KTO = E_{x, y_+} [λ_D · (1 − σ(r_θ(x, y_+) − z_0))] (desirable)
258
+ + E_{x, y_−} [λ_U · (1 − σ(z_0 − r_θ(x, y_−)))] (undesirable)
259
+ ```
260
+
261
+ with default `λ_D = λ_U = 1`. The derivation is via prospect theory: this is
262
+ a Kahneman-Tversky utility function applied to the implicit reward. KTO
263
+ matches DPO at 1B–30B even though it sees only `2n` binary signals where
264
+ DPO sees `n` pairs.
265
+
266
+ ### Why we down-rank it relative to the top-3
267
+ KTO is the right answer **only if** the framework wants to ingest single-side
268
+ trace signals (e.g., "this trace step succeeded" / "this step crashed the
269
+ agent") without constructing pairs. The current
270
+ `research/05-trace-replay-distillation.md` design **does** construct pairs
271
+ from multi-teacher replay (that is the whole point of the multi-teacher
272
+ variance signal), so the marginal value of KTO is small *for channel 3 as
273
+ specified*. If the trace-replay design pivots toward absolute scores per
274
+ step rather than relative pairs, KTO becomes the right loss and is already
275
+ free from `trl`. Add to the backlog as conditional.
276
+
277
+ ---
278
+
279
+ ## Audited but NOT recommended
280
+
281
+ ### GKD — Generalized Knowledge Distillation (Agarwal et al., 2023)
282
+ - **arXiv:** https://arxiv.org/abs/2306.13649 (Google DeepMind)
283
+ - **Loss core:** student samples its own outputs, teacher provides token
284
+ probabilities, divergence is generalized JSD with parameter β:
285
+ ```
286
+ D_JSD(β)(P‖Q) = β·KL(P ‖ βP+(1−β)Q) + (1−β)·KL(Q ��� βP+(1−β)Q)
287
+ ```
288
+ - **Why excluded:** **this is exactly the formula we already have** as
289
+ `composer_replication.opsd.generalized_jsd_loss` (lifted from
290
+ `siyan-zhao/OPSD`). GKD's contribution beyond the loss formula is the
291
+ on-policy student sampling protocol — which OPSD also does. No incremental
292
+ value to add.
293
+
294
+ ### DistiLLM (Ko et al., ICML 2024)
295
+ - **arXiv:** https://arxiv.org/abs/2402.03898
296
+ - **GitHub:** https://github.com/jongwooko/distillm — MIT, last push 2025-03
297
+ - **Loss core:** *Skew KL divergence* `KL(p ‖ λp + (1−λ)q)` plus an *adaptive
298
+ off-policy* student-generated-output (SGO) scheduler.
299
+ - **Why excluded:** the skew-KL is a special case of generalized JSD (set the
300
+ mixture coefficient appropriately) — same family the framework already
301
+ has. The interesting contribution, the SGO scheduler, is a process
302
+ optimisation, not a loss. The TAID paper's own ablation (Table 6) shows
303
+ TAID > Skew KL across student sizes, so TAID dominates this candidate.
304
+
305
+ ### MiniLLM (Gu et al., ICLR 2024)
306
+ - **arXiv:** https://arxiv.org/abs/2306.08543
307
+ - **GitHub:** https://github.com/microsoft/LMOps/tree/main/minillm — MIT, repo
308
+ active (last push 2026-04)
309
+ - **Loss core:** reverse KL minimised by policy-gradient on student rollouts,
310
+ with three optimisation tricks: single-step decomposition (variance
311
+ reduction), teacher-mixed sampling (anti-reward-hacking), length
312
+ normalisation.
313
+ - **Why excluded:** reverse-KL on-policy distillation **is the same recipe
314
+ family as SDPO/OPSD** the framework already implements. Adding MiniLLM
315
+ would be a parallel implementation of the same idea, not an addition.
316
+ Entropy-Aware OPD (Candidate 3) is a *strict improvement* over MiniLLM's
317
+ pure reverse-KL on exactly the failure mode MiniLLM identifies (mode
318
+ collapse in high-entropy regions).
319
+
320
+ ### Self-Rewarding Language Models (Yuan et al., 2024)
321
+ - **arXiv:** https://arxiv.org/abs/2401.10020 (Meta + NYU)
322
+ - **Why excluded:** SRLM is a *training procedure* (iterative DPO with the
323
+ model judging its own outputs), not a loss. The actual loss is plain DPO,
324
+ which the framework already supports. The procedural contribution belongs
325
+ in a future ADR on data generation, not in the distillation channel.
326
+
327
+ ### TAID's relationship to "TAID arXiv 2501.16937 if it exists"
328
+ The user asked us to verify existence. **It exists.** Submitted 2025-01-28,
329
+ ICLR 2025, code at https://github.com/SakanaAI/TAID with two released
330
+ checkpoints (`TAID-LLM-1.5B`, `TAID-VLM-2B`). Confirmed primary source.
331
+
332
+ ---
333
+
334
+ ## 2026 papers found
335
+
336
+ The targeted Exa search (`category=research paper`, `startPublishedDate=2026-01-01`)
337
+ surfaced four 2026 distillation papers worth listing for completeness:
338
+
339
+ 1. **Entropy-Aware On-Policy Distillation** — ICLR 2026 Spotlight. ⭐ Promoted to top-3 above.
340
+ 2. **KL for a KL: On-Policy Distillation with Control Variate Baseline** (arXiv 2605.07865, Oh et al., 2026-05). Variance-reduction trick for on-policy KL distillation. Useful future read but not a new loss — it's a baseline subtraction added to MiniLLM-style policy gradient.
341
+ 3. **Rethinking On-Policy Distillation: Phenomenology, Mechanism, and Recipe** (https://github.com/thunlp/OPD, Tsinghua NLP, last push 2026-04). Empirical study, not a new loss formulation.
342
+ 4. **Hybrid Policy Distillation for LLMs** (ICML 2026 poster, Zhu et al.). Combines off-policy and on-policy distillation; positioned as a recipe rather than a new loss; abstract suggests strong overlap with TAID's annealing argument.
343
+ 5. **Don't Ignore the Tail: Decoupling top-K Probabilities for Efficient Language Model Distillation** (ICML 2026 poster, Dasgupta et al.). Targets the long-tail of teacher distributions. Interesting but currently only an abstract; deferred until the camera-ready PDF is available.
344
+
345
+ None of these except Entropy-Aware OPD are mature enough (released code +
346
+ license + reproducible scale) to recommend adding right now.
347
+
348
+ ---
349
+
350
+ ## Recommended follow-up wiring
351
+
352
+ For ADR-007 the proposed addition is a `composer_replication.distillation`
353
+ sub-package with three pluggable hooks:
354
+
355
+ ```
356
+ composer_replication/
357
+ distillation/
358
+ __init__.py
359
+ targets.py # taid_target(...), fixed_target(...) ← Candidate 2
360
+ losses.py # reuses opsd.generalized_jsd_loss
361
+ # adds entropy_aware_kl_loss(...) ← Candidate 3
362
+ preference/
363
+ simpo.py # simpo_loss(...) ← Candidate 1
364
+ dpo.py # existing trace-replay path
365
+ ```
366
+
367
+ The composition rule for the total loss becomes:
368
+
369
+ ```
370
+ L_total = λ_grpo · L_GRPO (channel 1, unchanged)
371
+ + λ_distill · L_distill (channel 2, see below)
372
+ + λ_pref · L_pref (channel 3, choose DPO or SimPO)
373
+
374
+ L_distill = entropy_aware_kl_loss(
375
+ target = taid_target(student, teacher, t),
376
+ student = student,
377
+ teacher_entropy_gate = α_t
378
+ )
379
+ ```
380
+
381
+ This keeps the existing `generalized_jsd_loss` reachable as a fallback
382
+ (set `α_t ≡ 0` and `t ≡ 1` and you recover SDPO/OPSD exactly).
383
+
384
+ ---
385
+
386
+ ## Sources index
387
+
388
+ | Paper | arXiv | GitHub | License | Last push | Maturity |
389
+ |-------|-------|--------|---------|-----------|----------|
390
+ | SimPO | https://arxiv.org/abs/2405.14734 | https://github.com/princeton-nlp/SimPO | MIT | 2024-10-12 | Production |
391
+ | TAID | https://arxiv.org/abs/2501.16937 | https://github.com/SakanaAI/TAID | Apache-2.0 | 2025-10-06 | Production |
392
+ | Entropy-Aware OPD | n/a (OpenReview WSRQ37tzk1) | code-pending | CC BY 4.0 (paper) | n/a | Paper-only |
393
+ | KTO | https://arxiv.org/abs/2402.01306 | huggingface/trl (built-in) | Apache-2.0 | continuous | Production |
394
+ | GKD | https://arxiv.org/abs/2306.13649 | (no official repo from authors; reproduced inside SakanaAI/TAID and jongwooko/distillm) | n/a | n/a | Reference only |
395
+ | DistiLLM | https://arxiv.org/abs/2402.03898 | https://github.com/jongwooko/distillm | (no LICENSE file at audit time) | 2025-03-13 | Research |
396
+ | MiniLLM | https://arxiv.org/abs/2306.08543 | https://github.com/microsoft/LMOps/tree/main/minillm | MIT | 2026-04-08 | Production |
397
+ | Self-Rewarding LM | https://arxiv.org/abs/2401.10020 | (no canonical repo; integrated into many forks) | n/a | n/a | Procedure, not a loss |
398
+
399
+ ---
400
+
401
+ ## Notes for ADR-007 author
402
+
403
+ 1. **SimPO and TAID can land independently and without coordination.** They
404
+ touch different files and do not compete.
405
+ 2. **Entropy-Aware OPD should land last.** Wait for the IBM/KAIST authors'
406
+ code release; if it's not out by the time we want to ship the change, the
407
+ formula is simple enough to re-derive but we should pin a unit test that
408
+ reproduces the paper's Figure 3 entropy-vs-divergence behaviour.
409
+ 3. **Do not also pull in GKD/DistiLLM/MiniLLM.** Their loss contributions are
410
+ strict subsets of what (TAID + Entropy-Aware OPD + existing
411
+ `generalized_jsd_loss`) covers.
412
+ 4. **KTO should be added as a backlog item** with a "trigger" condition:
413
+ when the trace-replay reward design moves from preference pairs to per-step
414
+ binary signals, switch on the `trl.KTOTrainer` path.
415
+
416
+ ---
417
+
418
+ *Absolute path of this report:* `/mnt/e/CS/HF/composer-replication-framework/docs/research/SELF_DISTILLATION_LANDSCAPE.md`
docs/research/WAVE_13_FINAL_REVIEW.md ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wave 13 Adversarial Cross-Model Review
2
+
3
+ **Reviewer:** Claude Opus 4.7 (sub-agent via delegate_task)
4
+ **Date:** 2026-05-26
5
+ **Scope:** Wave 13 additions only (35 new tests, 4 ADRs, 6 new modules)
6
+ **Method:** Read-and-grep audit + targeted test runs (CPU)
7
+
8
+ ## Top-line verdict
9
+
10
+ **CONDITIONAL PASS with two BLOCKERs.** Wave 13 substantially advances
11
+ the brief expansion (serverless DiLoCo abstraction, replaysim
12
+ normalization, three distillation losses, PRIME-RL recipe, Monarch
13
+ tie-in). The **distillation losses are the strongest deliverable** —
14
+ real, well-tested, mathematically faithful to the cited papers. The
15
+ serverless-DiLoCo local executor + ObjectStoreAllReduce barrier are
16
+ also genuine and exercised by 3 real multi-process tests.
17
+
18
+ **However, two material claims are not test-validated, and one new
19
+ module silently produces a degenerate loss in its primary code path.**
20
+ ADR claims that say "X is added to compose_loss" describe code that
21
+ wasn't actually written. The MockManager → DiLoCo "drop-in" is
22
+ unverified end-to-end.
23
+
24
+ Wave 11's reviewer found 2 genuine BLOCKERs. This review finds **2
25
+ BLOCKERs + 4 SUGGESTIONs + 2 NITs**.
26
+
27
+ ---
28
+
29
+ ## Finding 1 — BLOCKER: PRIME-RL `composer_loss.loss_fn` SDPO term is mathematically degenerate (always 0)
30
+
31
+ **Severity:** BLOCKER
32
+ **Evidence:** `composer_replication/recipes/prime_rl/composer_loss.py:79-86`
33
+
34
+ The PRIME-RL composer-loss adapter applies `unsqueeze(-1)` to `(B, T)`
35
+ log-prob tensors before passing them to `generalized_jsd_loss`, which
36
+ calls `F.log_softmax(..., dim=-1)`. Softmax of a single-element vector
37
+ is exactly 1.0; its log is 0. Therefore both `student_log_probs` and
38
+ `teacher_log_probs` are identically zero, the JSD between them is 0,
39
+ and the SDPO contribution **is always 0 regardless of `alpha_sdpo` or
40
+ the actual log-prob values.**
41
+
42
+ ```python
43
+ >>> import torch.nn.functional as F
44
+ >>> F.log_softmax(torch.randn(2, 3, 1), dim=-1)
45
+ tensor([[[0.],[0.],[0.]],[[0.],[0.],[0.]]])
46
+ ```
47
+
48
+ The docstring calls this "a deliberate approximation," but it is not
49
+ an approximation — it's a mathematically degenerate operation that
50
+ silently disables channel 2.
51
+
52
+ **Fix direction:**
53
+ - Gate the SDPO branch behind `len(trainer_lp.shape) >= 3`, raising
54
+ `NotImplementedError` until PRIME-RL surfaces full logits.
55
+ - Update `prime_rl_recipe.md` and ADR-006 to stop claiming PRIME-RL
56
+ has working SDPO; mark it deferred.
57
+
58
+ ---
59
+
60
+ ## Finding 2 — BLOCKER: ADR-007 declares `compose_loss` kwargs that were never added
61
+
62
+ **Severity:** BLOCKER
63
+ **Evidence:**
64
+ - `docs/adrs/ADR-007-self-distillation-losses.md:103-108` claims:
65
+ > `composer_replication.compose_loss` gets new optional kwargs:
66
+ > - `dpo_variant: Literal["dpo", "simpo"] = "dpo"` — switches channel 3
67
+ > - `sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none"` — wraps channel 2
68
+ > - `taid_schedule_step: int | None = None`
69
+ > - `taid_total_steps: int | None = None`
70
+ - `composer_replication/loss.py:54-65` actual signature has **none**
71
+ of these. `grep -n "dpo_variant\|sdpo_wrapper\|taid"
72
+ composer_replication/loss.py` returns empty.
73
+
74
+ The new losses live in `composer_replication.distillation` as
75
+ standalone functions but **are not wired into the framework's actual
76
+ loss composition.** A user reading ADR-007 + the README would believe
77
+ `compose_loss(model, inputs, dpo_variant="simpo", sdpo_wrapper="taid", ...)`
78
+ works; it would raise `TypeError`. The 17 distillation tests verify
79
+ the standalone losses but never exercise integration.
80
+
81
+ **Fix direction:**
82
+ - Either (a) add the kwargs to `compose_loss` and write at least one
83
+ integration test combining e.g. SDPO+TAID (~30 LOC change), or
84
+ - (b) downgrade ADR-007 status to "Standalone losses landed;
85
+ integration deferred to Wave 14."
86
+
87
+ ---
88
+
89
+ ## Finding 3 — SUGGESTION: `default.yaml` replaysim recipe uses string ops on list-of-dict fields
90
+
91
+ **Severity:** SUGGESTION (would be BLOCKER if a test exercised the real path)
92
+ **Evidence:**
93
+ - `composer_replication/recipes/replaysim/default.yaml` configures
94
+ `text_length_filter`, `words_num_filter`, `special_characters_filter`,
95
+ `document_deduplicator` with `text_keys: ["chosen", "rejected"]`.
96
+ - In the record produced by `_dpo_pair_to_dj_record`, `chosen` and
97
+ `rejected` are **lists of dicts**
98
+ (`[{"role": "assistant", "content": "..."}]`) — not strings.
99
+ - data-juicer's `text_length_filter` expects string-typed fields;
100
+ running it on a list will either crash or no-op silently.
101
+
102
+ The reason no test catches this: tests only validate the real path *if
103
+ data-juicer is installed*, and even then only check `__init__` succeeds.
104
+ There is no test that calls `normalize()` against a real data-juicer
105
+ executor with the default recipe.
106
+
107
+ **Fix direction:**
108
+ - Reshape `_dpo_pair_to_dj_record` to extract `content` strings
109
+ alongside the messages-format list.
110
+ - Add one test (skip-marked unless `data_juicer` is importable) that
111
+ runs the real op-graph on 3 hand-crafted records.
112
+
113
+ ---
114
+
115
+ ## Finding 4 — SUGGESTION: MockManager → torchft.DiLoCo "drop-in" claim is unverified end-to-end
116
+
117
+ **Severity:** SUGGESTION
118
+ **Evidence:**
119
+ - `composer_replication/diloco/serverless/allreduce.py:188-191` claims
120
+ MockManager "drops into" `make_diloco_outer_loop`.
121
+ - The only test covering MockManager (`test_mock_manager_shape_compat`)
122
+ is a `hasattr` smoke that calls `.allreduce` on a `world_size=1`
123
+ store (passthrough).
124
+ - torchft.Manager has additional surface area
125
+ (`current_step`, `is_leader`, `_pg`, `report_error`,
126
+ internal step accounting) that DiLoCo's `_apply_pseudogradient`
127
+ may consult depending on version.
128
+
129
+ **Fix direction:**
130
+ - Add a single integration test that constructs
131
+ `make_diloco_outer_loop(manager=MockManager(store), ...)` against a
132
+ tiny `nn.Linear` and runs one outer round — even single-process.
133
+ - Audit `torchft/local_sgd.py` for the `Manager`-rooted call sites and
134
+ add stubs for any methods DiLoCo actually consults beyond `allreduce`.
135
+
136
+ ---
137
+
138
+ ## Finding 5 — SUGGESTION: README claim "9 multi-process tests" is mildly inflated
139
+
140
+ **Severity:** SUGGESTION (NIT bordering)
141
+ **Evidence:**
142
+ - README.md and V1_V8_COVERAGE both state: *"9 multi-process tests
143
+ pinning the allreduce barrier."*
144
+ - Actual breakdown:
145
+ - 4 single-process unit tests + `test_mock_manager_shape_compat` (5)
146
+ - 4 multi-process tests spawning subprocesses (parametrized [2,3] of
147
+ `_runs_allreduce_across_replicas`, `_handles_multiple_rounds`,
148
+ `_reports_failed_replicas`)
149
+ - Of the 4 multi-process tests, only **3 actually exercise the
150
+ allreduce barrier**; `_reports_failed_replicas` deliberately raises
151
+ before any allreduce call.
152
+
153
+ **Wave 13 clearly does NOT fake-pass via world_size=1** — the multi-
154
+ process barrier is real. But the count is rounded up.
155
+
156
+ **Fix direction:** Replace "9 multi-process tests" with "9 tests
157
+ covering the serverless DiLoCo layer, of which 4 spawn real
158
+ subprocesses and 3 exercise the allreduce barrier across replicas."
159
+
160
+ ---
161
+
162
+ ## Finding 6 — SUGGESTION: PRIME-RL channel 1 is REINFORCE not GRPO; ignores `inference_logprobs`
163
+
164
+ **Severity:** SUGGESTION
165
+ **Evidence:** `composer_replication/recipes/prime_rl/composer_loss.py:62-68`
166
+ computes:
167
+ ```python
168
+ grpo_loss = -(advantages * trainer_lp * mask).sum() / mask.sum().clamp_min(epsilon)
169
+ ```
170
+
171
+ This is plain REINFORCE with advantage. PRIME-RL's `LossInputs`
172
+ exposes `inference_logprobs` precisely because GRPO-with-replay-buffer
173
+ requires the importance-sampling ratio
174
+ `exp(trainer_lp - inference_lp)` (PPO-style clipped objective).
175
+
176
+ The file says "SKELETON" so this isn't a hidden bug per se, but the
177
+ loss is **labeled GRPO and is not GRPO**.
178
+
179
+ **Fix direction:** Either implement the ratio + clipping (~20 LOC) or
180
+ rename channel-1 comment to "REINFORCE-with-advantage stub" with a TODO.
181
+
182
+ ---
183
+
184
+ ## Finding 7 — NIT: ModalExecutor / HFJobsExecutor are skeleton-only with `NotImplementedError` in `__init__`
185
+
186
+ **Severity:** NIT (this is documented, but README phrasing is slightly soft)
187
+ **Evidence:** Honestly documented as skeletons in the code, ADR-005,
188
+ and README. NIT: a user trying `ModalExecutor()` gets a runtime error
189
+ rather than an import-time clue.
190
+
191
+ **Fix direction:** Low priority. Update README phrase to "skeleton-only
192
+ — raises NotImplementedError until v0.x." Or use a `__getattr__` on
193
+ the package that raises a clearer message.
194
+
195
+ ---
196
+
197
+ ## Finding 8 — NIT: SimPO test uses positive log-probs (impossible values)
198
+
199
+ **Severity:** NIT
200
+ **Evidence:** `test_distillation_losses.py:27-46` calls `simpo_loss`
201
+ with `chosen=tensor([0.5, 0.4, 0.3])`. Log-probabilities are bounded
202
+ above by 0; positive values aren't possible from any softmax. The tests
203
+ still verify the formula correctly, but the test inputs aren't legal.
204
+
205
+ **Fix direction:** Use negative values — purely cosmetic.
206
+
207
+ ---
208
+
209
+ ## Cross-cutting risk check
210
+
211
+ 73 tests passed in 29.29s on the CPU-fast subset. Spike 008 5/5 still
212
+ pass. The new `composer_replication.diloco.serverless` package is
213
+ purely additive; the existing `make_diloco_outer_loop` is untouched.
214
+ **No cross-wave regressions detected on CPU.** GPU tests + slow CPU
215
+ e2e tests not re-run; regression risk low since Wave 13 doesn't touch
216
+ their dependencies.
217
+
218
+ ---
219
+
220
+ ## Summary scorecard
221
+
222
+ | Item | Verdict |
223
+ |---|---|
224
+ | Distillation module (SimPO/TAID/Entropy-Aware OPD) standalone | ✅ Real, well-tested, paper-faithful |
225
+ | Distillation integrated into `compose_loss` | ❌ **Not implemented** despite ADR-007 (Finding 2) |
226
+ | ObjectStoreAllReduce + LocalProcessExecutor | ✅ Real multi-process barrier validated |
227
+ | MockManager → DiLoCo drop-in | 🟡 Shape-checked only; integration unverified (Finding 4) |
228
+ | Modal/HFJobs adapters | 🟡 Honestly documented as skeletons (Finding 7) |
229
+ | Replaysim DJNormalizer passthrough | ✅ Works |
230
+ | Replaysim default.yaml against real data-juicer | ❌ **Recipe field types don't match record shape** (Finding 3) |
231
+ | PRIME-RL composer_loss.loss_fn | ❌ **SDPO term silently 0** (Finding 1); channel 1 is REINFORCE not GRPO (Finding 6) |
232
+ | Monarch actors | ✅ Honest skeleton; raises NotImplementedError |
233
+ | Altered-minds tie-in doc | ✅ Design-only, scoped honestly |
234
+ | 35 new tests | All pass; 3 of 4 multi-process tests are genuine (Finding 5) |
235
+
236
+ **Recommendation:** Address Findings 1 and 2 before publishing the
237
+ Wave 13 expansion as "closed." Findings 3 and 4 should be addressed
238
+ before any user attempts the real data-juicer or real torchft DiLoCo
239
+ path. Findings 5–8 are cleanup.
pyproject.toml CHANGED
@@ -16,16 +16,23 @@ keywords = [
16
  "rlvr",
17
  "grpo",
18
  "sdpo",
 
 
19
  "dpo",
20
  "diloco",
 
21
  "agentic",
22
  "coding-agents",
23
  "composer-2-5",
24
  "cursor",
25
  "trl",
26
  "verl",
 
27
  "openenv",
28
  "torchft",
 
 
 
29
  ]
30
  classifiers = [
31
  "Development Status :: 3 - Alpha",
@@ -47,17 +54,35 @@ dependencies = [
47
  replay = [
48
  "httpx>=0.27",
49
  ]
50
- # DiLoCo outer-loop optimizer
51
  diloco = [
52
  "torchft-nightly",
53
  ]
54
- # Production training (TRL GRPOTrainer subclass)
 
 
 
 
 
 
 
 
 
 
55
  train = [
56
  "trl>=0.12",
57
  "peft>=0.13",
58
  "accelerate>=1.0",
59
  "datasets>=3.0",
60
  ]
 
 
 
 
 
 
 
 
61
  # Everything for development
62
  dev = [
63
  "pytest>=8.0",
 
16
  "rlvr",
17
  "grpo",
18
  "sdpo",
19
+ "simpo",
20
+ "taid",
21
  "dpo",
22
  "diloco",
23
+ "decoupled-diloco",
24
  "agentic",
25
  "coding-agents",
26
  "composer-2-5",
27
  "cursor",
28
  "trl",
29
  "verl",
30
+ "prime-rl",
31
  "openenv",
32
  "torchft",
33
+ "monarch",
34
+ "modal",
35
+ "huggingface-jobs",
36
  ]
37
  classifiers = [
38
  "Development Status :: 3 - Alpha",
 
54
  replay = [
55
  "httpx>=0.27",
56
  ]
57
+ # DiLoCo outer-loop optimizer (single-process)
58
  diloco = [
59
  "torchft-nightly",
60
  ]
61
+ # Decoupled DiLoCo over serverless executors (per ADR-005)
62
+ serverless = [
63
+ "fsspec>=2024.6",
64
+ "huggingface_hub>=0.27", # for hf:// fsspec backend + HF Jobs
65
+ ]
66
+ # Replaysim dataset normalization (per ADR-004)
67
+ replaysim = [
68
+ "data-juicer>=1.0",
69
+ "composer-replication[replay]", # replaysim builds on the replay channel
70
+ ]
71
+ # Production training (TRL GRPOTrainer subclass — Recipe A)
72
  train = [
73
  "trl>=0.12",
74
  "peft>=0.13",
75
  "accelerate>=1.0",
76
  "datasets>=3.0",
77
  ]
78
+ # PRIME-RL recipe (Recipe C — per ADR-006)
79
+ prime-rl = [
80
+ "prime-rl>=0.5",
81
+ ]
82
+ # Monarch actor mesh (per ADR-006)
83
+ monarch = [
84
+ "monarch>=0.4.1",
85
+ ]
86
  # Everything for development
87
  dev = [
88
  "pytest>=8.0",