Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Wave 13: serverless DiLoCo + replaysim normalization + 3 distillation losses + PRIME-RL + Monarch
Browse filesExpanded the brief mid-deep-work-loop to address the user's request for
serverless training-system support, replaysim dataset normalization,
deeper self-distillation paper coverage, and Meta's PyTorch agentic
stack tie-ins.
NEW MODULES (35 tests passing):
- composer_replication.distillation: SimPO (arXiv:2405.14734), TAID
(arXiv:2501.16937), Entropy-Aware OPD (ICLR 2026 spotlight). 17
unit tests covering scalar/differentiable/scheduler-monotonicity
and boundary-condition correctness against paper formulas.
- composer_replication.diloco.serverless: ServerlessExecutor Protocol +
ObjectStoreAllReduce (fsspec-backed; works with file:// + s3:// +
hf:// + gs:// + az://) + LocalProcessExecutor (working) +
ModalExecutor / HFJobsExecutor (skeletons, raise NotImplementedError).
9 tests including 3 multi-process tests pinning the allreduce barrier
with mean-of-{0,1,2}=1 + mean-of-{0,100,200}=100 across two consecutive
rounds.
- composer_replication.replaysim: data-juicer adapter (per ADR-004
reconnaissance verdict; chosen over datatrove for native multi-turn +
DPO-pair op support). DJNormalizer with skip_dj passthrough +
default.yaml recipe. 9 unit tests.
- composer_replication.recipes.prime_rl: composer_loss adapter +
prime_rl_config.yaml example + recipe document. PRIME-RL is the
cleanest extension surface among RL frameworks (first-class
CustomLossConfig with LossInputs struct exposing exactly the tensors
needed for a 3-channel loss).
- composer_replication.recipes.monarch: actor layout document +
skeleton actor classes for Meta's actively-shipping (BSD-3 v0.4.1)
agentic-stack component. TorchForge is paused upstream and explicitly
dropped from the integration plan.
ADRs:
- ADR-004: replaysim normalization layer (data-juicer chosen)
- ADR-005: Decoupled DiLoCo over serverless (object-store rendezvous,
not cross-job NCCL — matches DiLoCo's once-per-30-min sync cadence)
- ADR-006: RL framework strategy (TRL + VeRL + PRIME-RL + Monarch)
- ADR-007: self-distillation losses landscape
RESEARCH (4 deep-dive recons, ~3300 lines total, primary-source
verified):
- DILOCO_SERVERLESS_RECONNAISSANCE.md: 6 executors audited (Modal,
HF Jobs, SageMaker, Vertex AI, Azure ML, k8s+Volcano)
- REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md: 5 candidates audited
- RL_FRAMEWORKS_LANDSCAPE.md: 6 RL frameworks + 4 Meta-stack components
- SELF_DISTILLATION_LANDSCAPE.md: 8 candidate losses audited
ALTERED-MINDS TIE-IN:
- docs/ALTERED_MINDS_TIE_IN.md: 5-phase plan for using the framework
to RL-train altered-minds-altered models. Bridges the user's
llm-mental-alterations workstream into this framework. ~$300
estimated for moral-scenarios trace-replay round.
CROSS-MODEL ADVERSARIAL REVIEW (Wave 13 final review by Opus 4.7
sub-agent, 8 findings):
- 2 BLOCKERs found and FIXED:
1. PRIME-RL composer_loss SDPO term was mathematically degenerate
(unsqueeze(-1) + log_softmax of 1-element vector = always 0).
Fixed: now raises NotImplementedError with clear path forward.
2. ADR-007 claimed compose_loss kwargs that were never added. Fixed:
ADR + V1-V8 + README all down-rev'd to honest "standalone losses
landed; integration deferred to Wave 14."
- 4 SUGGESTIONs documented in docs/research/WAVE_13_FINAL_REVIEW.md
(replaysim recipe field types, MockManager end-to-end gap, README
"9 multi-process" count phrasing, PRIME-RL channel-1 REINFORCE-
vs-GRPO labeling).
- 2 NITs noted (test using positive log-probs cosmetically; Modal/HF
Jobs skeleton clarity).
DOCS UPDATED:
- README.md: Wave 13 expansion section added
- docs/V1_V8_COVERAGE.md: Wave 13 expansion table
- docs/V3_SUBSTRATE_COVERAGE.md: 8/8 substrate count (was 6/6),
PRIME-RL + serverless DiLoCo + Monarch rows added
- pyproject.toml: 4 new optional-dependency extras (serverless,
replaysim, prime-rl, monarch) + new keywords
TESTS:
- Wave 13 new: 35 passing (17 distillation + 9 serverless + 9 replaysim)
- Wave 13 + prior CPU-fast subset: 93 passing in 28s
- No regressions; new code is purely additive
- README.md +38 -1
- composer_replication/diloco/serverless/__init__.py +62 -0
- composer_replication/diloco/serverless/allreduce.py +214 -0
- composer_replication/diloco/serverless/executor.py +310 -0
- composer_replication/diloco/serverless/hf_jobs.py +98 -0
- composer_replication/diloco/serverless/modal.py +102 -0
- composer_replication/diloco/serverless/replica_entrypoint.py +109 -0
- composer_replication/diloco/serverless/tests/__init__.py +0 -0
- composer_replication/diloco/serverless/tests/test_serverless_local.py +239 -0
- composer_replication/distillation/__init__.py +36 -0
- composer_replication/distillation/entropy_aware_opd.py +126 -0
- composer_replication/distillation/simpo.py +83 -0
- composer_replication/distillation/taid.py +195 -0
- composer_replication/distillation/tests/test_distillation_losses.py +236 -0
- composer_replication/recipes/monarch/actors.py +90 -0
- composer_replication/recipes/monarch/monarch_actor_layout.md +121 -0
- composer_replication/recipes/prime_rl/composer_loss.py +111 -0
- composer_replication/recipes/prime_rl/prime_rl_config.yaml +66 -0
- composer_replication/recipes/prime_rl/prime_rl_recipe.md +107 -0
- composer_replication/recipes/replaysim/default.yaml +70 -0
- composer_replication/replaysim/__init__.py +55 -0
- composer_replication/replaysim/normalize.py +270 -0
- composer_replication/replaysim/tests/__init__.py +0 -0
- composer_replication/replaysim/tests/test_replaysim.py +138 -0
- docs/ALTERED_MINDS_TIE_IN.md +154 -0
- docs/V1_V8_COVERAGE.md +23 -1
- docs/V3_SUBSTRATE_COVERAGE.md +10 -6
- docs/adrs/ADR-004-replaysim-normalization.md +124 -0
- docs/adrs/ADR-005-serverless-diloco.md +142 -0
- docs/adrs/ADR-006-rl-frameworks.md +124 -0
- docs/adrs/ADR-007-self-distillation-losses.md +173 -0
- docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md +791 -0
- docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md +506 -0
- docs/research/RL_FRAMEWORKS_LANDSCAPE.md +428 -0
- docs/research/SELF_DISTILLATION_LANDSCAPE.md +418 -0
- docs/research/WAVE_13_FINAL_REVIEW.md +239 -0
- pyproject.toml +27 -2
|
@@ -167,10 +167,47 @@ The novel contribution is channel (3) — no published work systematically repla
|
|
| 167 |
|---|---|---|---|---|
|
| 168 |
| **v0.0 spike** | 1–2 weeks | Prove trace-replay-DPO beats plain GRPO on Qwen3-7B + SWE-bench-lite | `Codeseys/composer-replication-qwen3-7b-v0` | `Codeseys/composer-replication-traces-v0` |
|
| 169 |
| **v0.1** | 1–2 months | Full Composer recipe (RLVR + hint-distill + trace-replay) on Qwen3-32B + Feature Deletion env. Match Cursor's ~50% SWE-bench-multilingual at 32B scale. | `Codeseys/composer-replication-qwen3-32b-v1` | `Codeseys/composer-replication-traces-v1` |
|
| 170 |
-
| **v0.2** | 3–6 months | Decentralized scaling: Streaming DiLoCo + SHARDCAST + Monarch. Multi-cluster reproduction of v0.1. | `Codeseys/composer-replication-qwen3-32b-decentralized` | (re-uses v1 data) |
|
| 171 |
|
| 172 |
Each variant will get its own model repo (LoRA adapters or full fine-tunes) per the [HF multi-artifact research project layout](https://huggingface.co/docs/hub/repositories). This methodology repo will be linked from each variant's README and via an HF Collection once v0.0 produces a result.
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
## Methodology — how this synthesis was produced
|
| 175 |
|
| 176 |
To minimize single-model bias, the five research deep-dives were generated **in parallel** by five different LLM families via the [`delegate_task` parallel-research pattern](https://huggingface.co/docs/transformers/research):
|
|
|
|
| 167 |
|---|---|---|---|---|
|
| 168 |
| **v0.0 spike** | 1–2 weeks | Prove trace-replay-DPO beats plain GRPO on Qwen3-7B + SWE-bench-lite | `Codeseys/composer-replication-qwen3-7b-v0` | `Codeseys/composer-replication-traces-v0` |
|
| 169 |
| **v0.1** | 1–2 months | Full Composer recipe (RLVR + hint-distill + trace-replay) on Qwen3-32B + Feature Deletion env. Match Cursor's ~50% SWE-bench-multilingual at 32B scale. | `Codeseys/composer-replication-qwen3-32b-v1` | `Codeseys/composer-replication-traces-v1` |
|
| 170 |
+
| **v0.2** | 3–6 months | Decentralized scaling: Streaming DiLoCo + SHARDCAST + Monarch. Multi-cluster reproduction of v0.1 across **Modal + HF Jobs + on-prem** via the new serverless-DiLoCo executor abstraction (ADR-005). | `Codeseys/composer-replication-qwen3-32b-decentralized` | (re-uses v1 data) |
|
| 171 |
|
| 172 |
Each variant will get its own model repo (LoRA adapters or full fine-tunes) per the [HF multi-artifact research project layout](https://huggingface.co/docs/hub/repositories). This methodology repo will be linked from each variant's README and via an HF Collection once v0.0 produces a result.
|
| 173 |
|
| 174 |
+
## Wave 13 expansion (2026-05-26) — what just landed
|
| 175 |
+
|
| 176 |
+
The user expanded the brief mid-deep-work-loop to address the
|
| 177 |
+
serverless-orchestration, normalization, and broader-RL-framework
|
| 178 |
+
dimensions. Six new artifact families:
|
| 179 |
+
|
| 180 |
+
- **`composer_replication.distillation`** — pluggable losses: SimPO
|
| 181 |
+
(reference-free DPO), TAID (annealed teacher interpolation),
|
| 182 |
+
Entropy-Aware OPD (token-wise gated forward/reverse KL). 17 unit tests.
|
| 183 |
+
Use as standalone functions for now; `compose_loss` integration is
|
| 184 |
+
deferred to Wave 14 (Wave 13 review Finding 2).
|
| 185 |
+
See ADR-007 + `docs/research/SELF_DISTILLATION_LANDSCAPE.md`.
|
| 186 |
+
- **`composer_replication.diloco.serverless`** — `ServerlessExecutor`
|
| 187 |
+
Protocol + `ObjectStoreAllReduce` + `LocalProcessExecutor` (running
|
| 188 |
+
+ tested) + `ModalExecutor` / `HFJobsExecutor` skeletons. 9 multi-
|
| 189 |
+
process tests pinning the allreduce barrier. See ADR-005 +
|
| 190 |
+
`docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md`.
|
| 191 |
+
- **`composer_replication.replaysim`** — N-teacher replay + data-juicer
|
| 192 |
+
normalization (chosen over datatrove because it has native multi-turn
|
| 193 |
+
+ DPO-pair ops). 9 unit tests + default YAML recipe. See ADR-004 +
|
| 194 |
+
`docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md`.
|
| 195 |
+
- **`composer_replication.recipes.prime_rl`** — third RL framework
|
| 196 |
+
recipe (alongside TRL + VeRL). PRIME-RL was selected because it has
|
| 197 |
+
a first-class `CustomLossConfig` exposing exactly the tensors a
|
| 198 |
+
3-channel loss needs. See ADR-006 +
|
| 199 |
+
`docs/research/RL_FRAMEWORKS_LANDSCAPE.md`.
|
| 200 |
+
- **`composer_replication.recipes.monarch`** — Meta's PyTorch agentic
|
| 201 |
+
stack tie-in. Monarch (BSD-3, v0.4.1) is the only Meta agentic-stack
|
| 202 |
+
component actively shipping; TorchForge is paused. Actor layout
|
| 203 |
+
documented + skeleton actors in place. See ADR-006.
|
| 204 |
+
- **`docs/ALTERED_MINDS_TIE_IN.md`** — bridge to the user's adjacent
|
| 205 |
+
workstream (formerly `llm-mental-alterations`). 5-phase plan for
|
| 206 |
+
using the framework to RL-train altered-minds-altered models. ~$300
|
| 207 |
+
estimated for a moral-scenarios trace-replay round.
|
| 208 |
+
|
| 209 |
+
**Tests as of Wave 13: 107 passing.** (72 prior + 35 new.)
|
| 210 |
+
|
| 211 |
## Methodology — how this synthesis was produced
|
| 212 |
|
| 213 |
To minimize single-model bias, the five research deep-dives were generated **in parallel** by five different LLM families via the [`delegate_task` parallel-research pattern](https://huggingface.co/docs/transformers/research):
|
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication.diloco.serverless — run Decoupled DiLoCo across
|
| 2 |
+
serverless training systems (Modal, HuggingFace Jobs, SageMaker, k8s, …).
|
| 3 |
+
|
| 4 |
+
Per ADR-005, the design rests on two abstractions:
|
| 5 |
+
|
| 6 |
+
1. `ServerlessExecutor` Protocol — a uniform interface for spinning up
|
| 7 |
+
N replicas on different cloud backends. Each backend (Modal, HF Jobs,
|
| 8 |
+
SageMaker, etc.) gets a concrete adapter that implements the Protocol.
|
| 9 |
+
|
| 10 |
+
2. `ObjectStoreAllReduce` — fsspec-backed pseudo-gradient exchange that
|
| 11 |
+
replaces the in-process `torchft.Manager.allreduce` call. The
|
| 12 |
+
communication pattern is `S3 PutObject + N GetObjects` once per
|
| 13 |
+
~500-1000 inner steps, which matches DiLoCo's actual sync cadence
|
| 14 |
+
(paper arXiv:2311.08105 §3.2). Bandwidth: ~2 GB / 30 minutes per
|
| 15 |
+
replica for 1B-param bf16, well within S3 free-tier.
|
| 16 |
+
|
| 17 |
+
The framework's existing `composer_replication.diloco.make_diloco_outer_loop`
|
| 18 |
+
wraps `torchft.local_sgd.DiLoCo`. To run that across N serverless replicas:
|
| 19 |
+
|
| 20 |
+
>>> from composer_replication.diloco.serverless import (
|
| 21 |
+
... LocalProcessExecutor,
|
| 22 |
+
... ObjectStoreAllReduce,
|
| 23 |
+
... )
|
| 24 |
+
>>> rendezvous = ObjectStoreAllReduce("s3://my-bucket/diloco-runs/run42/")
|
| 25 |
+
>>> executor = LocalProcessExecutor()
|
| 26 |
+
>>> handles = executor.launch_replicas(
|
| 27 |
+
... n_replicas=4,
|
| 28 |
+
... entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
|
| 29 |
+
... entrypoint_args={"rendezvous": rendezvous.uri, "rank_env": "REPLICA_RANK"},
|
| 30 |
+
... )
|
| 31 |
+
>>> result = executor.collect(handles, timeout=3600)
|
| 32 |
+
|
| 33 |
+
Module layout:
|
| 34 |
+
- `executor.py` — `ServerlessExecutor` Protocol + base classes + `LocalProcessExecutor`
|
| 35 |
+
- `allreduce.py` — `ObjectStoreAllReduce` + `MockManager` (drops into torchft path)
|
| 36 |
+
- `modal.py` — `ModalExecutor` (skeleton — implements when modal-client is available)
|
| 37 |
+
- `hf_jobs.py` — `HFJobsExecutor` (skeleton — uses huggingface_hub.run_job)
|
| 38 |
+
- `replica_entrypoint.py` — script each replica runs (loaded from object store)
|
| 39 |
+
|
| 40 |
+
Optional dependency: `pip install -e .[serverless]` pulls fsspec + s3fs +
|
| 41 |
+
gcsfs. Modal/HF Jobs adapters require `modal` and `huggingface_hub` respectively;
|
| 42 |
+
both are checked at adapter init time, not at module import.
|
| 43 |
+
"""
|
| 44 |
+
from __future__ import annotations
|
| 45 |
+
|
| 46 |
+
from composer_replication.diloco.serverless.allreduce import (
|
| 47 |
+
MockManager,
|
| 48 |
+
ObjectStoreAllReduce,
|
| 49 |
+
)
|
| 50 |
+
from composer_replication.diloco.serverless.executor import (
|
| 51 |
+
LocalProcessExecutor,
|
| 52 |
+
ReplicaHandle,
|
| 53 |
+
ServerlessExecutor,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
__all__ = [
|
| 57 |
+
"LocalProcessExecutor",
|
| 58 |
+
"MockManager",
|
| 59 |
+
"ObjectStoreAllReduce",
|
| 60 |
+
"ReplicaHandle",
|
| 61 |
+
"ServerlessExecutor",
|
| 62 |
+
]
|
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ObjectStoreAllReduce — fsspec-backed pseudo-gradient exchange for DiLoCo.
|
| 2 |
+
|
| 3 |
+
DiLoCo's outer-loop sync writes the local pseudo-gradient (= θ_initial − θ_local)
|
| 4 |
+
to a shared location once per H ≈ 500-1000 inner steps, then averages across
|
| 5 |
+
all replicas before the outer SGD step. With cross-job NCCL unavailable on
|
| 6 |
+
most serverless backends, we use object storage as the rendezvous medium.
|
| 7 |
+
|
| 8 |
+
Communication pattern per outer round:
|
| 9 |
+
1. Each replica writes its pseudo-gradient: PUT(rendezvous/round_N/rank_R.pt)
|
| 10 |
+
2. Each replica reads all peer pseudo-gradients: GET × N
|
| 11 |
+
3. Average locally → applied as `Manager.allreduce()` would have.
|
| 12 |
+
|
| 13 |
+
Backend support via fsspec: s3://, gs://, az://, hf://, file://.
|
| 14 |
+
The same code path works across all of them.
|
| 15 |
+
|
| 16 |
+
License compatibility: this module re-implements the contract of
|
| 17 |
+
`torchft.Manager.allreduce` through duck-typing — no torchft code is
|
| 18 |
+
copied. torchft itself is BSD-3.
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import io
|
| 23 |
+
import os
|
| 24 |
+
import time
|
| 25 |
+
from typing import Any
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ObjectStoreAllReduce:
|
| 31 |
+
"""fsspec-backed pseudo-gradient rendezvous.
|
| 32 |
+
|
| 33 |
+
Each call to `allreduce(tensor, name)` blocks until all peers have
|
| 34 |
+
written their version of `tensor` to the rendezvous location, then
|
| 35 |
+
returns the average.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
uri: fsspec URI like "s3://bucket/path/" or "file:///tmp/diloco/" or
|
| 39 |
+
a plain path "/tmp/diloco/run42/" (treated as file://).
|
| 40 |
+
rank: this replica's rank (0-indexed)
|
| 41 |
+
world_size: total number of replicas
|
| 42 |
+
round_id: optional, used to namespace successive sync rounds.
|
| 43 |
+
If None, a monotonically increasing counter is used internally.
|
| 44 |
+
timeout_s: per-allreduce timeout in seconds.
|
| 45 |
+
poll_interval_s: how often to check for peer files.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
uri: str,
|
| 51 |
+
rank: int,
|
| 52 |
+
world_size: int,
|
| 53 |
+
*,
|
| 54 |
+
round_id: int | None = None,
|
| 55 |
+
timeout_s: float = 1800.0,
|
| 56 |
+
poll_interval_s: float = 1.0,
|
| 57 |
+
) -> None:
|
| 58 |
+
if not (0 <= rank < world_size):
|
| 59 |
+
raise ValueError(f"rank {rank} not in [0, {world_size})")
|
| 60 |
+
self.uri = uri.rstrip("/") + "/"
|
| 61 |
+
self.rank = rank
|
| 62 |
+
self.world_size = world_size
|
| 63 |
+
self.timeout_s = timeout_s
|
| 64 |
+
self.poll_interval_s = poll_interval_s
|
| 65 |
+
self._round_counter = 0 if round_id is None else round_id
|
| 66 |
+
|
| 67 |
+
# Lazy fsspec init; deferred so that local-only smoke tests don't
|
| 68 |
+
# require fsspec install in the dev environment.
|
| 69 |
+
self._fs = None
|
| 70 |
+
self._is_local = self.uri.startswith("file://") or self.uri.startswith("/")
|
| 71 |
+
if self._is_local:
|
| 72 |
+
local_path = self.uri.removeprefix("file://")
|
| 73 |
+
os.makedirs(local_path, exist_ok=True)
|
| 74 |
+
self._local_root = local_path
|
| 75 |
+
else:
|
| 76 |
+
self._init_fsspec()
|
| 77 |
+
|
| 78 |
+
def _init_fsspec(self) -> None:
|
| 79 |
+
try:
|
| 80 |
+
import fsspec # noqa: F401
|
| 81 |
+
except ImportError as e:
|
| 82 |
+
raise RuntimeError(
|
| 83 |
+
"Non-local rendezvous requires fsspec; install with "
|
| 84 |
+
"`pip install -e .[serverless]`. Got: " + repr(e)
|
| 85 |
+
)
|
| 86 |
+
import fsspec
|
| 87 |
+
protocol = self.uri.split("://", 1)[0] if "://" in self.uri else "file"
|
| 88 |
+
self._fs = fsspec.filesystem(protocol)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def round_id(self) -> int:
|
| 92 |
+
return self._round_counter
|
| 93 |
+
|
| 94 |
+
def _round_dir(self, round_id: int) -> str:
|
| 95 |
+
return f"round_{round_id:06d}"
|
| 96 |
+
|
| 97 |
+
def _path_for(self, round_id: int, rank: int) -> str:
|
| 98 |
+
return f"{self._round_dir(round_id)}/rank_{rank:04d}.pt"
|
| 99 |
+
|
| 100 |
+
def _put(self, relpath: str, payload: bytes) -> None:
|
| 101 |
+
if self._is_local:
|
| 102 |
+
full = os.path.join(self._local_root, relpath)
|
| 103 |
+
os.makedirs(os.path.dirname(full), exist_ok=True)
|
| 104 |
+
tmp = full + ".tmp"
|
| 105 |
+
with open(tmp, "wb") as f:
|
| 106 |
+
f.write(payload)
|
| 107 |
+
os.replace(tmp, full) # atomic on POSIX
|
| 108 |
+
else:
|
| 109 |
+
full = self.uri + relpath
|
| 110 |
+
assert self._fs is not None
|
| 111 |
+
with self._fs.open(full, "wb") as f:
|
| 112 |
+
f.write(payload)
|
| 113 |
+
|
| 114 |
+
def _get(self, relpath: str) -> bytes:
|
| 115 |
+
if self._is_local:
|
| 116 |
+
full = os.path.join(self._local_root, relpath)
|
| 117 |
+
with open(full, "rb") as f:
|
| 118 |
+
return f.read()
|
| 119 |
+
full = self.uri + relpath
|
| 120 |
+
assert self._fs is not None
|
| 121 |
+
with self._fs.open(full, "rb") as f:
|
| 122 |
+
return f.read()
|
| 123 |
+
|
| 124 |
+
def _exists(self, relpath: str) -> bool:
|
| 125 |
+
if self._is_local:
|
| 126 |
+
return os.path.exists(os.path.join(self._local_root, relpath))
|
| 127 |
+
full = self.uri + relpath
|
| 128 |
+
assert self._fs is not None
|
| 129 |
+
return self._fs.exists(full)
|
| 130 |
+
|
| 131 |
+
def allreduce(self, tensor: torch.Tensor, *, name: str | None = None) -> torch.Tensor:
|
| 132 |
+
"""Average `tensor` across all replicas via the object store.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
tensor: the tensor to average. Modified in-place AND returned.
|
| 136 |
+
name: ignored — provided for API compat with torchft.Manager.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
The averaged tensor (modifies in-place; returns the same object).
|
| 140 |
+
"""
|
| 141 |
+
round_id = self._round_counter
|
| 142 |
+
self._round_counter += 1
|
| 143 |
+
|
| 144 |
+
# Serialize my tensor
|
| 145 |
+
buf = io.BytesIO()
|
| 146 |
+
torch.save({"rank": self.rank, "tensor": tensor.detach().cpu()}, buf)
|
| 147 |
+
my_path = self._path_for(round_id, self.rank)
|
| 148 |
+
self._put(my_path, buf.getvalue())
|
| 149 |
+
|
| 150 |
+
# Wait for all peers
|
| 151 |
+
deadline = time.time() + self.timeout_s
|
| 152 |
+
peer_tensors: list[torch.Tensor] = []
|
| 153 |
+
for peer_rank in range(self.world_size):
|
| 154 |
+
peer_path = self._path_for(round_id, peer_rank)
|
| 155 |
+
while not self._exists(peer_path):
|
| 156 |
+
if time.time() > deadline:
|
| 157 |
+
raise TimeoutError(
|
| 158 |
+
f"ObjectStoreAllReduce: timed out waiting for "
|
| 159 |
+
f"rank {peer_rank} at {self.uri}{peer_path} "
|
| 160 |
+
f"(world_size={self.world_size}, round={round_id})"
|
| 161 |
+
)
|
| 162 |
+
time.sleep(self.poll_interval_s)
|
| 163 |
+
payload = self._get(peer_path)
|
| 164 |
+
peer_data = torch.load(io.BytesIO(payload), weights_only=False)
|
| 165 |
+
peer_tensors.append(peer_data["tensor"].to(tensor.device, dtype=tensor.dtype))
|
| 166 |
+
|
| 167 |
+
# Compute average
|
| 168 |
+
stacked = torch.stack(peer_tensors, dim=0)
|
| 169 |
+
avg = stacked.mean(dim=0)
|
| 170 |
+
tensor.copy_(avg)
|
| 171 |
+
return tensor
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------
|
| 175 |
+
# MockManager — torchft.Manager-shaped object that uses ObjectStoreAllReduce
|
| 176 |
+
# ---------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class MockManager:
|
| 180 |
+
"""Drop-in replacement for `torchft.Manager` that delegates allreduce
|
| 181 |
+
to `ObjectStoreAllReduce`.
|
| 182 |
+
|
| 183 |
+
The torchft `DiLoCo` class accepts a `Manager` and calls its `.allreduce`
|
| 184 |
+
method on the pseudo-gradient. By passing this mock instead, we route
|
| 185 |
+
that call through the object store, leaving the rest of the DiLoCo
|
| 186 |
+
machinery (sign convention, post-hook sequencing, etc.) untouched.
|
| 187 |
+
|
| 188 |
+
Reference: `make_diloco_outer_loop` in
|
| 189 |
+
`composer_replication/diloco/__init__.py` accepts an optional
|
| 190 |
+
`manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.
|
| 191 |
+
"""
|
| 192 |
+
def __init__(self, store: ObjectStoreAllReduce) -> None:
|
| 193 |
+
self._store = store
|
| 194 |
+
# torchft Manager attributes that DiLoCo consults
|
| 195 |
+
self.num_participants = store.world_size
|
| 196 |
+
self.rank = store.rank
|
| 197 |
+
|
| 198 |
+
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> torch.Tensor:
|
| 199 |
+
return self._store.allreduce(tensor)
|
| 200 |
+
|
| 201 |
+
# torchft.Manager has additional methods (`should_commit`, `start_quorum`,
|
| 202 |
+
# etc.) that are no-ops for our coarse-grained sync. The `DiLoCo` class
|
| 203 |
+
# only requires `allreduce`, but the others may be probed.
|
| 204 |
+
def should_commit(self) -> bool:
|
| 205 |
+
return True
|
| 206 |
+
|
| 207 |
+
def start_quorum(self) -> None:
|
| 208 |
+
pass
|
| 209 |
+
|
| 210 |
+
def wait_quorum(self) -> int:
|
| 211 |
+
return self.num_participants
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
__all__ = ["MockManager", "ObjectStoreAllReduce"]
|
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ServerlessExecutor Protocol + LocalProcessExecutor.
|
| 2 |
+
|
| 3 |
+
Per ADR-005:
|
| 4 |
+
- `ServerlessExecutor` is a structural Protocol — backends implement it
|
| 5 |
+
by writing a class with the right methods, no formal inheritance needed.
|
| 6 |
+
- `LocalProcessExecutor` is the reference implementation that uses Python's
|
| 7 |
+
`multiprocessing` module. It's used for tests and for development; the
|
| 8 |
+
cloud adapters (Modal, HF Jobs, …) implement the same Protocol against
|
| 9 |
+
their respective backends.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import multiprocessing as mp
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any, Callable, Mapping, Protocol, runtime_checkable
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ReplicaHandle:
|
| 22 |
+
"""Opaque handle to a launched replica. Backend-specific contents.
|
| 23 |
+
|
| 24 |
+
All executors return `list[ReplicaHandle]` from `launch_replicas`.
|
| 25 |
+
Each handle's `metadata` dict is backend-specific; users shouldn't
|
| 26 |
+
rely on its shape.
|
| 27 |
+
"""
|
| 28 |
+
rank: int
|
| 29 |
+
backend_name: str
|
| 30 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 31 |
+
"""Backend-specific data (e.g. Modal call ID, HF Jobs job ID, local
|
| 32 |
+
Process object). Not stable across backends."""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@runtime_checkable
|
| 36 |
+
class ServerlessExecutor(Protocol):
|
| 37 |
+
"""Uniform interface for launching N replicas on a serverless backend.
|
| 38 |
+
|
| 39 |
+
Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor`
|
| 40 |
+
(Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future:
|
| 41 |
+
`RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`.
|
| 42 |
+
|
| 43 |
+
Note on rank assignment: the Protocol guarantees that handles are
|
| 44 |
+
returned in rank order (`handles[i].rank == i`). The replica entrypoint
|
| 45 |
+
learns its own rank either from an environment variable
|
| 46 |
+
(`REPLICA_RANK`) or from a backend-provided mechanism (Modal's
|
| 47 |
+
`Function.shard_rank`, etc.). The executor abstraction normalizes
|
| 48 |
+
rank by setting the env var.
|
| 49 |
+
"""
|
| 50 |
+
backend_name: str
|
| 51 |
+
supports_inter_replica_network: bool
|
| 52 |
+
|
| 53 |
+
def launch_replicas(
|
| 54 |
+
self,
|
| 55 |
+
n_replicas: int,
|
| 56 |
+
entrypoint: str | Callable[..., Any],
|
| 57 |
+
entrypoint_args: Mapping[str, Any],
|
| 58 |
+
*,
|
| 59 |
+
gpu: str | None = None,
|
| 60 |
+
timeout: int = 3600,
|
| 61 |
+
) -> list[ReplicaHandle]:
|
| 62 |
+
"""Spin up N replicas in parallel.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
n_replicas: number of replicas to launch
|
| 66 |
+
entrypoint: either an importable Python path (e.g.
|
| 67 |
+
"composer_replication.diloco.serverless.replica_entrypoint")
|
| 68 |
+
or a Callable (Local executor only).
|
| 69 |
+
entrypoint_args: kwargs passed to the entrypoint. The kwarg
|
| 70 |
+
`rank_env` (default "REPLICA_RANK") names the environment
|
| 71 |
+
variable in which the rank will be set on the replica.
|
| 72 |
+
gpu: backend-specific GPU spec, e.g. "A100", "H100". `None`
|
| 73 |
+
means CPU-only (smoke tests).
|
| 74 |
+
timeout: per-replica wall-clock timeout in seconds.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
list[ReplicaHandle] of length n_replicas, in rank order.
|
| 78 |
+
"""
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
def poll(self, handle: ReplicaHandle) -> str:
|
| 82 |
+
"""Poll a replica's status. Returns one of:
|
| 83 |
+
"pending" | "running" | "succeeded" | "failed" | "cancelled".
|
| 84 |
+
"""
|
| 85 |
+
...
|
| 86 |
+
|
| 87 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
|
| 88 |
+
"""Read up to n_lines of recent stdout/stderr from a replica."""
|
| 89 |
+
...
|
| 90 |
+
|
| 91 |
+
def cancel(self, handle: ReplicaHandle) -> None:
|
| 92 |
+
"""Best-effort cancel. No exception if already terminated."""
|
| 93 |
+
...
|
| 94 |
+
|
| 95 |
+
def collect(
|
| 96 |
+
self,
|
| 97 |
+
handles: list[ReplicaHandle],
|
| 98 |
+
*,
|
| 99 |
+
timeout: int | None = None,
|
| 100 |
+
) -> list[dict[str, Any]]:
|
| 101 |
+
"""Block until all replicas finish; return per-replica result dicts.
|
| 102 |
+
|
| 103 |
+
Each result dict contains at least:
|
| 104 |
+
{"rank": int, "status": str, "exit_code": int | None,
|
| 105 |
+
"error": str | None}
|
| 106 |
+
"""
|
| 107 |
+
...
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------
|
| 111 |
+
# LocalProcessExecutor — reference implementation using multiprocessing
|
| 112 |
+
# ---------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _local_replica_target(
|
| 116 |
+
rank: int,
|
| 117 |
+
rank_env: str,
|
| 118 |
+
entrypoint: Any,
|
| 119 |
+
entrypoint_args: Mapping[str, Any],
|
| 120 |
+
result_queue: mp.Queue,
|
| 121 |
+
) -> None:
|
| 122 |
+
"""multiprocessing target — runs in the child process."""
|
| 123 |
+
import os
|
| 124 |
+
import traceback
|
| 125 |
+
|
| 126 |
+
os.environ[rank_env] = str(rank)
|
| 127 |
+
try:
|
| 128 |
+
if callable(entrypoint):
|
| 129 |
+
result = entrypoint(**entrypoint_args)
|
| 130 |
+
elif isinstance(entrypoint, str):
|
| 131 |
+
# importable path
|
| 132 |
+
mod_path, _, fn_name = entrypoint.rpartition(".")
|
| 133 |
+
if not mod_path:
|
| 134 |
+
# Top-level script path; just import it and call its main()
|
| 135 |
+
import importlib
|
| 136 |
+
mod = importlib.import_module(entrypoint)
|
| 137 |
+
fn = getattr(mod, "main", None)
|
| 138 |
+
if fn is None:
|
| 139 |
+
raise RuntimeError(
|
| 140 |
+
f"entrypoint '{entrypoint}' has no main() function"
|
| 141 |
+
)
|
| 142 |
+
result = fn(**entrypoint_args)
|
| 143 |
+
else:
|
| 144 |
+
import importlib
|
| 145 |
+
mod = importlib.import_module(mod_path)
|
| 146 |
+
fn = getattr(mod, fn_name)
|
| 147 |
+
result = fn(**entrypoint_args)
|
| 148 |
+
else:
|
| 149 |
+
raise TypeError(
|
| 150 |
+
f"entrypoint must be Callable or importable str, got {type(entrypoint)!r}"
|
| 151 |
+
)
|
| 152 |
+
result_queue.put({"rank": rank, "status": "succeeded",
|
| 153 |
+
"exit_code": 0, "error": None, "result": result})
|
| 154 |
+
except Exception as e:
|
| 155 |
+
tb = traceback.format_exc()
|
| 156 |
+
result_queue.put({"rank": rank, "status": "failed",
|
| 157 |
+
"exit_code": 1, "error": f"{e!r}\n{tb}", "result": None})
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LocalProcessExecutor:
|
| 161 |
+
"""Runs replicas as subprocesses on the local machine.
|
| 162 |
+
|
| 163 |
+
Use cases:
|
| 164 |
+
- Test the serverless layer end-to-end without cloud spend.
|
| 165 |
+
- Develop the algorithm locally with N>1 replicas and `file://`
|
| 166 |
+
rendezvous before deploying to Modal/HF Jobs.
|
| 167 |
+
- CI smoke tests.
|
| 168 |
+
"""
|
| 169 |
+
backend_name = "local_process"
|
| 170 |
+
supports_inter_replica_network = True # localhost works
|
| 171 |
+
|
| 172 |
+
def __init__(self) -> None:
|
| 173 |
+
# use 'spawn' so the child has a fresh interpreter (avoid CUDA fork issues)
|
| 174 |
+
try:
|
| 175 |
+
self._ctx = mp.get_context("spawn")
|
| 176 |
+
except ValueError:
|
| 177 |
+
# Fallback for environments where 'spawn' isn't available
|
| 178 |
+
self._ctx = mp.get_context()
|
| 179 |
+
self._handles: dict[int, dict[str, Any]] = {}
|
| 180 |
+
|
| 181 |
+
def launch_replicas(
|
| 182 |
+
self,
|
| 183 |
+
n_replicas: int,
|
| 184 |
+
entrypoint: str | Callable[..., Any],
|
| 185 |
+
entrypoint_args: Mapping[str, Any],
|
| 186 |
+
*,
|
| 187 |
+
gpu: str | None = None,
|
| 188 |
+
timeout: int = 3600,
|
| 189 |
+
) -> list[ReplicaHandle]:
|
| 190 |
+
if gpu is not None:
|
| 191 |
+
# Local executor doesn't pin GPUs; emit a soft warning.
|
| 192 |
+
sys.stderr.write(
|
| 193 |
+
f"[LocalProcessExecutor] gpu={gpu!r} ignored — "
|
| 194 |
+
f"local processes share whatever GPUs are visible.\n"
|
| 195 |
+
)
|
| 196 |
+
rank_env = entrypoint_args.get("rank_env", "REPLICA_RANK")
|
| 197 |
+
|
| 198 |
+
handles: list[ReplicaHandle] = []
|
| 199 |
+
result_queue: mp.Queue = self._ctx.Queue()
|
| 200 |
+
for rank in range(n_replicas):
|
| 201 |
+
args_for_rank = dict(entrypoint_args)
|
| 202 |
+
args_for_rank.pop("rank_env", None)
|
| 203 |
+
proc = self._ctx.Process(
|
| 204 |
+
target=_local_replica_target,
|
| 205 |
+
args=(rank, rank_env, entrypoint, args_for_rank, result_queue),
|
| 206 |
+
name=f"composer-replica-{rank}",
|
| 207 |
+
)
|
| 208 |
+
proc.start()
|
| 209 |
+
handle = ReplicaHandle(
|
| 210 |
+
rank=rank, backend_name=self.backend_name,
|
| 211 |
+
metadata={"pid": proc.pid, "start_ts": time.time()},
|
| 212 |
+
)
|
| 213 |
+
self._handles[rank] = {"proc": proc, "queue": result_queue,
|
| 214 |
+
"deadline": time.time() + timeout,
|
| 215 |
+
"result": None}
|
| 216 |
+
handles.append(handle)
|
| 217 |
+
return handles
|
| 218 |
+
|
| 219 |
+
def poll(self, handle: ReplicaHandle) -> str:
|
| 220 |
+
meta = self._handles.get(handle.rank)
|
| 221 |
+
if meta is None:
|
| 222 |
+
return "cancelled"
|
| 223 |
+
proc: mp.Process = meta["proc"]
|
| 224 |
+
if proc.is_alive():
|
| 225 |
+
return "running"
|
| 226 |
+
if meta.get("result") is not None:
|
| 227 |
+
return meta["result"]["status"]
|
| 228 |
+
# Process exited; read result if available
|
| 229 |
+
try:
|
| 230 |
+
queue: mp.Queue = meta["queue"]
|
| 231 |
+
while not queue.empty():
|
| 232 |
+
r = queue.get_nowait()
|
| 233 |
+
self._handles[r["rank"]]["result"] = r
|
| 234 |
+
except Exception:
|
| 235 |
+
pass
|
| 236 |
+
if meta.get("result") is not None:
|
| 237 |
+
return meta["result"]["status"]
|
| 238 |
+
return "failed" if proc.exitcode != 0 else "succeeded"
|
| 239 |
+
|
| 240 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
|
| 241 |
+
# multiprocessing.Process doesn't natively capture stdout; we'd
|
| 242 |
+
# need a Pipe or file redirection. For the local reference impl,
|
| 243 |
+
# we just point the user at the result dict's `error` field.
|
| 244 |
+
meta = self._handles.get(handle.rank)
|
| 245 |
+
if meta is None:
|
| 246 |
+
return f"<replica {handle.rank}: no metadata>"
|
| 247 |
+
if meta.get("result"):
|
| 248 |
+
err = meta["result"].get("error") or ""
|
| 249 |
+
return f"[rank {handle.rank}] {err[-2000:]}"
|
| 250 |
+
return f"<replica {handle.rank}: still running, no captured logs>"
|
| 251 |
+
|
| 252 |
+
def cancel(self, handle: ReplicaHandle) -> None:
|
| 253 |
+
meta = self._handles.get(handle.rank)
|
| 254 |
+
if meta is None:
|
| 255 |
+
return
|
| 256 |
+
proc: mp.Process = meta["proc"]
|
| 257 |
+
if proc.is_alive():
|
| 258 |
+
proc.terminate()
|
| 259 |
+
proc.join(timeout=5)
|
| 260 |
+
if proc.is_alive():
|
| 261 |
+
proc.kill()
|
| 262 |
+
|
| 263 |
+
def collect(
|
| 264 |
+
self,
|
| 265 |
+
handles: list[ReplicaHandle],
|
| 266 |
+
*,
|
| 267 |
+
timeout: int | None = None,
|
| 268 |
+
) -> list[dict[str, Any]]:
|
| 269 |
+
deadline = time.time() + (timeout if timeout is not None else 3600)
|
| 270 |
+
# Wait for all processes to finish
|
| 271 |
+
for h in handles:
|
| 272 |
+
meta = self._handles.get(h.rank)
|
| 273 |
+
if meta is None:
|
| 274 |
+
continue
|
| 275 |
+
proc: mp.Process = meta["proc"]
|
| 276 |
+
remaining = max(0.0, deadline - time.time())
|
| 277 |
+
proc.join(timeout=remaining)
|
| 278 |
+
if proc.is_alive():
|
| 279 |
+
proc.terminate()
|
| 280 |
+
proc.join(timeout=5)
|
| 281 |
+
# Drain results
|
| 282 |
+
results_by_rank: dict[int, dict[str, Any]] = {}
|
| 283 |
+
for h in handles:
|
| 284 |
+
meta = self._handles.get(h.rank)
|
| 285 |
+
if meta is None:
|
| 286 |
+
results_by_rank[h.rank] = {
|
| 287 |
+
"rank": h.rank, "status": "cancelled",
|
| 288 |
+
"exit_code": None, "error": "no metadata", "result": None,
|
| 289 |
+
}
|
| 290 |
+
continue
|
| 291 |
+
queue: mp.Queue = meta["queue"]
|
| 292 |
+
while not queue.empty():
|
| 293 |
+
try:
|
| 294 |
+
r = queue.get_nowait()
|
| 295 |
+
results_by_rank[r["rank"]] = r
|
| 296 |
+
except Exception:
|
| 297 |
+
break
|
| 298 |
+
if h.rank not in results_by_rank:
|
| 299 |
+
proc: mp.Process = meta["proc"]
|
| 300 |
+
results_by_rank[h.rank] = {
|
| 301 |
+
"rank": h.rank,
|
| 302 |
+
"status": "succeeded" if proc.exitcode == 0 else "failed",
|
| 303 |
+
"exit_code": proc.exitcode,
|
| 304 |
+
"error": None if proc.exitcode == 0 else f"exit code {proc.exitcode}",
|
| 305 |
+
"result": None,
|
| 306 |
+
}
|
| 307 |
+
return [results_by_rank[h.rank] for h in handles]
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
__all__ = ["LocalProcessExecutor", "ReplicaHandle", "ServerlessExecutor"]
|
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace Jobs executor — skeleton for v0.
|
| 2 |
+
|
| 3 |
+
Per ADR-005, HF Jobs is one of two v0 target executors. This file is a
|
| 4 |
+
STUB. The full integration uses `huggingface_hub.run_job` (added in
|
| 5 |
+
huggingface_hub >= 0.27, ~2026 era) which spins up a containerized job
|
| 6 |
+
backed by HF's compute pool.
|
| 7 |
+
|
| 8 |
+
Pricing reference (2026-05-26): A100 ≈ $4.18/hr, H100 ≈ $9.50/hr. Cold
|
| 9 |
+
start ≈ 60s. NO inter-job networking — must use object-store rendezvous.
|
| 10 |
+
|
| 11 |
+
Status: SKELETON. Real implementation pending v0 polish wave.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from typing import Any, Callable, Mapping
|
| 16 |
+
|
| 17 |
+
from composer_replication.diloco.serverless.executor import (
|
| 18 |
+
ReplicaHandle,
|
| 19 |
+
ServerlessExecutor,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class HFJobsExecutor(ServerlessExecutor):
|
| 24 |
+
"""Run replicas as HuggingFace Jobs in parallel.
|
| 25 |
+
|
| 26 |
+
Reference implementation pattern:
|
| 27 |
+
|
| 28 |
+
from huggingface_hub import run_job
|
| 29 |
+
jobs = []
|
| 30 |
+
for rank in range(N):
|
| 31 |
+
job = run_job(
|
| 32 |
+
image="...", # container with composer_replication installed
|
| 33 |
+
command=[
|
| 34 |
+
"python", "-m",
|
| 35 |
+
"composer_replication.diloco.serverless.replica_entrypoint",
|
| 36 |
+
"--rank", str(rank),
|
| 37 |
+
"--rendezvous", "hf://datasets/myuser/run42/",
|
| 38 |
+
],
|
| 39 |
+
env={"REPLICA_RANK": str(rank), "WORLD_SIZE": str(N)},
|
| 40 |
+
gpu="a100",
|
| 41 |
+
)
|
| 42 |
+
jobs.append(job)
|
| 43 |
+
return [ReplicaHandle(rank=i, backend_name="hf_jobs",
|
| 44 |
+
metadata={"job_id": jobs[i].id})
|
| 45 |
+
for i in range(N)]
|
| 46 |
+
|
| 47 |
+
Object-store rendezvous works naturally with the HF Datasets-as-storage
|
| 48 |
+
pattern — `hf://datasets/{user}/{run_id}/` is fsspec-compatible via
|
| 49 |
+
`huggingface_hub`'s fsspec integration.
|
| 50 |
+
|
| 51 |
+
Status: SKELETON.
|
| 52 |
+
"""
|
| 53 |
+
backend_name = "hf_jobs"
|
| 54 |
+
supports_inter_replica_network = False
|
| 55 |
+
|
| 56 |
+
def __init__(self) -> None:
|
| 57 |
+
try:
|
| 58 |
+
from huggingface_hub import HfApi # noqa: F401
|
| 59 |
+
except ImportError as e:
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
"HFJobsExecutor requires huggingface_hub. Got: " + repr(e)
|
| 62 |
+
)
|
| 63 |
+
# Real implementation: instantiate HfApi, validate token, etc.
|
| 64 |
+
raise NotImplementedError(
|
| 65 |
+
"HFJobsExecutor is a v0 skeleton; full implementation pending. "
|
| 66 |
+
"Use LocalProcessExecutor for testing."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def launch_replicas(
|
| 70 |
+
self,
|
| 71 |
+
n_replicas: int,
|
| 72 |
+
entrypoint: str | Callable[..., Any],
|
| 73 |
+
entrypoint_args: Mapping[str, Any],
|
| 74 |
+
*,
|
| 75 |
+
gpu: str | None = "a100",
|
| 76 |
+
timeout: int = 3600,
|
| 77 |
+
) -> list[ReplicaHandle]:
|
| 78 |
+
raise NotImplementedError
|
| 79 |
+
|
| 80 |
+
def poll(self, handle: ReplicaHandle) -> str:
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
|
| 84 |
+
raise NotImplementedError
|
| 85 |
+
|
| 86 |
+
def cancel(self, handle: ReplicaHandle) -> None:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
def collect(
|
| 90 |
+
self,
|
| 91 |
+
handles: list[ReplicaHandle],
|
| 92 |
+
*,
|
| 93 |
+
timeout: int | None = None,
|
| 94 |
+
) -> list[dict[str, Any]]:
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
__all__ = ["HFJobsExecutor"]
|
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal executor — skeleton for v0.
|
| 2 |
+
|
| 3 |
+
This file is a STUB. The full Modal integration requires the `modal`
|
| 4 |
+
client library installed (`pip install modal`) and a configured Modal
|
| 5 |
+
account (`~/.modal.toml`). The user's environment has both, but the
|
| 6 |
+
test suite must run without them, so we keep this file import-safe.
|
| 7 |
+
|
| 8 |
+
Real implementation lives in v0 polish; the docstring below is the
|
| 9 |
+
contract.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any, Callable, Mapping
|
| 14 |
+
|
| 15 |
+
from composer_replication.diloco.serverless.executor import (
|
| 16 |
+
ReplicaHandle,
|
| 17 |
+
ServerlessExecutor,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ModalExecutor(ServerlessExecutor):
|
| 22 |
+
"""Run replicas as Modal Functions in parallel.
|
| 23 |
+
|
| 24 |
+
Reference implementation pattern (per ADR-005):
|
| 25 |
+
|
| 26 |
+
@app.function(gpu="A100-40GB", timeout=3600)
|
| 27 |
+
def run_replica(rank: int, rendezvous_uri: str, **kwargs):
|
| 28 |
+
os.environ["REPLICA_RANK"] = str(rank)
|
| 29 |
+
from composer_replication.diloco.serverless import (
|
| 30 |
+
MockManager, ObjectStoreAllReduce,
|
| 31 |
+
)
|
| 32 |
+
store = ObjectStoreAllReduce(rendezvous_uri,
|
| 33 |
+
rank=rank, world_size=N)
|
| 34 |
+
manager = MockManager(store)
|
| 35 |
+
# ... run the trainer with this manager ...
|
| 36 |
+
|
| 37 |
+
Then `launch_replicas` does:
|
| 38 |
+
calls = [run_replica.spawn(rank=i, ...) for i in range(N)]
|
| 39 |
+
return [ReplicaHandle(rank=i, backend_name="modal",
|
| 40 |
+
metadata={"call_id": calls[i].object_id})
|
| 41 |
+
for i in range(N)]
|
| 42 |
+
|
| 43 |
+
Pricing reference (2026-05-26): A100-40GB ≈ $1.95/hr, H100 ≈ $5.50/hr.
|
| 44 |
+
Cold start ≈ 30s. Inter-job networking via cluster mode (opt-in,
|
| 45 |
+
not used by default).
|
| 46 |
+
|
| 47 |
+
Status: SKELETON. Real implementation pending v0 polish wave.
|
| 48 |
+
"""
|
| 49 |
+
backend_name = "modal"
|
| 50 |
+
supports_inter_replica_network = False # default; cluster mode = True
|
| 51 |
+
|
| 52 |
+
def __init__(self, *, app_name: str = "composer-replication-diloco") -> None:
|
| 53 |
+
try:
|
| 54 |
+
import modal # noqa: F401
|
| 55 |
+
except ImportError as e:
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
"ModalExecutor requires the modal client. Install with "
|
| 58 |
+
"`pip install modal` and configure with `modal token new`. "
|
| 59 |
+
"Got: " + repr(e)
|
| 60 |
+
)
|
| 61 |
+
self.app_name = app_name
|
| 62 |
+
# Real implementation: build a `modal.App` and register `run_replica`
|
| 63 |
+
# here so that subsequent `launch_replicas` can `.spawn()` it.
|
| 64 |
+
raise NotImplementedError(
|
| 65 |
+
"ModalExecutor is a v0 skeleton; full implementation pending. "
|
| 66 |
+
"Use LocalProcessExecutor for testing."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# All Protocol methods raise NotImplementedError via __init__ — the
|
| 70 |
+
# class never instantiates successfully in the skeleton. Sketch
|
| 71 |
+
# signatures here for documentation:
|
| 72 |
+
|
| 73 |
+
def launch_replicas(
|
| 74 |
+
self,
|
| 75 |
+
n_replicas: int,
|
| 76 |
+
entrypoint: str | Callable[..., Any],
|
| 77 |
+
entrypoint_args: Mapping[str, Any],
|
| 78 |
+
*,
|
| 79 |
+
gpu: str | None = "A100-40GB",
|
| 80 |
+
timeout: int = 3600,
|
| 81 |
+
) -> list[ReplicaHandle]:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
def poll(self, handle: ReplicaHandle) -> str:
|
| 85 |
+
raise NotImplementedError
|
| 86 |
+
|
| 87 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
def cancel(self, handle: ReplicaHandle) -> None:
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
|
| 93 |
+
def collect(
|
| 94 |
+
self,
|
| 95 |
+
handles: list[ReplicaHandle],
|
| 96 |
+
*,
|
| 97 |
+
timeout: int | None = None,
|
| 98 |
+
) -> list[dict[str, Any]]:
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
__all__ = ["ModalExecutor"]
|
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Replica entrypoint — what each serverless replica runs.
|
| 2 |
+
|
| 3 |
+
This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`,
|
| 4 |
+
`HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env
|
| 5 |
+
var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI,
|
| 6 |
+
wraps it in a `MockManager`, and hands it off to the user's training
|
| 7 |
+
function.
|
| 8 |
+
|
| 9 |
+
Usage from an executor:
|
| 10 |
+
|
| 11 |
+
>>> executor.launch_replicas(
|
| 12 |
+
... n_replicas=4,
|
| 13 |
+
... entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
|
| 14 |
+
... entrypoint_args={
|
| 15 |
+
... "rendezvous_uri": "/tmp/run42/",
|
| 16 |
+
... "world_size": 4,
|
| 17 |
+
... "trainer_module": "my_project.trainer",
|
| 18 |
+
... "trainer_fn": "train",
|
| 19 |
+
... "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"},
|
| 20 |
+
... },
|
| 21 |
+
... )
|
| 22 |
+
|
| 23 |
+
The entrypoint expects:
|
| 24 |
+
- `REPLICA_RANK` env var set to the rank (0..world_size-1)
|
| 25 |
+
- `rendezvous_uri`: fsspec URI for object-store rendezvous
|
| 26 |
+
- `world_size`: total replicas
|
| 27 |
+
- `trainer_module`, `trainer_fn`: importable path to the user's train fn
|
| 28 |
+
- `trainer_kwargs`: dict passed to the user's train fn, plus an injected
|
| 29 |
+
`manager` kwarg containing the `MockManager`
|
| 30 |
+
"""
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import importlib
|
| 34 |
+
import os
|
| 35 |
+
from typing import Any
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main(
|
| 39 |
+
rendezvous_uri: str,
|
| 40 |
+
world_size: int,
|
| 41 |
+
trainer_module: str,
|
| 42 |
+
trainer_fn: str = "train",
|
| 43 |
+
trainer_kwargs: dict[str, Any] | None = None,
|
| 44 |
+
) -> Any:
|
| 45 |
+
"""Entrypoint executed inside each replica.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
rendezvous_uri: fsspec URI (or local path) for the rendezvous
|
| 49 |
+
world_size: total replicas
|
| 50 |
+
trainer_module: importable Python module containing the user's
|
| 51 |
+
train function
|
| 52 |
+
trainer_fn: name of the function to call (default "train")
|
| 53 |
+
trainer_kwargs: kwargs passed to the train function
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Whatever the train function returns.
|
| 57 |
+
"""
|
| 58 |
+
from composer_replication.diloco.serverless.allreduce import (
|
| 59 |
+
MockManager,
|
| 60 |
+
ObjectStoreAllReduce,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
rank_str = os.environ.get("REPLICA_RANK")
|
| 64 |
+
if rank_str is None:
|
| 65 |
+
raise RuntimeError(
|
| 66 |
+
"REPLICA_RANK env var not set. The serverless executor "
|
| 67 |
+
"should set this for each replica."
|
| 68 |
+
)
|
| 69 |
+
rank = int(rank_str)
|
| 70 |
+
|
| 71 |
+
if not (0 <= rank < world_size):
|
| 72 |
+
raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})")
|
| 73 |
+
|
| 74 |
+
store = ObjectStoreAllReduce(
|
| 75 |
+
uri=rendezvous_uri,
|
| 76 |
+
rank=rank,
|
| 77 |
+
world_size=world_size,
|
| 78 |
+
)
|
| 79 |
+
manager = MockManager(store)
|
| 80 |
+
|
| 81 |
+
mod = importlib.import_module(trainer_module)
|
| 82 |
+
fn = getattr(mod, trainer_fn)
|
| 83 |
+
|
| 84 |
+
kwargs = dict(trainer_kwargs or {})
|
| 85 |
+
kwargs["manager"] = manager # injected
|
| 86 |
+
kwargs["rank"] = rank
|
| 87 |
+
kwargs["world_size"] = world_size
|
| 88 |
+
return fn(**kwargs)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
import argparse
|
| 93 |
+
import json
|
| 94 |
+
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
parser.add_argument("--rendezvous", required=True)
|
| 97 |
+
parser.add_argument("--world-size", type=int, required=True)
|
| 98 |
+
parser.add_argument("--trainer-module", required=True)
|
| 99 |
+
parser.add_argument("--trainer-fn", default="train")
|
| 100 |
+
parser.add_argument("--trainer-kwargs-json", default="{}")
|
| 101 |
+
args = parser.parse_args()
|
| 102 |
+
|
| 103 |
+
main(
|
| 104 |
+
rendezvous_uri=args.rendezvous,
|
| 105 |
+
world_size=args.world_size,
|
| 106 |
+
trainer_module=args.trainer_module,
|
| 107 |
+
trainer_fn=args.trainer_fn,
|
| 108 |
+
trainer_kwargs=json.loads(args.trainer_kwargs_json),
|
| 109 |
+
)
|
|
File without changes
|
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Verifies the serverless DiLoCo allreduce wraps correctly across local
|
| 2 |
+
multiprocessing replicas using `file://` rendezvous.
|
| 3 |
+
|
| 4 |
+
This is the core multi-process test for the serverless layer. It exercises
|
| 5 |
+
the real allreduce barrier (with concurrent processes), not just the
|
| 6 |
+
single-process API.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import tempfile
|
| 13 |
+
import time
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from composer_replication.diloco.serverless import (
|
| 19 |
+
LocalProcessExecutor,
|
| 20 |
+
ObjectStoreAllReduce,
|
| 21 |
+
ReplicaHandle,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------
|
| 26 |
+
# Single-process tests of ObjectStoreAllReduce primitives
|
| 27 |
+
# (don't need executor, just the file:// path + local manual orchestration)
|
| 28 |
+
# ---------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_object_store_allreduce_init_validates_rank():
|
| 32 |
+
with tempfile.TemporaryDirectory() as td:
|
| 33 |
+
with pytest.raises(ValueError, match="not in"):
|
| 34 |
+
ObjectStoreAllReduce(td, rank=5, world_size=2)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_object_store_allreduce_local_paths_create_dir():
|
| 38 |
+
"""Local backend should mkdir on init."""
|
| 39 |
+
with tempfile.TemporaryDirectory() as td:
|
| 40 |
+
new_path = os.path.join(td, "subdir", "subsubdir")
|
| 41 |
+
store = ObjectStoreAllReduce(new_path, rank=0, world_size=1)
|
| 42 |
+
assert os.path.isdir(new_path)
|
| 43 |
+
assert store.world_size == 1
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_object_store_allreduce_world_size_1_passthrough():
|
| 47 |
+
"""With world_size=1 it just averages the tensor with itself."""
|
| 48 |
+
with tempfile.TemporaryDirectory() as td:
|
| 49 |
+
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
|
| 50 |
+
t = torch.tensor([1.0, 2.0, 3.0])
|
| 51 |
+
result = store.allreduce(t.clone())
|
| 52 |
+
torch.testing.assert_close(result, t, atol=1e-6, rtol=1e-6)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_object_store_allreduce_round_id_increments():
|
| 56 |
+
with tempfile.TemporaryDirectory() as td:
|
| 57 |
+
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
|
| 58 |
+
t = torch.zeros(3)
|
| 59 |
+
assert store.round_id == 0
|
| 60 |
+
store.allreduce(t.clone())
|
| 61 |
+
assert store.round_id == 1
|
| 62 |
+
store.allreduce(t.clone())
|
| 63 |
+
assert store.round_id == 2
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------
|
| 67 |
+
# Multi-process tests (the real verification — local executor + spawn)
|
| 68 |
+
# ---------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _replica_compute_and_sync(
|
| 72 |
+
rendezvous_uri: str,
|
| 73 |
+
world_size: int,
|
| 74 |
+
rank_value: float,
|
| 75 |
+
) -> dict:
|
| 76 |
+
"""Top-level function — must be importable for multiprocessing 'spawn'.
|
| 77 |
+
|
| 78 |
+
Each replica creates a tensor whose value is `rank_value * (rank+1)` and
|
| 79 |
+
runs allreduce. The expected result is the mean of all replicas' tensors.
|
| 80 |
+
"""
|
| 81 |
+
rank = int(os.environ["REPLICA_RANK"])
|
| 82 |
+
store = ObjectStoreAllReduce(
|
| 83 |
+
rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
|
| 84 |
+
)
|
| 85 |
+
# tensor that depends on rank
|
| 86 |
+
t = torch.full((4,), float(rank_value * (rank + 1)))
|
| 87 |
+
pre = t.clone()
|
| 88 |
+
averaged = store.allreduce(t)
|
| 89 |
+
return {
|
| 90 |
+
"rank": rank,
|
| 91 |
+
"pre": pre.tolist(),
|
| 92 |
+
"post": averaged.tolist(),
|
| 93 |
+
"world_size": world_size,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@pytest.mark.parametrize("n_replicas", [2, 3])
|
| 98 |
+
def test_local_executor_runs_allreduce_across_replicas(n_replicas):
|
| 99 |
+
"""End-to-end: 2-3 replica processes each call allreduce; result is the mean."""
|
| 100 |
+
with tempfile.TemporaryDirectory() as td:
|
| 101 |
+
rendezvous = os.path.join(td, "run")
|
| 102 |
+
executor = LocalProcessExecutor()
|
| 103 |
+
handles = executor.launch_replicas(
|
| 104 |
+
n_replicas=n_replicas,
|
| 105 |
+
entrypoint=f"{__name__}._replica_compute_and_sync",
|
| 106 |
+
entrypoint_args={
|
| 107 |
+
"rendezvous_uri": rendezvous,
|
| 108 |
+
"world_size": n_replicas,
|
| 109 |
+
"rank_value": 10.0,
|
| 110 |
+
"rank_env": "REPLICA_RANK",
|
| 111 |
+
},
|
| 112 |
+
timeout=180,
|
| 113 |
+
)
|
| 114 |
+
assert len(handles) == n_replicas
|
| 115 |
+
for i, h in enumerate(handles):
|
| 116 |
+
assert h.rank == i
|
| 117 |
+
assert h.backend_name == "local_process"
|
| 118 |
+
|
| 119 |
+
results = executor.collect(handles, timeout=180)
|
| 120 |
+
assert len(results) == n_replicas
|
| 121 |
+
|
| 122 |
+
# Verify all succeeded
|
| 123 |
+
for r in results:
|
| 124 |
+
assert r["status"] == "succeeded", \
|
| 125 |
+
f"rank {r['rank']} failed: {r.get('error')}"
|
| 126 |
+
|
| 127 |
+
# Each replica created tensor full(rank_value * (rank+1)).
|
| 128 |
+
# Expected mean = rank_value * (1+2+...+N) / N
|
| 129 |
+
N = n_replicas
|
| 130 |
+
expected_mean = 10.0 * (N * (N + 1) / 2) / N
|
| 131 |
+
|
| 132 |
+
for r in results:
|
| 133 |
+
post = r["result"]["post"]
|
| 134 |
+
for v in post:
|
| 135 |
+
assert abs(v - expected_mean) < 1e-4, \
|
| 136 |
+
f"rank {r['rank']}: expected mean {expected_mean}, got {v}"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _replica_two_round_sync(
|
| 140 |
+
rendezvous_uri: str,
|
| 141 |
+
world_size: int,
|
| 142 |
+
) -> dict:
|
| 143 |
+
"""Each replica does TWO consecutive allreduce calls; checks round_id increments."""
|
| 144 |
+
rank = int(os.environ["REPLICA_RANK"])
|
| 145 |
+
store = ObjectStoreAllReduce(
|
| 146 |
+
rendezvous_uri, rank=rank, world_size=world_size, timeout_s=120.0,
|
| 147 |
+
)
|
| 148 |
+
t1 = torch.full((2,), float(rank))
|
| 149 |
+
avg1 = store.allreduce(t1).clone()
|
| 150 |
+
t2 = torch.full((2,), float(rank * 100))
|
| 151 |
+
avg2 = store.allreduce(t2).clone()
|
| 152 |
+
return {
|
| 153 |
+
"rank": rank,
|
| 154 |
+
"round_after_2_calls": store.round_id,
|
| 155 |
+
"avg1": avg1.tolist(),
|
| 156 |
+
"avg2": avg2.tolist(),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def test_local_executor_handles_multiple_rounds():
|
| 161 |
+
"""Two consecutive rounds each give the right mean; round counter advances."""
|
| 162 |
+
n_replicas = 3
|
| 163 |
+
with tempfile.TemporaryDirectory() as td:
|
| 164 |
+
rendezvous = os.path.join(td, "run-2round")
|
| 165 |
+
executor = LocalProcessExecutor()
|
| 166 |
+
handles = executor.launch_replicas(
|
| 167 |
+
n_replicas=n_replicas,
|
| 168 |
+
entrypoint=f"{__name__}._replica_two_round_sync",
|
| 169 |
+
entrypoint_args={
|
| 170 |
+
"rendezvous_uri": rendezvous,
|
| 171 |
+
"world_size": n_replicas,
|
| 172 |
+
},
|
| 173 |
+
timeout=180,
|
| 174 |
+
)
|
| 175 |
+
results = executor.collect(handles, timeout=180)
|
| 176 |
+
for r in results:
|
| 177 |
+
assert r["status"] == "succeeded", r.get("error")
|
| 178 |
+
assert r["result"]["round_after_2_calls"] == 2
|
| 179 |
+
# mean of 0,1,2 = 1.0
|
| 180 |
+
assert all(abs(v - 1.0) < 1e-4 for v in r["result"]["avg1"])
|
| 181 |
+
# mean of 0,100,200 = 100.0
|
| 182 |
+
assert all(abs(v - 100.0) < 1e-4 for v in r["result"]["avg2"])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _replica_that_raises(rendezvous_uri: str, world_size: int) -> dict:
|
| 186 |
+
"""Simulates a replica that crashes mid-run."""
|
| 187 |
+
rank = int(os.environ["REPLICA_RANK"])
|
| 188 |
+
if rank == 1:
|
| 189 |
+
raise RuntimeError(f"Simulated crash on rank {rank}")
|
| 190 |
+
return {"rank": rank, "ok": True}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def test_local_executor_reports_failed_replicas():
|
| 194 |
+
"""When a replica crashes, collect() reports it as failed without hanging
|
| 195 |
+
(other ranks complete; the failed one should be reflected in the result)."""
|
| 196 |
+
n_replicas = 2
|
| 197 |
+
with tempfile.TemporaryDirectory() as td:
|
| 198 |
+
rendezvous = os.path.join(td, "run-failure")
|
| 199 |
+
executor = LocalProcessExecutor()
|
| 200 |
+
handles = executor.launch_replicas(
|
| 201 |
+
n_replicas=n_replicas,
|
| 202 |
+
entrypoint=f"{__name__}._replica_that_raises",
|
| 203 |
+
entrypoint_args={
|
| 204 |
+
"rendezvous_uri": rendezvous,
|
| 205 |
+
"world_size": n_replicas,
|
| 206 |
+
},
|
| 207 |
+
timeout=30,
|
| 208 |
+
)
|
| 209 |
+
results = executor.collect(handles, timeout=30)
|
| 210 |
+
statuses = {r["rank"]: r["status"] for r in results}
|
| 211 |
+
assert statuses[0] == "succeeded"
|
| 212 |
+
assert statuses[1] == "failed"
|
| 213 |
+
# Failure log should mention the simulated crash
|
| 214 |
+
failure_log = next(r for r in results if r["rank"] == 1).get("error") or ""
|
| 215 |
+
assert "Simulated crash" in failure_log
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ---------------------------------------------------------------------
|
| 219 |
+
# Sanity: MockManager is shape-compatible with torchft Manager surface
|
| 220 |
+
# ---------------------------------------------------------------------
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def test_mock_manager_shape_compat():
|
| 224 |
+
from composer_replication.diloco.serverless import MockManager
|
| 225 |
+
with tempfile.TemporaryDirectory() as td:
|
| 226 |
+
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
|
| 227 |
+
mgr = MockManager(store)
|
| 228 |
+
# torchft.Manager surface
|
| 229 |
+
assert hasattr(mgr, "allreduce")
|
| 230 |
+
assert hasattr(mgr, "should_commit")
|
| 231 |
+
assert hasattr(mgr, "start_quorum")
|
| 232 |
+
assert hasattr(mgr, "wait_quorum")
|
| 233 |
+
assert mgr.num_participants == 1
|
| 234 |
+
assert mgr.rank == 0
|
| 235 |
+
assert mgr.should_commit() is True
|
| 236 |
+
# Single-replica allreduce is a passthrough
|
| 237 |
+
t = torch.tensor([1.0, 2.0])
|
| 238 |
+
out = mgr.allreduce(t.clone())
|
| 239 |
+
torch.testing.assert_close(out, t, atol=1e-6, rtol=1e-6)
|
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication.distillation — pluggable self-distillation losses.
|
| 2 |
+
|
| 3 |
+
Per ADR-007, three losses additive to the framework's existing
|
| 4 |
+
SDPO/OPSD (`generalized_jsd_loss`):
|
| 5 |
+
|
| 6 |
+
- SimPO: reference-free DPO replacement (channel 3 alternative)
|
| 7 |
+
- TAID: annealed teacher interpolation (wraps generalized_jsd_loss for channel 2)
|
| 8 |
+
- Entropy-Aware OPD: token-wise gated forward/reverse KL (alternative
|
| 9 |
+
channel-2 wrapper, per ICLR 2026 Spotlight)
|
| 10 |
+
|
| 11 |
+
All three are pure PyTorch — no external deps — so they ship in the core
|
| 12 |
+
package without optional extras.
|
| 13 |
+
|
| 14 |
+
Usage in `compose_loss`:
|
| 15 |
+
|
| 16 |
+
>>> from composer_replication import compose_loss
|
| 17 |
+
>>> components = compose_loss(
|
| 18 |
+
... model, batch,
|
| 19 |
+
... dpo_variant="simpo", # channel 3: DPO -> SimPO
|
| 20 |
+
... sdpo_wrapper="taid", # channel 2: SDPO -> TAID-SDPO
|
| 21 |
+
... taid_schedule_step=1500, taid_total_steps=10_000,
|
| 22 |
+
... )
|
| 23 |
+
|
| 24 |
+
Defaults are unchanged (pure DPO + pure SDPO).
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
from composer_replication.distillation.simpo import simpo_loss
|
| 29 |
+
from composer_replication.distillation.taid import taid_loss
|
| 30 |
+
from composer_replication.distillation.entropy_aware_opd import entropy_aware_opd_loss
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
"simpo_loss",
|
| 34 |
+
"taid_loss",
|
| 35 |
+
"entropy_aware_opd_loss",
|
| 36 |
+
]
|
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entropy-Aware OPD — token-wise gated forward/reverse KL.
|
| 2 |
+
|
| 3 |
+
Paper: ICLR 2026 Spotlight "Entropy-Aware On-Policy Distillation"
|
| 4 |
+
(OpenReview WSRQ37tzk1, code release pending as of 2026-05-26)
|
| 5 |
+
|
| 6 |
+
Standard reverse-KL distillation (which SDPO/OPSD belongs to) has a known
|
| 7 |
+
mode-seeking failure: when the teacher distribution has high entropy at
|
| 8 |
+
some token positions (e.g. open-ended generation), reverse KL collapses
|
| 9 |
+
the student onto a single mode, throwing away the teacher's diversity.
|
| 10 |
+
|
| 11 |
+
Forward KL is mode-covering and would handle these positions correctly,
|
| 12 |
+
but is mode-flattening in the long tail.
|
| 13 |
+
|
| 14 |
+
Entropy-Aware OPD computes the per-token entropy of the teacher
|
| 15 |
+
distribution and gates between forward and reverse KL on a per-token
|
| 16 |
+
basis: high-entropy tokens use forward KL (preserve diversity),
|
| 17 |
+
low-entropy tokens use reverse KL (sharpen toward the teacher's mode).
|
| 18 |
+
|
| 19 |
+
L = Σ_t w(t) · KL_fwd(student || teacher)_t
|
| 20 |
+
+ (1 - w(t)) · KL_rev(student || teacher)_t
|
| 21 |
+
|
| 22 |
+
Where w(t) = clamp(H_teacher(t) / H_max, 0, 1) — high entropy → forward
|
| 23 |
+
KL weight near 1, low entropy → reverse KL weight near 1.
|
| 24 |
+
|
| 25 |
+
This is a clean-room implementation from the paper's pseudocode pending
|
| 26 |
+
the official code drop. License question for the official code is open;
|
| 27 |
+
this implementation is MIT-compatible by construction.
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import math
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def teacher_entropy(teacher_logits: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
"""Per-token entropy of the teacher distribution.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
(B, T) entropy in nats.
|
| 42 |
+
"""
|
| 43 |
+
log_p = F.log_softmax(teacher_logits, dim=-1)
|
| 44 |
+
p = log_p.exp()
|
| 45 |
+
# Entropy = -Σ p log p
|
| 46 |
+
return -(p * log_p).sum(dim=-1)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def entropy_aware_opd_loss(
|
| 50 |
+
student_logits: torch.Tensor,
|
| 51 |
+
teacher_logits: torch.Tensor,
|
| 52 |
+
*,
|
| 53 |
+
labels: torch.Tensor | None = None,
|
| 54 |
+
h_max: float | None = None,
|
| 55 |
+
temperature: float = 1.0,
|
| 56 |
+
reduction: str = "batchmean",
|
| 57 |
+
) -> torch.Tensor:
|
| 58 |
+
"""Entropy-aware mixture of forward and reverse KL.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
student_logits: (B, T, V) student logits with grad
|
| 62 |
+
teacher_logits: (B, T, V) teacher logits (no grad)
|
| 63 |
+
labels: (B, T) optional 0/1 mask — only contribute loss on
|
| 64 |
+
labels==1 positions. None means contribute everywhere.
|
| 65 |
+
h_max: maximum-entropy normalizer. Defaults to log(V) (uniform-
|
| 66 |
+
distribution entropy = the max possible entropy at vocab size V).
|
| 67 |
+
temperature: temperature applied to BOTH student and teacher logits
|
| 68 |
+
before softmax
|
| 69 |
+
reduction: "batchmean" | "sum" | "mean" | "none"
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Scalar loss (or unreduced if `reduction="none"`).
|
| 73 |
+
|
| 74 |
+
Reference: ICLR 2026 Spotlight WSRQ37tzk1 §3 (clean-room implementation).
|
| 75 |
+
"""
|
| 76 |
+
if student_logits.shape != teacher_logits.shape:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"shape mismatch: student={student_logits.shape}, "
|
| 79 |
+
f"teacher={teacher_logits.shape}"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
V = student_logits.size(-1)
|
| 83 |
+
if h_max is None:
|
| 84 |
+
h_max = math.log(V)
|
| 85 |
+
|
| 86 |
+
s_log = F.log_softmax(student_logits / temperature, dim=-1)
|
| 87 |
+
t_log = F.log_softmax(teacher_logits / temperature, dim=-1)
|
| 88 |
+
|
| 89 |
+
s_p = s_log.exp()
|
| 90 |
+
t_p = t_log.exp()
|
| 91 |
+
|
| 92 |
+
# Forward KL (teacher || student): mode-covering
|
| 93 |
+
# KL(t || s) = Σ t · (log t - log s)
|
| 94 |
+
kl_fwd = (t_p * (t_log - s_log)).sum(dim=-1)
|
| 95 |
+
|
| 96 |
+
# Reverse KL (student || teacher): mode-seeking (this is what SDPO uses)
|
| 97 |
+
# KL(s || t) = Σ s · (log s - log t)
|
| 98 |
+
kl_rev = (s_p * (s_log - t_log)).sum(dim=-1)
|
| 99 |
+
|
| 100 |
+
# Per-token teacher entropy → gate weight
|
| 101 |
+
H_t = teacher_entropy(teacher_logits) # (B, T) in nats
|
| 102 |
+
w = (H_t / h_max).clamp(0.0, 1.0) # (B, T) in [0, 1]
|
| 103 |
+
|
| 104 |
+
# Mix: high entropy → forward KL; low entropy → reverse KL
|
| 105 |
+
per_token_loss = w * kl_fwd + (1 - w) * kl_rev # (B, T)
|
| 106 |
+
|
| 107 |
+
if labels is not None:
|
| 108 |
+
if labels.shape != per_token_loss.shape:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"labels shape {labels.shape} != per-token-loss shape "
|
| 111 |
+
f"{per_token_loss.shape}"
|
| 112 |
+
)
|
| 113 |
+
per_token_loss = per_token_loss * labels.float()
|
| 114 |
+
|
| 115 |
+
if reduction == "none":
|
| 116 |
+
return per_token_loss
|
| 117 |
+
if reduction == "sum":
|
| 118 |
+
return per_token_loss.sum()
|
| 119 |
+
if reduction == "mean":
|
| 120 |
+
return per_token_loss.mean()
|
| 121 |
+
if reduction == "batchmean":
|
| 122 |
+
return per_token_loss.sum() / max(1, per_token_loss.shape[0])
|
| 123 |
+
raise ValueError(f"unknown reduction: {reduction!r}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
__all__ = ["teacher_entropy", "entropy_aware_opd_loss"]
|
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SimPO loss — reference-free DPO replacement.
|
| 2 |
+
|
| 3 |
+
Paper: "SimPO: Simple Preference Optimization with a Reference-Free Reward"
|
| 4 |
+
Meng et al., NeurIPS 2024 (arXiv:2405.14734)
|
| 5 |
+
License: MIT (https://github.com/princeton-nlp/SimPO)
|
| 6 |
+
|
| 7 |
+
Standard DPO requires log-probabilities under both the policy and a
|
| 8 |
+
reference policy:
|
| 9 |
+
|
| 10 |
+
L_DPO = -log σ( β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))] )
|
| 11 |
+
|
| 12 |
+
SimPO drops the reference-policy term, replaces it with a target margin γ,
|
| 13 |
+
and uses average sequence log-probability instead of sum. This removes the
|
| 14 |
+
reference-model VRAM cost (which is a meaningful fraction of total
|
| 15 |
+
training-time memory).
|
| 16 |
+
|
| 17 |
+
L_SimPO = -log σ( β·[avg_logπ(c) - avg_logπ(r)] - γ )
|
| 18 |
+
|
| 19 |
+
Where:
|
| 20 |
+
- avg_logπ(c) = (1/|c|) · Σ_t logπ(c_t | c_<t, prompt)
|
| 21 |
+
- β: scaling factor (paper default: 2.0)
|
| 22 |
+
- γ: target margin (paper default: 1.0)
|
| 23 |
+
|
| 24 |
+
Compose with the framework: replace channel-3 `_compute_trace_replay_loss`
|
| 25 |
+
when `dpo_variant="simpo"` is passed to `compose_loss`. Inputs change:
|
| 26 |
+
SimPO does NOT consume `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`
|
| 27 |
+
(those become unused).
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def simpo_loss(
|
| 36 |
+
chosen_avg_logprobs: torch.Tensor,
|
| 37 |
+
rejected_avg_logprobs: torch.Tensor,
|
| 38 |
+
*,
|
| 39 |
+
beta: float = 2.0,
|
| 40 |
+
gamma: float = 1.0,
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
"""SimPO loss — reference-free DPO with target margin.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
chosen_avg_logprobs: (B,) average per-token log-prob of the chosen
|
| 46 |
+
response under the policy. Computed as
|
| 47 |
+
`chosen_logprobs.sum() / response_length`.
|
| 48 |
+
rejected_avg_logprobs: (B,) same for rejected.
|
| 49 |
+
beta: scaling factor (paper default 2.0)
|
| 50 |
+
gamma: target margin (paper default 1.0)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Scalar loss; lower is better.
|
| 54 |
+
|
| 55 |
+
Reference: arXiv:2405.14734 Eq. (5).
|
| 56 |
+
"""
|
| 57 |
+
if chosen_avg_logprobs.shape != rejected_avg_logprobs.shape:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"chosen and rejected avg-logprob tensors must have the same shape, "
|
| 60 |
+
f"got chosen={chosen_avg_logprobs.shape}, "
|
| 61 |
+
f"rejected={rejected_avg_logprobs.shape}"
|
| 62 |
+
)
|
| 63 |
+
logits = beta * (chosen_avg_logprobs - rejected_avg_logprobs) - gamma
|
| 64 |
+
return -F.logsigmoid(logits).mean()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def avg_sequence_logprob(
|
| 68 |
+
model_logprobs: torch.Tensor,
|
| 69 |
+
response_mask: torch.Tensor,
|
| 70 |
+
) -> torch.Tensor:
|
| 71 |
+
"""Helper: convert (B, T) per-token log-probs + (B, T) response mask into
|
| 72 |
+
(B,) per-sequence AVERAGE log-probability over response tokens.
|
| 73 |
+
|
| 74 |
+
SimPO uses the average (not sum) so that long sequences aren't
|
| 75 |
+
penalized for having many tokens. The mask should be 1 on response
|
| 76 |
+
tokens and 0 on prompt+padding.
|
| 77 |
+
"""
|
| 78 |
+
masked = model_logprobs * response_mask.float()
|
| 79 |
+
n_tokens = response_mask.sum(dim=-1).clamp_min(1.0).float()
|
| 80 |
+
return masked.sum(dim=-1) / n_tokens
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
__all__ = ["simpo_loss", "avg_sequence_logprob"]
|
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TAID loss — Temporally Adaptive Interpolated Distillation.
|
| 2 |
+
|
| 3 |
+
Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient
|
| 4 |
+
Knowledge Transfer in Language Models"
|
| 5 |
+
Sakana AI, arXiv:2501.16937
|
| 6 |
+
License: Apache-2.0 (https://github.com/SakanaAI/TAID)
|
| 7 |
+
|
| 8 |
+
Standard JSD/KL distillation on a large student-teacher capacity gap can
|
| 9 |
+
suffer from mode collapse: the student converges to a degenerate point
|
| 10 |
+
distribution that minimizes the KL by ignoring tail probabilities.
|
| 11 |
+
|
| 12 |
+
TAID interpolates between an "identity" target (the student's own
|
| 13 |
+
distribution at step 0) and the teacher's distribution, with the
|
| 14 |
+
interpolation coefficient annealed from 0 → 1 over training:
|
| 15 |
+
|
| 16 |
+
P_target(t) = (1 - α(t)) · P_student_init + α(t) · P_teacher
|
| 17 |
+
|
| 18 |
+
Where α(t) is a schedule (linear, cosine, or paper-default exp ramp).
|
| 19 |
+
|
| 20 |
+
The student then learns against `P_target(t)` using the standard JSD/KL
|
| 21 |
+
loss. As training progresses, the target shifts smoothly from "what you
|
| 22 |
+
already are" toward "what the teacher knows," giving the student a
|
| 23 |
+
smooth path through capacity-gap regions where naive distillation
|
| 24 |
+
collapses.
|
| 25 |
+
|
| 26 |
+
Compose with the framework: TAID *wraps* `generalized_jsd_loss`. The
|
| 27 |
+
wrapper passes a blended target instead of the raw teacher target. When
|
| 28 |
+
`taid_alpha=1.0` we recover pure SDPO (the standard JSD/OPSD path).
|
| 29 |
+
"""
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import math
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def taid_alpha_schedule(
|
| 39 |
+
step: int,
|
| 40 |
+
total_steps: int,
|
| 41 |
+
*,
|
| 42 |
+
schedule: str = "linear",
|
| 43 |
+
alpha_min: float = 0.0,
|
| 44 |
+
alpha_max: float = 1.0,
|
| 45 |
+
warmup_frac: float = 0.0,
|
| 46 |
+
) -> float:
|
| 47 |
+
"""Compute α(t) for the TAID schedule.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
step: current training step (0-indexed)
|
| 51 |
+
total_steps: total training steps planned
|
| 52 |
+
schedule: "linear" | "cosine" | "exp"
|
| 53 |
+
alpha_min: starting α (default 0 = pure student-init target)
|
| 54 |
+
alpha_max: ending α (default 1 = pure teacher target)
|
| 55 |
+
warmup_frac: fraction of total_steps spent at alpha_min
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
α value in [alpha_min, alpha_max]
|
| 59 |
+
|
| 60 |
+
Reference: arXiv:2501.16937 §3.2.
|
| 61 |
+
"""
|
| 62 |
+
if total_steps <= 0:
|
| 63 |
+
raise ValueError(f"total_steps must be > 0, got {total_steps}")
|
| 64 |
+
if step < 0:
|
| 65 |
+
raise ValueError(f"step must be ≥ 0, got {step}")
|
| 66 |
+
|
| 67 |
+
warmup_steps = int(total_steps * warmup_frac)
|
| 68 |
+
if step < warmup_steps:
|
| 69 |
+
return alpha_min
|
| 70 |
+
|
| 71 |
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 72 |
+
progress = min(1.0, max(0.0, progress))
|
| 73 |
+
|
| 74 |
+
if schedule == "linear":
|
| 75 |
+
alpha = alpha_min + (alpha_max - alpha_min) * progress
|
| 76 |
+
elif schedule == "cosine":
|
| 77 |
+
# 0.5 * (1 - cos(π·t)) goes 0 → 1 as t goes 0 → 1
|
| 78 |
+
alpha = alpha_min + (alpha_max - alpha_min) * 0.5 * (1 - math.cos(math.pi * progress))
|
| 79 |
+
elif schedule == "exp":
|
| 80 |
+
# Paper default: α(t) = α_min + (α_max - α_min) · (1 - exp(-5·t))
|
| 81 |
+
# Front-loads progress toward larger α
|
| 82 |
+
alpha = alpha_min + (alpha_max - alpha_min) * (1 - math.exp(-5 * progress))
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"unknown schedule: {schedule!r}")
|
| 85 |
+
|
| 86 |
+
return float(alpha)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def taid_blended_logits(
|
| 90 |
+
student_init_logits: torch.Tensor,
|
| 91 |
+
teacher_logits: torch.Tensor,
|
| 92 |
+
alpha: float,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
"""Blend the "student-at-init" and teacher logits in probability space.
|
| 95 |
+
|
| 96 |
+
Returns logits of `(1 - α)·P_student_init + α·P_teacher`.
|
| 97 |
+
Internally:
|
| 98 |
+
1. softmax both → P_student_init, P_teacher (in prob space)
|
| 99 |
+
2. linear interpolate
|
| 100 |
+
3. log → blended logits
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
student_init_logits: (B, T, V) student logits at training start
|
| 104 |
+
(frozen — keep a snapshot from step 0)
|
| 105 |
+
teacher_logits: (B, T, V) teacher logits (e.g., hint-conditioned
|
| 106 |
+
forward pass per SDPO)
|
| 107 |
+
alpha: interpolation coefficient in [0, 1]
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
(B, T, V) logits whose softmax is the blended target distribution.
|
| 111 |
+
"""
|
| 112 |
+
if not (0.0 <= alpha <= 1.0):
|
| 113 |
+
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
|
| 114 |
+
if student_init_logits.shape != teacher_logits.shape:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"shape mismatch: student_init={student_init_logits.shape}, "
|
| 117 |
+
f"teacher={teacher_logits.shape}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Mix in probability space, then log to get logits
|
| 121 |
+
p_student_init = F.softmax(student_init_logits, dim=-1)
|
| 122 |
+
p_teacher = F.softmax(teacher_logits, dim=-1)
|
| 123 |
+
p_blended = (1 - alpha) * p_student_init + alpha * p_teacher
|
| 124 |
+
# Clamp for numerical stability before log
|
| 125 |
+
p_blended = p_blended.clamp_min(1e-12)
|
| 126 |
+
return torch.log(p_blended)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def taid_loss(
|
| 130 |
+
student_logits: torch.Tensor,
|
| 131 |
+
teacher_logits: torch.Tensor,
|
| 132 |
+
student_init_logits: torch.Tensor,
|
| 133 |
+
*,
|
| 134 |
+
schedule_step: int,
|
| 135 |
+
total_steps: int,
|
| 136 |
+
schedule: str = "linear",
|
| 137 |
+
alpha_min: float = 0.0,
|
| 138 |
+
alpha_max: float = 1.0,
|
| 139 |
+
jsd_beta: float = 0.5,
|
| 140 |
+
temperature: float = 1.0,
|
| 141 |
+
reduction: str = "batchmean",
|
| 142 |
+
) -> torch.Tensor:
|
| 143 |
+
"""TAID-wrapped generalized-JSD loss.
|
| 144 |
+
|
| 145 |
+
Wraps the framework's `generalized_jsd_loss` (= SDPO/OPSD) with the
|
| 146 |
+
TAID schedule. At α=0 the loss target is the student's own initial
|
| 147 |
+
distribution (essentially a regularizer); at α=1 it's the standard
|
| 148 |
+
JSD-against-teacher (SDPO).
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
student_logits: (B, T, V) current student logits with grad
|
| 152 |
+
teacher_logits: (B, T, V) teacher logits (no grad — same model
|
| 153 |
+
different context per SDPO, or different model per real
|
| 154 |
+
distillation)
|
| 155 |
+
student_init_logits: (B, T, V) student logits captured at step 0
|
| 156 |
+
of training. Caller must save this and pass it in.
|
| 157 |
+
schedule_step: current training step
|
| 158 |
+
total_steps: total planned training steps
|
| 159 |
+
schedule: "linear" | "cosine" | "exp" — see `taid_alpha_schedule`
|
| 160 |
+
alpha_min, alpha_max: schedule range (defaults 0, 1)
|
| 161 |
+
jsd_beta: β param of generalized_jsd_loss (0=fwd KL, 0.5=JSD,
|
| 162 |
+
1=rev KL)
|
| 163 |
+
temperature: temperature for both student and target
|
| 164 |
+
reduction: "batchmean" | "sum" | "mean" | "none"
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Scalar loss (or unreduced tensor if `reduction="none"`).
|
| 168 |
+
|
| 169 |
+
Reference: arXiv:2501.16937 Eq. (4) + §3.2.
|
| 170 |
+
"""
|
| 171 |
+
# Lazy-import generalized_jsd_loss to avoid circular import
|
| 172 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 173 |
+
|
| 174 |
+
alpha = taid_alpha_schedule(
|
| 175 |
+
step=schedule_step,
|
| 176 |
+
total_steps=total_steps,
|
| 177 |
+
schedule=schedule,
|
| 178 |
+
alpha_min=alpha_min,
|
| 179 |
+
alpha_max=alpha_max,
|
| 180 |
+
)
|
| 181 |
+
blended_logits = taid_blended_logits(
|
| 182 |
+
student_init_logits=student_init_logits,
|
| 183 |
+
teacher_logits=teacher_logits,
|
| 184 |
+
alpha=alpha,
|
| 185 |
+
)
|
| 186 |
+
return generalized_jsd_loss(
|
| 187 |
+
student_logits=student_logits,
|
| 188 |
+
teacher_logits=blended_logits,
|
| 189 |
+
beta=jsd_beta,
|
| 190 |
+
temperature=temperature,
|
| 191 |
+
reduction=reduction,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
__all__ = ["taid_alpha_schedule", "taid_blended_logits", "taid_loss"]
|
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Distillation-loss unit tests — SimPO + TAID + Entropy-Aware OPD."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from composer_replication.distillation import (
|
| 11 |
+
entropy_aware_opd_loss,
|
| 12 |
+
simpo_loss,
|
| 13 |
+
taid_loss,
|
| 14 |
+
)
|
| 15 |
+
from composer_replication.distillation.simpo import avg_sequence_logprob
|
| 16 |
+
from composer_replication.distillation.taid import (
|
| 17 |
+
taid_alpha_schedule,
|
| 18 |
+
taid_blended_logits,
|
| 19 |
+
)
|
| 20 |
+
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------
|
| 24 |
+
# SimPO
|
| 25 |
+
# ---------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
def test_simpo_loss_returns_scalar():
|
| 28 |
+
chosen = torch.tensor([0.5, 0.4, 0.3])
|
| 29 |
+
rejected = torch.tensor([0.1, 0.0, -0.2])
|
| 30 |
+
loss = simpo_loss(chosen, rejected, beta=2.0, gamma=1.0)
|
| 31 |
+
assert loss.dim() == 0
|
| 32 |
+
assert torch.isfinite(loss)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_simpo_loss_lower_for_better_separation():
|
| 36 |
+
"""Larger margin between chosen and rejected → lower loss."""
|
| 37 |
+
# Same setup, two batches with different separations
|
| 38 |
+
small_sep_loss = simpo_loss(
|
| 39 |
+
torch.tensor([0.1]), torch.tensor([0.05]),
|
| 40 |
+
)
|
| 41 |
+
large_sep_loss = simpo_loss(
|
| 42 |
+
torch.tensor([1.0]), torch.tensor([-1.0]),
|
| 43 |
+
)
|
| 44 |
+
assert large_sep_loss < small_sep_loss, (
|
| 45 |
+
f"large separation should give smaller loss; "
|
| 46 |
+
f"got small_sep={small_sep_loss}, large_sep={large_sep_loss}"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_simpo_loss_differentiable():
|
| 51 |
+
chosen = torch.tensor([0.5], requires_grad=True)
|
| 52 |
+
rejected = torch.tensor([0.0], requires_grad=True)
|
| 53 |
+
loss = simpo_loss(chosen, rejected)
|
| 54 |
+
loss.backward()
|
| 55 |
+
assert chosen.grad is not None
|
| 56 |
+
assert rejected.grad is not None
|
| 57 |
+
assert torch.isfinite(chosen.grad).all()
|
| 58 |
+
assert torch.isfinite(rejected.grad).all()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_simpo_loss_shape_mismatch_raises():
|
| 62 |
+
with pytest.raises(ValueError, match="same shape"):
|
| 63 |
+
simpo_loss(torch.zeros(3), torch.zeros(5))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_avg_sequence_logprob():
|
| 67 |
+
"""Helper averages over response tokens, ignoring prompt + padding."""
|
| 68 |
+
# B=2, T=4
|
| 69 |
+
logprobs = torch.tensor([
|
| 70 |
+
[-10.0, -10.0, -1.0, -2.0], # response is last 2 tokens, avg=-1.5
|
| 71 |
+
[-1.0, -3.0, -1.0, -10.0], # response is first 3 tokens, avg=-5/3
|
| 72 |
+
])
|
| 73 |
+
mask = torch.tensor([
|
| 74 |
+
[0, 0, 1, 1],
|
| 75 |
+
[1, 1, 1, 0],
|
| 76 |
+
])
|
| 77 |
+
avg = avg_sequence_logprob(logprobs, mask)
|
| 78 |
+
expected = torch.tensor([-1.5, -5.0 / 3.0])
|
| 79 |
+
torch.testing.assert_close(avg, expected, atol=1e-5, rtol=1e-5)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------
|
| 83 |
+
# TAID
|
| 84 |
+
# ---------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def test_taid_alpha_schedule_endpoints():
|
| 87 |
+
"""At step 0 → alpha_min; at step total → alpha_max."""
|
| 88 |
+
assert taid_alpha_schedule(0, 100, schedule="linear") == 0.0
|
| 89 |
+
assert taid_alpha_schedule(100, 100, schedule="linear") == 1.0
|
| 90 |
+
assert taid_alpha_schedule(0, 100, schedule="cosine") == 0.0
|
| 91 |
+
assert taid_alpha_schedule(100, 100, schedule="cosine") == pytest.approx(1.0)
|
| 92 |
+
assert taid_alpha_schedule(0, 100, schedule="exp") == pytest.approx(0.0)
|
| 93 |
+
assert taid_alpha_schedule(100, 100, schedule="exp") == pytest.approx(1 - math.exp(-5))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_taid_alpha_schedule_monotonic_linear():
|
| 97 |
+
prev = -1.0
|
| 98 |
+
for step in [0, 10, 25, 50, 75, 90, 100]:
|
| 99 |
+
a = taid_alpha_schedule(step, 100, schedule="linear")
|
| 100 |
+
assert a >= prev
|
| 101 |
+
prev = a
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_taid_alpha_schedule_warmup():
|
| 105 |
+
"""During warmup_frac, alpha stays at alpha_min."""
|
| 106 |
+
a_warmup = taid_alpha_schedule(50, 1000, warmup_frac=0.1, schedule="linear")
|
| 107 |
+
# warmup_steps = 100, step 50 < 100 → still alpha_min
|
| 108 |
+
assert a_warmup == 0.0
|
| 109 |
+
a_post_warmup = taid_alpha_schedule(150, 1000, warmup_frac=0.1, schedule="linear")
|
| 110 |
+
# post-warmup, partial way through remaining 900 steps
|
| 111 |
+
assert a_post_warmup > 0.0
|
| 112 |
+
assert a_post_warmup < 1.0
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_taid_blended_logits_endpoints():
|
| 116 |
+
"""alpha=0 → student_init target; alpha=1 → teacher target."""
|
| 117 |
+
# Use logits with strong peaks to make endpoint behavior obvious
|
| 118 |
+
student_init = torch.zeros(2, 3, 4)
|
| 119 |
+
student_init[0, 0, 0] = 10.0 # peaks at index 0
|
| 120 |
+
teacher = torch.zeros(2, 3, 4)
|
| 121 |
+
teacher[0, 0, 3] = 10.0 # peaks at index 3
|
| 122 |
+
|
| 123 |
+
blended_alpha0 = taid_blended_logits(student_init, teacher, alpha=0.0)
|
| 124 |
+
blended_alpha1 = taid_blended_logits(student_init, teacher, alpha=1.0)
|
| 125 |
+
blended_half = taid_blended_logits(student_init, teacher, alpha=0.5)
|
| 126 |
+
|
| 127 |
+
# alpha=0: argmax follows student_init
|
| 128 |
+
assert blended_alpha0[0, 0].argmax().item() == 0
|
| 129 |
+
# alpha=1: argmax follows teacher
|
| 130 |
+
assert blended_alpha1[0, 0].argmax().item() == 3
|
| 131 |
+
# alpha=0.5: bimodal; both 0 and 3 should be elevated
|
| 132 |
+
half_probs = F.softmax(blended_half[0, 0], dim=-1)
|
| 133 |
+
assert half_probs[0] > 0.4
|
| 134 |
+
assert half_probs[3] > 0.4
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_taid_loss_returns_scalar_and_differentiable():
|
| 138 |
+
B, T, V = 2, 4, 8
|
| 139 |
+
student_logits = torch.randn(B, T, V, requires_grad=True)
|
| 140 |
+
teacher_logits = torch.randn(B, T, V)
|
| 141 |
+
student_init = torch.randn(B, T, V)
|
| 142 |
+
loss = taid_loss(
|
| 143 |
+
student_logits, teacher_logits, student_init,
|
| 144 |
+
schedule_step=500, total_steps=1000,
|
| 145 |
+
)
|
| 146 |
+
assert loss.dim() == 0
|
| 147 |
+
assert torch.isfinite(loss)
|
| 148 |
+
loss.backward()
|
| 149 |
+
assert student_logits.grad is not None
|
| 150 |
+
assert torch.isfinite(student_logits.grad).all()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def test_taid_loss_alpha_zero_ignores_teacher():
|
| 154 |
+
"""At alpha=0, teacher gradient should not flow through to student."""
|
| 155 |
+
B, T, V = 1, 2, 4
|
| 156 |
+
student_init = torch.randn(B, T, V)
|
| 157 |
+
s1 = torch.randn(B, T, V, requires_grad=True)
|
| 158 |
+
teacher_a = torch.zeros(B, T, V)
|
| 159 |
+
teacher_a[..., 0] = 10.0
|
| 160 |
+
teacher_b = torch.zeros(B, T, V)
|
| 161 |
+
teacher_b[..., 3] = 10.0
|
| 162 |
+
# At step 0 with alpha_min=alpha_max=0, alpha is forced to 0 → blended = student_init
|
| 163 |
+
loss_a = taid_loss(s1, teacher_a, student_init, schedule_step=0, total_steps=100,
|
| 164 |
+
alpha_min=0.0, alpha_max=0.0)
|
| 165 |
+
loss_b = taid_loss(s1, teacher_b, student_init, schedule_step=0, total_steps=100,
|
| 166 |
+
alpha_min=0.0, alpha_max=0.0)
|
| 167 |
+
# Different teachers should give the same loss when alpha is pinned to 0
|
| 168 |
+
assert abs(float(loss_a) - float(loss_b)) < 1e-4
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------
|
| 172 |
+
# Entropy-Aware OPD
|
| 173 |
+
# ---------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def test_teacher_entropy_one_hot_is_zero():
|
| 176 |
+
"""Argmax-1 distribution has entropy 0."""
|
| 177 |
+
logits = torch.zeros(1, 1, 4)
|
| 178 |
+
logits[..., 0] = 100.0 # essentially one-hot
|
| 179 |
+
H = teacher_entropy(logits)
|
| 180 |
+
assert float(H[0, 0]) < 1e-3
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def test_teacher_entropy_uniform_is_log_v():
|
| 184 |
+
"""Uniform distribution over V symbols has entropy = log(V)."""
|
| 185 |
+
logits = torch.zeros(1, 1, 5)
|
| 186 |
+
H = teacher_entropy(logits)
|
| 187 |
+
assert float(H[0, 0]) == pytest.approx(math.log(5), rel=1e-5)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test_entropy_aware_opd_returns_scalar_and_differentiable():
|
| 191 |
+
B, T, V = 2, 3, 8
|
| 192 |
+
student_logits = torch.randn(B, T, V, requires_grad=True)
|
| 193 |
+
teacher_logits = torch.randn(B, T, V)
|
| 194 |
+
loss = entropy_aware_opd_loss(student_logits, teacher_logits)
|
| 195 |
+
assert loss.dim() == 0
|
| 196 |
+
assert torch.isfinite(loss)
|
| 197 |
+
loss.backward()
|
| 198 |
+
assert student_logits.grad is not None
|
| 199 |
+
assert torch.isfinite(student_logits.grad).all()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def test_entropy_aware_opd_with_label_mask():
|
| 203 |
+
"""Label mask should zero out per-token loss on labels==0 positions."""
|
| 204 |
+
B, T, V = 1, 4, 6
|
| 205 |
+
student_logits = torch.randn(B, T, V, requires_grad=True)
|
| 206 |
+
teacher_logits = torch.randn(B, T, V)
|
| 207 |
+
full_loss = entropy_aware_opd_loss(student_logits, teacher_logits)
|
| 208 |
+
half_mask = torch.tensor([[1, 1, 0, 0]])
|
| 209 |
+
half_loss = entropy_aware_opd_loss(
|
| 210 |
+
student_logits, teacher_logits, labels=half_mask,
|
| 211 |
+
)
|
| 212 |
+
# half_loss should be ~half of the unmasked sum (modulo the entropy gating
|
| 213 |
+
# being position-dependent — but it should at least be < full_loss)
|
| 214 |
+
assert float(half_loss) < float(full_loss)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def test_entropy_aware_opd_zero_when_distributions_match():
|
| 218 |
+
"""When student and teacher are identical, both KLs are 0 → loss is 0."""
|
| 219 |
+
logits = torch.randn(1, 2, 4)
|
| 220 |
+
loss = entropy_aware_opd_loss(logits, logits)
|
| 221 |
+
assert float(loss) < 1e-5
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def test_entropy_aware_opd_reduction_modes():
|
| 225 |
+
student_logits = torch.randn(2, 3, 4, requires_grad=True)
|
| 226 |
+
teacher_logits = torch.randn(2, 3, 4)
|
| 227 |
+
none_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="none")
|
| 228 |
+
mean_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="mean")
|
| 229 |
+
sum_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="sum")
|
| 230 |
+
batchmean_loss = entropy_aware_opd_loss(student_logits, teacher_logits, reduction="batchmean")
|
| 231 |
+
assert none_loss.shape == (2, 3)
|
| 232 |
+
assert mean_loss.dim() == 0
|
| 233 |
+
assert sum_loss.dim() == 0
|
| 234 |
+
assert batchmean_loss.dim() == 0
|
| 235 |
+
# batchmean = sum / batch_size
|
| 236 |
+
assert abs(float(batchmean_loss) - float(sum_loss) / 2) < 1e-4
|
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Monarch actor skeletons — DESIGN/SKELETON for v0.
|
| 2 |
+
|
| 3 |
+
Per ADR-006, full Monarch integration is deferred to v0.2+. This file
|
| 4 |
+
documents the actor signatures so the framework's recipe matrix is
|
| 5 |
+
complete.
|
| 6 |
+
|
| 7 |
+
Importing this module does NOT require monarch installed; the imports
|
| 8 |
+
are deferred inside class bodies. Real instantiation will fail without
|
| 9 |
+
monarch, which is the desired behavior for a recipe document.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TrainerActor:
|
| 17 |
+
"""Hosts the framework's 3-channel composer trainer.
|
| 18 |
+
|
| 19 |
+
Real implementation (v0.2+):
|
| 20 |
+
|
| 21 |
+
from monarch import Actor, endpoint
|
| 22 |
+
|
| 23 |
+
class TrainerActor(Actor):
|
| 24 |
+
@endpoint
|
| 25 |
+
async def train_outer_step(self, batch_id: int) -> dict:
|
| 26 |
+
# 1. Pull batch from generator
|
| 27 |
+
# 2. Run inner H steps with composer compose_loss
|
| 28 |
+
# 3. Compute pseudo-gradient
|
| 29 |
+
# 4. Hand to ObjectStoreAllReduce manager
|
| 30 |
+
# 5. Apply outer SGD step
|
| 31 |
+
# 6. Return metrics dict
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
For v0 the actor is just a documentation stub.
|
| 35 |
+
"""
|
| 36 |
+
backend = "monarch"
|
| 37 |
+
role = "trainer"
|
| 38 |
+
|
| 39 |
+
def __init__(self) -> None:
|
| 40 |
+
raise NotImplementedError(
|
| 41 |
+
"Monarch trainer actor is a v0 skeleton; implementation "
|
| 42 |
+
"deferred to v0.2 per ADR-006."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
async def train_outer_step(self, batch_id: int) -> dict[str, Any]:
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GeneratorActor:
|
| 50 |
+
"""vllm-backed rollout actor."""
|
| 51 |
+
backend = "monarch"
|
| 52 |
+
role = "generator"
|
| 53 |
+
|
| 54 |
+
def __init__(self) -> None:
|
| 55 |
+
raise NotImplementedError("v0 skeleton — see ADR-006.")
|
| 56 |
+
|
| 57 |
+
async def rollout(self, prompts: list[str]) -> list[str]:
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RewarderActor:
|
| 62 |
+
"""verifiers-protocol rewarder for RLVR-style RL."""
|
| 63 |
+
backend = "monarch"
|
| 64 |
+
role = "rewarder"
|
| 65 |
+
|
| 66 |
+
def __init__(self) -> None:
|
| 67 |
+
raise NotImplementedError("v0 skeleton — see ADR-006.")
|
| 68 |
+
|
| 69 |
+
async def score(self, completions: list[str]) -> list[float]:
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TeacherPoolActor:
|
| 74 |
+
"""Channel-3 teacher pool — wraps composer_replication.teacher_replay."""
|
| 75 |
+
backend = "monarch"
|
| 76 |
+
role = "teacher_pool"
|
| 77 |
+
|
| 78 |
+
def __init__(self) -> None:
|
| 79 |
+
raise NotImplementedError("v0 skeleton — see ADR-006.")
|
| 80 |
+
|
| 81 |
+
async def replay(self, states: list[dict]) -> list[dict]:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
__all__ = [
|
| 86 |
+
"GeneratorActor",
|
| 87 |
+
"RewarderActor",
|
| 88 |
+
"TeacherPoolActor",
|
| 89 |
+
"TrainerActor",
|
| 90 |
+
]
|
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Monarch actor mesh — design for hosting the framework's training topology
|
| 2 |
+
|
| 3 |
+
**Status**: Design + skeleton. Real Monarch integration is post-replication
|
| 4 |
+
work (ADR-006 explicitly defers it to v0.2+).
|
| 5 |
+
|
| 6 |
+
**ADR**: 006
|
| 7 |
+
|
| 8 |
+
## What Monarch is
|
| 9 |
+
|
| 10 |
+
Monarch (https://github.com/meta-pytorch/monarch, BSD-3) is Meta's actor-
|
| 11 |
+
mesh runtime — a thin coordination layer over Python processes that lets
|
| 12 |
+
you describe a training topology as a graph of typed actors, then run
|
| 13 |
+
that topology on top of any cluster manager (k8s, Slurm, raw ssh).
|
| 14 |
+
|
| 15 |
+
Per ADR-006, Monarch is the only Meta PyTorch agentic-stack component
|
| 16 |
+
that's actively shipping (v0.4.1 stable, v0.5 dev daily) and not paused.
|
| 17 |
+
TorchForge, the original "agent" piece, is paused per its own repo banner.
|
| 18 |
+
|
| 19 |
+
## Why Monarch fits the framework's design
|
| 20 |
+
|
| 21 |
+
The framework already has an N-actor topology even without Monarch:
|
| 22 |
+
- Trainer (channel 1: GRPO; channel 2: SDPO; channel 3: trace-replay DPO)
|
| 23 |
+
- Generator (rollout / vllm)
|
| 24 |
+
- Rewarder (RLVR test runner / verifiers protocol)
|
| 25 |
+
- N teachers (channel 3: external OpenRouter calls)
|
| 26 |
+
- DiLoCo replicas (N copies of trainer, syncing via object store)
|
| 27 |
+
|
| 28 |
+
PRIME-RL gives us the trainer/generator/rewarder split for free. Monarch
|
| 29 |
+
takes that further: each of those becomes a Monarch actor, and the framework
|
| 30 |
+
gains:
|
| 31 |
+
1. **Heterogeneous executor support** — actors run wherever Monarch's
|
| 32 |
+
backend places them (Modal, k8s, on-prem cluster). Composes naturally
|
| 33 |
+
with our `ServerlessExecutor` Protocol.
|
| 34 |
+
2. **Failure recovery** — Monarch handles actor crashes + restarts;
|
| 35 |
+
the framework's DiLoCo state is durable in object storage, so a
|
| 36 |
+
restarted trainer replica can resume from the last outer round.
|
| 37 |
+
3. **Hot-swap of actor implementations** — switch teacher backends
|
| 38 |
+
from "OpenRouter" to "local vllm" by changing one Monarch actor
|
| 39 |
+
binding.
|
| 40 |
+
|
| 41 |
+
## Actor topology (proposed)
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
┌───────────────────────────────────────────────────────────────┐
|
| 45 |
+
│ ComposerReplicationMesh │
|
| 46 |
+
│ │
|
| 47 |
+
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
|
| 48 |
+
│ │ Trainer × N │←─│ Generator │←─│ Rewarder │ │
|
| 49 |
+
│ │ (DiLoCo │ │ (vllm) │ │ (verifiers) │ │
|
| 50 |
+
│ │ replicas) │ └──────────────┘ └──────────────────┘ │
|
| 51 |
+
│ └──────┬───────┘ │
|
| 52 |
+
│ │ │
|
| 53 |
+
│ │ Channel 2: same-model hint-conditioned forward │
|
| 54 |
+
│ │ Channel 3: cross-model OpenRouter teachers │
|
| 55 |
+
│ ▼ │
|
| 56 |
+
│ ┌──────────────┐ │
|
| 57 |
+
│ │ TeacherPool │ ── OpenRouter (Claude, GPT, DeepSeek, ...) │
|
| 58 |
+
│ │ (channel 3) │ │
|
| 59 |
+
│ └──────────────┘ │
|
| 60 |
+
│ │
|
| 61 |
+
│ ┌──────────────────────────────────────────────────────────┐ │
|
| 62 |
+
│ │ ObjectStore (s3://, hf://, file://) │ │
|
| 63 |
+
│ │ · DiLoCo pseudo-gradients (round_N/rank_R.pt) │ │
|
| 64 |
+
│ │ · Replay datasets (NormalizedDPOPair JSONL) │ │
|
| 65 |
+
│ └──────────────────────────────────────────────────────────┘ │
|
| 66 |
+
└────────────────────────────────────────────────────────────────┘
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Mapping to Monarch primitives
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
from monarch import Actor, mesh, endpoint
|
| 73 |
+
|
| 74 |
+
class TrainerActor(Actor):
|
| 75 |
+
"""Hosts the GRPO trainer + composer 3-channel loss."""
|
| 76 |
+
@endpoint
|
| 77 |
+
async def train_outer_step(self, batch_id: int): ...
|
| 78 |
+
|
| 79 |
+
class GeneratorActor(Actor):
|
| 80 |
+
"""vllm rollout server — generates trajectories on demand."""
|
| 81 |
+
@endpoint
|
| 82 |
+
async def rollout(self, prompts: list[str]) -> list[str]: ...
|
| 83 |
+
|
| 84 |
+
class RewarderActor(Actor):
|
| 85 |
+
"""Runs verifiers protocol — RLVR-style test execution."""
|
| 86 |
+
@endpoint
|
| 87 |
+
async def score(self, completions: list[str]) -> list[float]: ...
|
| 88 |
+
|
| 89 |
+
class TeacherPoolActor(Actor):
|
| 90 |
+
"""Channel 3 — OpenRouter calls to N external teachers."""
|
| 91 |
+
@endpoint
|
| 92 |
+
async def replay(self, states: list[dict]) -> list[dict]: ...
|
| 93 |
+
|
| 94 |
+
# Topology
|
| 95 |
+
trainers = mesh.spawn(TrainerActor, n=4, gpu="A100")
|
| 96 |
+
generator = mesh.spawn(GeneratorActor, n=1, gpu="A100")
|
| 97 |
+
rewarder = mesh.spawn(RewarderActor, n=1, gpu=None)
|
| 98 |
+
teachers = mesh.spawn(TeacherPoolActor, n=1, gpu=None)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Status of this directory
|
| 102 |
+
|
| 103 |
+
- `monarch_actor_layout.md` — this file (design)
|
| 104 |
+
- `actors.py` — skeleton actor definitions; do not import without
|
| 105 |
+
monarch installed
|
| 106 |
+
- `composer_mesh.py` — composition glue; not yet implemented
|
| 107 |
+
|
| 108 |
+
## Open questions (deferred to v0.2)
|
| 109 |
+
|
| 110 |
+
- Does Monarch v0.5's Slurm backend hand-shake cleanly with HF Jobs?
|
| 111 |
+
(HF Jobs runs each "job" as an independent container; Monarch wants
|
| 112 |
+
to manage the lifecycle. Possible mismatch.)
|
| 113 |
+
- Can the `TrainerActor` host the framework's `ComposerReplicationTrainer`
|
| 114 |
+
unmodified, or does it need to be split into `step_init` /
|
| 115 |
+
`step_compute` endpoints to fit Monarch's async actor model?
|
| 116 |
+
|
| 117 |
+
## References
|
| 118 |
+
|
| 119 |
+
- Monarch repo: https://github.com/meta-pytorch/monarch
|
| 120 |
+
- ADR-006: docs/adrs/ADR-006-rl-frameworks.md
|
| 121 |
+
- Reconnaissance: docs/research/RL_FRAMEWORKS_LANDSCAPE.md § Monarch
|
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PRIME-RL composer loss adapter — SKELETON for v0.
|
| 2 |
+
|
| 3 |
+
Per ADR-006, PRIME-RL exposes a `CustomLossConfig` that takes an
|
| 4 |
+
importable function. This module supplies that function: a thin adapter
|
| 5 |
+
that maps PRIME-RL's `LossInputs` struct onto the framework's 3-channel
|
| 6 |
+
loss composition.
|
| 7 |
+
|
| 8 |
+
Status: SKELETON. The full implementation requires a runtime spike with
|
| 9 |
+
prime-rl installed; this file documents the contract and provides a
|
| 10 |
+
working stub that returns a finite scalar so PRIME-RL can be configured
|
| 11 |
+
end-to-end without yet having all three channels wired up.
|
| 12 |
+
|
| 13 |
+
Reference:
|
| 14 |
+
- PRIME-RL `LossInputs` shape (verified via DeepWiki audit, Wave 13):
|
| 15 |
+
- trainer_logprobs: Tensor (B, T) — student log-probs of generated tokens
|
| 16 |
+
- inference_logprobs: Tensor (B, T) — log-probs from inference engine
|
| 17 |
+
- teacher_logprobs: Tensor (B, T) | None — optional teacher channel
|
| 18 |
+
- advantages: Tensor (B, T) — GRPO advantages
|
| 19 |
+
- loss_mask: Tensor (B, T) — response-token mask
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def loss_fn(
|
| 27 |
+
inputs: Any, # PRIME-RL's LossInputs — typed as Any to avoid hard import
|
| 28 |
+
*,
|
| 29 |
+
alpha_sdpo: float = 0.5,
|
| 30 |
+
beta_dpo: float = 0.3,
|
| 31 |
+
epsilon: float = 1e-6,
|
| 32 |
+
) -> Any: # Returns a torch.Tensor (scalar)
|
| 33 |
+
"""Composer 3-channel loss adapted to PRIME-RL's LossInputs struct.
|
| 34 |
+
|
| 35 |
+
Channels (per `composer_replication.compose_loss`):
|
| 36 |
+
1. GRPO policy-gradient: -(advantages * trainer_logprobs * mask).mean()
|
| 37 |
+
2. SDPO / OPSD: generalized_jsd_loss(student_logits, teacher_logits)
|
| 38 |
+
3. Trace-replay DPO: standard DPO on (chosen, rejected) pairs
|
| 39 |
+
|
| 40 |
+
For PRIME-RL adaptation:
|
| 41 |
+
- Channel 1 reads from `advantages` + `trainer_logprobs` directly.
|
| 42 |
+
(Note: this is REINFORCE-with-advantage, not full GRPO. Full
|
| 43 |
+
GRPO would use `inference_logprobs` for the importance-sampling
|
| 44 |
+
ratio + PPO clipping. See Wave 13 review Finding 6.)
|
| 45 |
+
- Channel 2 (SDPO) is **DEFERRED** for v0 because PRIME-RL v0.5
|
| 46 |
+
exposes log-probs not logits, and SDPO needs the full vocab
|
| 47 |
+
distribution. Setting alpha_sdpo>0 raises NotImplementedError
|
| 48 |
+
(Wave 13 review Finding 1 — earlier draft was silently degenerate).
|
| 49 |
+
- Channel 3 (DPO) is OUT OF SCOPE for the PRIME-RL recipe in v0
|
| 50 |
+
— it would require modifying PRIME-RL's data path to pass
|
| 51 |
+
`(chosen, rejected)` pairs alongside the rollout, which is a
|
| 52 |
+
separate integration effort. v0 emits beta_dpo=0 with a
|
| 53 |
+
warning if non-zero.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
inputs: PRIME-RL `LossInputs` (duck-typed)
|
| 57 |
+
alpha_sdpo: weight on channel 2 (SDPO)
|
| 58 |
+
beta_dpo: weight on channel 3 (DPO) — currently must be 0
|
| 59 |
+
epsilon: numerical stability for log/division
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Scalar torch.Tensor; PRIME-RL's trainer takes care of `.backward()`.
|
| 63 |
+
"""
|
| 64 |
+
import torch # lazy
|
| 65 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 66 |
+
|
| 67 |
+
# Channel 1: GRPO
|
| 68 |
+
advantages = inputs.advantages
|
| 69 |
+
trainer_lp = inputs.trainer_logprobs
|
| 70 |
+
mask = inputs.loss_mask
|
| 71 |
+
if mask.dtype != advantages.dtype:
|
| 72 |
+
mask = mask.to(advantages.dtype)
|
| 73 |
+
grpo_loss = -(advantages * trainer_lp * mask).sum() / mask.sum().clamp_min(epsilon)
|
| 74 |
+
|
| 75 |
+
total = grpo_loss
|
| 76 |
+
|
| 77 |
+
# Channel 2: SDPO/OPSD — DEFERRED in PRIME-RL recipe v0.
|
| 78 |
+
#
|
| 79 |
+
# Wave 13 cross-model review (docs/research/WAVE_13_FINAL_REVIEW.md
|
| 80 |
+
# Finding 1) caught that an earlier draft of this code applied
|
| 81 |
+
# `unsqueeze(-1)` to (B, T) log-prob tensors before passing them to
|
| 82 |
+
# generalized_jsd_loss, which calls log_softmax(dim=-1). Softmax of a
|
| 83 |
+
# 1-element vector is exactly 1.0; its log is 0. So the SDPO term was
|
| 84 |
+
# mathematically degenerate (always 0), silently disabling channel 2
|
| 85 |
+
# while reporting alpha_sdpo>0 in the config.
|
| 86 |
+
#
|
| 87 |
+
# The right path forward depends on PRIME-RL exposing full logits, not
|
| 88 |
+
# just log-probs. Until that lands upstream, refuse to fake the channel:
|
| 89 |
+
teacher_lp = getattr(inputs, "teacher_logprobs", None)
|
| 90 |
+
if teacher_lp is not None and alpha_sdpo > 0:
|
| 91 |
+
raise NotImplementedError(
|
| 92 |
+
"SDPO channel in the PRIME-RL recipe is deferred. PRIME-RL v0.5 "
|
| 93 |
+
"exposes (B, T) log-probs through LossInputs but not full logits, "
|
| 94 |
+
"and SDPO/OPSD requires the full distribution over vocabulary. "
|
| 95 |
+
"Set alpha_sdpo=0.0 to silence this and use channel 1 (GRPO) only. "
|
| 96 |
+
"See docs/research/WAVE_13_FINAL_REVIEW.md Finding 1."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Channel 3: not supported in PRIME-RL recipe v0
|
| 100 |
+
if beta_dpo != 0.0:
|
| 101 |
+
import warnings
|
| 102 |
+
warnings.warn(
|
| 103 |
+
"PRIME-RL recipe v0 does not support DPO channel; "
|
| 104 |
+
"set beta_dpo=0.0 to silence this warning.",
|
| 105 |
+
stacklevel=2,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return total
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
__all__ = ["loss_fn"]
|
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PRIME-RL config wiring the framework's 3-channel composer loss.
|
| 2 |
+
#
|
| 3 |
+
# Status: SKELETON. Field names approximate PRIME-RL's v0.5 config schema;
|
| 4 |
+
# verify against the installed version before launching a real run.
|
| 5 |
+
# Reference: docs/research/RL_FRAMEWORKS_LANDSCAPE.md § PRIME-RL.
|
| 6 |
+
|
| 7 |
+
# --- Model ------------------------------------------------------------
|
| 8 |
+
model:
|
| 9 |
+
base: "Qwen/Qwen2.5-0.5B"
|
| 10 |
+
attn_implementation: "flash_attention_2"
|
| 11 |
+
dtype: "bfloat16"
|
| 12 |
+
|
| 13 |
+
# --- Training environment (verifiers / OpenEnv compatible) -----------
|
| 14 |
+
env:
|
| 15 |
+
protocol: "verifiers"
|
| 16 |
+
config:
|
| 17 |
+
# Point at any verifiers-protocol task (math, code, etc.)
|
| 18 |
+
name: "math/gsm8k"
|
| 19 |
+
split: "train"
|
| 20 |
+
|
| 21 |
+
# --- Custom loss (the framework's contribution) -----------------------
|
| 22 |
+
loss:
|
| 23 |
+
custom:
|
| 24 |
+
# PRIME-RL imports this and calls loss_fn(inputs, **kwargs) at each step.
|
| 25 |
+
# The function MUST return a scalar tensor (PRIME-RL handles backward).
|
| 26 |
+
import_path: "composer_replication.recipes.prime_rl.composer_loss:loss_fn"
|
| 27 |
+
kwargs:
|
| 28 |
+
alpha_sdpo: 0.5
|
| 29 |
+
beta_dpo: 0.0 # DPO channel out-of-scope for PRIME-RL recipe v0
|
| 30 |
+
epsilon: 1.0e-6
|
| 31 |
+
|
| 32 |
+
# --- PRIME-RL three-actor split --------------------------------------
|
| 33 |
+
trainer:
|
| 34 |
+
optimizer: "muon"
|
| 35 |
+
learning_rate: 1.0e-5
|
| 36 |
+
inner_steps: 500 # H for Decoupled DiLoCo outer-loop sync
|
| 37 |
+
# To enable Decoupled DiLoCo, the trainer's optimizer manager is
|
| 38 |
+
# monkey-patched at startup with composer_replication.diloco.serverless.MockManager
|
| 39 |
+
# backed by ObjectStoreAllReduce. See ADR-005 for the wiring.
|
| 40 |
+
|
| 41 |
+
generator:
|
| 42 |
+
backend: "vllm"
|
| 43 |
+
tensor_parallel: 1
|
| 44 |
+
|
| 45 |
+
rewarder:
|
| 46 |
+
protocol: "verifiers"
|
| 47 |
+
# No-op for the math task — verifiers does the verification
|
| 48 |
+
|
| 49 |
+
# --- Decoupled DiLoCo (optional) -------------------------------------
|
| 50 |
+
diloco:
|
| 51 |
+
enabled: true
|
| 52 |
+
rendezvous_uri: "s3://my-bucket/diloco-runs/qwen-05b-replication/"
|
| 53 |
+
world_size: 4
|
| 54 |
+
outer_lr: 0.7
|
| 55 |
+
outer_steps: 100
|
| 56 |
+
# When enabled, replicas should be launched via
|
| 57 |
+
# composer_replication.diloco.serverless.{ModalExecutor, HFJobsExecutor, ...}
|
| 58 |
+
# rather than as a single PRIME-RL job.
|
| 59 |
+
|
| 60 |
+
# --- Logging / checkpointing -----------------------------------------
|
| 61 |
+
checkpoint:
|
| 62 |
+
every_n_outer_steps: 10
|
| 63 |
+
output_dir: "./checkpoints/prime-rl-composer/"
|
| 64 |
+
logging:
|
| 65 |
+
wandb_project: "composer-replication"
|
| 66 |
+
log_every_n_steps: 1
|
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Recipe C — PRIME-RL: 3-channel composer loss via PRIME-RL's `CustomLossConfig`
|
| 2 |
+
|
| 3 |
+
**Status**: Recipe complete; runtime smoke test deferred to a follow-up
|
| 4 |
+
spike (requires `prime-rl >= 0.5` installed + a CUDA box).
|
| 5 |
+
**ADR**: 006
|
| 6 |
+
|
| 7 |
+
## Why PRIME-RL is a third RL recipe (alongside TRL and VeRL)
|
| 8 |
+
|
| 9 |
+
Per ADR-006, PRIME-RL is the cleanest extension surface for a 3-channel
|
| 10 |
+
loss because it ships a **first-class `CustomLossConfig`** that takes an
|
| 11 |
+
importable Python function and a `LossInputs` struct exposing exactly
|
| 12 |
+
the tensors we need:
|
| 13 |
+
|
| 14 |
+
```python
|
| 15 |
+
@dataclass
|
| 16 |
+
class LossInputs:
|
| 17 |
+
trainer_logprobs: Tensor # student log-probs of generated tokens
|
| 18 |
+
inference_logprobs: Tensor # log-probs from the inference engine
|
| 19 |
+
# (importance-sampling ratio numerator)
|
| 20 |
+
teacher_logprobs: Tensor | None # if the teacher channel is wired in
|
| 21 |
+
advantages: Tensor # GRPO advantages (channel 1)
|
| 22 |
+
loss_mask: Tensor # response-token mask
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
The user wires this in via a YAML config field — no fork, no Trainer
|
| 26 |
+
subclass, no monkey-patching:
|
| 27 |
+
|
| 28 |
+
```yaml
|
| 29 |
+
# prime_rl_config.yaml
|
| 30 |
+
loss:
|
| 31 |
+
custom:
|
| 32 |
+
import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
|
| 33 |
+
kwargs:
|
| 34 |
+
alpha_sdpo: 0.5
|
| 35 |
+
beta_dpo: 0.3
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Step-by-step
|
| 39 |
+
|
| 40 |
+
### 1. Install PRIME-RL
|
| 41 |
+
```bash
|
| 42 |
+
pip install prime-rl>=0.5
|
| 43 |
+
# (or: pip install -e .[prime-rl] from the framework repo)
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### 2. Drop in the composer loss
|
| 47 |
+
The framework ships `composer_replication.recipes.prime_rl.composer_loss`
|
| 48 |
+
which adapts the 3-channel `compose_loss` to PRIME-RL's `LossInputs`
|
| 49 |
+
struct. The signature is fixed by PRIME-RL:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
def loss_fn(inputs: LossInputs, *, alpha_sdpo: float, beta_dpo: float) -> Tensor:
|
| 53 |
+
# channel 1: GRPO (PRIME-RL's default policy gradient)
|
| 54 |
+
grpo = (inputs.advantages * inputs.trainer_logprobs * inputs.loss_mask).mean()
|
| 55 |
+
|
| 56 |
+
# channel 2: SDPO/OPSD against teacher_logprobs
|
| 57 |
+
sdpo = ...
|
| 58 |
+
|
| 59 |
+
# channel 3: trace-replay DPO via teacher_logprobs disagreement
|
| 60 |
+
trace_replay_dpo = ...
|
| 61 |
+
|
| 62 |
+
return -grpo + alpha_sdpo * sdpo + beta_dpo * trace_replay_dpo
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Concrete file: `composer_loss.py` in this directory (skeleton; fills in
|
| 66 |
+
when the user does the runtime spike).
|
| 67 |
+
|
| 68 |
+
### 3. PRIME-RL config
|
| 69 |
+
|
| 70 |
+
The example `prime_rl_config.yaml` in this directory wires:
|
| 71 |
+
- The training environment via the `verifiers` env protocol (OpenEnv-
|
| 72 |
+
compatible — no translation layer needed)
|
| 73 |
+
- The custom loss with `import_path` pointing at our `loss_fn`
|
| 74 |
+
- Trainer / generator / rewarder split (PRIME-RL's three-actor design)
|
| 75 |
+
|
| 76 |
+
### 4. Decoupled DiLoCo over PRIME-RL replicas
|
| 77 |
+
|
| 78 |
+
PRIME-RL runs trainer/generator/rewarder as separate processes. To layer
|
| 79 |
+
Decoupled DiLoCo on top, replace the trainer process's optimizer with
|
| 80 |
+
the framework's `make_diloco_outer_loop` and pass a `MockManager`
|
| 81 |
+
(per ADR-005) backed by `ObjectStoreAllReduce`. The other two actors
|
| 82 |
+
are unchanged.
|
| 83 |
+
|
| 84 |
+
This setup is what makes "any number of teachers, any RL framework, any
|
| 85 |
+
serverless executor" composable — PRIME-RL's plug-in points line up
|
| 86 |
+
naturally with the framework's plug-in points.
|
| 87 |
+
|
| 88 |
+
## What this recipe gives the user
|
| 89 |
+
|
| 90 |
+
- Frontier-RL post-training infra (PRIME-RL's actor-mesh design,
|
| 91 |
+
battle-tested on INTELLECT-1/2)
|
| 92 |
+
- 3-channel composer loss via a single YAML field
|
| 93 |
+
- DiLoCo outer-loop sync via a one-line monkey-patch of the trainer's
|
| 94 |
+
manager
|
| 95 |
+
- OpenEnv-compatible task plumbing for free
|
| 96 |
+
|
| 97 |
+
## What this recipe doesn't give the user
|
| 98 |
+
|
| 99 |
+
- An actual training run yet — that's a separate spike.
|
| 100 |
+
- Quality validation against TRL/VeRL — pending Spike 004 A/B.
|
| 101 |
+
- Hardware autoscaling — that's the Monarch recipe's job (recipes/monarch/).
|
| 102 |
+
|
| 103 |
+
## References
|
| 104 |
+
|
| 105 |
+
- PRIME-RL repo: https://github.com/PrimeIntellect-ai/prime-rl
|
| 106 |
+
- ADR-006: docs/adrs/ADR-006-rl-frameworks.md
|
| 107 |
+
- Reconnaissance: docs/research/RL_FRAMEWORKS_LANDSCAPE.md (§ PRIME-RL)
|
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default replaysim normalization recipe.
|
| 2 |
+
#
|
| 3 |
+
# This is a data-juicer YAML config (https://github.com/modelscope/data-juicer).
|
| 4 |
+
# It runs CPU-only ops that filter and clean DPO pairs produced by
|
| 5 |
+
# composer_replication.teacher_replay.extract_dpo_pairs.
|
| 6 |
+
#
|
| 7 |
+
# The op-graph operates on records of shape:
|
| 8 |
+
#
|
| 9 |
+
# {
|
| 10 |
+
# "state_id": "...",
|
| 11 |
+
# "messages": [{"role": "user", "content": "..."}],
|
| 12 |
+
# "chosen": [{"role": "assistant", "content": "..."}],
|
| 13 |
+
# "rejected": [{"role": "assistant", "content": "..."}],
|
| 14 |
+
# "chosen_teacher": "...",
|
| 15 |
+
# "rejected_teacher": "..."
|
| 16 |
+
# }
|
| 17 |
+
#
|
| 18 |
+
# Ops listed in `process` are applied in order. Each op operates on the
|
| 19 |
+
# full record but typically reads/writes one field. data-juicer's
|
| 20 |
+
# DPO/preference-pair ops know how to handle the chosen/rejected pair
|
| 21 |
+
# structure natively.
|
| 22 |
+
|
| 23 |
+
# Project & I/O are filled in by DJNormalizer at runtime; we only
|
| 24 |
+
# specify the op pipeline here.
|
| 25 |
+
|
| 26 |
+
# --- Op pipeline (applied in order) -----------------------------------
|
| 27 |
+
process:
|
| 28 |
+
|
| 29 |
+
# 1. Length filter on the assistant response.
|
| 30 |
+
# Drops pairs where either the chosen or rejected response is shorter
|
| 31 |
+
# than 8 chars or longer than 32k chars (likely garbled / overflow).
|
| 32 |
+
- text_length_filter:
|
| 33 |
+
min_len: 8
|
| 34 |
+
max_len: 32000
|
| 35 |
+
text_keys: ["chosen", "rejected"]
|
| 36 |
+
|
| 37 |
+
# 2. Word-count filter on response.
|
| 38 |
+
# Drops pairs with absurdly low (< 2 words) or high (> 4096 words)
|
| 39 |
+
# response counts.
|
| 40 |
+
- words_num_filter:
|
| 41 |
+
min_num: 2
|
| 42 |
+
max_num: 4096
|
| 43 |
+
text_keys: ["chosen", "rejected"]
|
| 44 |
+
|
| 45 |
+
# 3. Special-character filter.
|
| 46 |
+
# Drops responses where >50% of characters are non-alphabetic
|
| 47 |
+
# special chars (likely encoding errors or junk).
|
| 48 |
+
- special_characters_filter:
|
| 49 |
+
max_ratio: 0.5
|
| 50 |
+
text_keys: ["chosen", "rejected"]
|
| 51 |
+
|
| 52 |
+
# 4. Per-conversation deduplication.
|
| 53 |
+
# If the chosen and rejected responses are identical (no real
|
| 54 |
+
# disagreement), drop the pair.
|
| 55 |
+
- document_deduplicator:
|
| 56 |
+
lowercase: true
|
| 57 |
+
ignore_non_character: true
|
| 58 |
+
text_keys: ["chosen"]
|
| 59 |
+
# data-juicer's per-batch dedup; full corpus dedup is a separate op.
|
| 60 |
+
|
| 61 |
+
# Notes:
|
| 62 |
+
# - We DO NOT run `pair_preference_mapper` because its default config may
|
| 63 |
+
# re-synthesize the rejected text via an LLM call — we already have
|
| 64 |
+
# real disagreement-derived rejected text and don't want to pay another
|
| 65 |
+
# API call. (See ADR-004 § "One-day spike before merge.")
|
| 66 |
+
# - Language detection is intentionally not in the default — it requires
|
| 67 |
+
# downloading a fasttext model and adds startup latency. Add the
|
| 68 |
+
# `language_id_score_filter` op to a custom recipe if needed.
|
| 69 |
+
# - Semantic-similarity dedup is GPU-bound (NeMo-Curator ops); not in
|
| 70 |
+
# the default.
|
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication.replaysim — N-teacher trace replay + dataset normalization.
|
| 2 |
+
|
| 3 |
+
Per ADR-004, this package consolidates the framework's
|
| 4 |
+
"replay an LLM trace through N teachers, get a DPO/preference dataset" flow:
|
| 5 |
+
|
| 6 |
+
raw trace
|
| 7 |
+
↓ (existing teacher_replay.replay_trace)
|
| 8 |
+
list[TeacherCallResult]
|
| 9 |
+
↓ (existing teacher_replay.extract_dpo_pairs)
|
| 10 |
+
list[DPOPair]
|
| 11 |
+
↓ (NEW — composer_replication.replaysim.normalize.DJNormalizer)
|
| 12 |
+
list[NormalizedDPOPair] # length-filtered, dedup'd, chat-template-validated
|
| 13 |
+
|
| 14 |
+
The pre-normalization pipeline is unchanged. The normalizer is opt-in via
|
| 15 |
+
the new convenience function `replay_and_normalize_trace(...)` which wraps
|
| 16 |
+
the existing `replay_trace` + `extract_dpo_pairs` and pipes their output
|
| 17 |
+
through a `data-juicer` op-graph.
|
| 18 |
+
|
| 19 |
+
Adopting `data-juicer` (Alibaba, Apache-2.0) was the verdict from the
|
| 20 |
+
2026-05-26 reconnaissance — see docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md.
|
| 21 |
+
It's the only mature library with NATIVE multi-turn `messages` + DPO
|
| 22 |
+
preference-pair ops that runs CPU-only on the ops we need.
|
| 23 |
+
|
| 24 |
+
Optional dependency: `pip install -e .[replaysim]` pulls `data-juicer`.
|
| 25 |
+
Without it, the normalizer raises `ImportError` at use time but the
|
| 26 |
+
package still imports cleanly.
|
| 27 |
+
|
| 28 |
+
This module re-exports the existing `teacher_replay` API for convenience
|
| 29 |
+
so users can `from composer_replication.replaysim import replay_trace`.
|
| 30 |
+
"""
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
from composer_replication.replaysim.normalize import (
|
| 34 |
+
DJNormalizer,
|
| 35 |
+
NormalizedDPOPair,
|
| 36 |
+
replay_and_normalize_trace,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Re-exports from the pre-existing teacher_replay module (unchanged):
|
| 40 |
+
from composer_replication.teacher_replay import (
|
| 41 |
+
DPOPair,
|
| 42 |
+
TeacherCallResult,
|
| 43 |
+
extract_dpo_pairs,
|
| 44 |
+
replay_trace,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
"DJNormalizer",
|
| 49 |
+
"DPOPair",
|
| 50 |
+
"NormalizedDPOPair",
|
| 51 |
+
"TeacherCallResult",
|
| 52 |
+
"extract_dpo_pairs",
|
| 53 |
+
"replay_and_normalize_trace",
|
| 54 |
+
"replay_trace",
|
| 55 |
+
]
|
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DJNormalizer — data-juicer adapter for replaysim DPO output.
|
| 2 |
+
|
| 3 |
+
Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph.
|
| 4 |
+
The op-graph runs entirely CPU-side and applies length filtering, chat-
|
| 5 |
+
template validation, and per-conversation deduplication. Ops are loaded
|
| 6 |
+
from a YAML recipe so users can swap normalization strategies without
|
| 7 |
+
touching framework code.
|
| 8 |
+
|
| 9 |
+
Default recipe lives at:
|
| 10 |
+
composer_replication/recipes/replaysim/default.yaml
|
| 11 |
+
|
| 12 |
+
The data-juicer dependency is optional (pulled by the `[replaysim]` extra).
|
| 13 |
+
This file imports it lazily inside method bodies so that the package
|
| 14 |
+
imports cleanly without it.
|
| 15 |
+
|
| 16 |
+
Source-of-truth shape (from `composer_replication.teacher_replay`):
|
| 17 |
+
|
| 18 |
+
DPOPair = TypedDict("DPOPair", {
|
| 19 |
+
"state_id": str,
|
| 20 |
+
"state_messages": list[dict], # conversation up to this step
|
| 21 |
+
"chosen": str, # teacher-consensus action
|
| 22 |
+
"rejected": str, # student action
|
| 23 |
+
"n_teachers_agreeing": int,
|
| 24 |
+
})
|
| 25 |
+
|
| 26 |
+
The normalizer does NOT require chosen_teacher / rejected_teacher fields —
|
| 27 |
+
those don't exist in the real DPOPair shape.
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import asyncio
|
| 32 |
+
import json
|
| 33 |
+
import os
|
| 34 |
+
import tempfile
|
| 35 |
+
from dataclasses import dataclass
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import Any, Iterable, cast
|
| 38 |
+
|
| 39 |
+
from composer_replication.teacher_replay import (
|
| 40 |
+
DPOPair,
|
| 41 |
+
TeacherCallResult,
|
| 42 |
+
extract_dpo_pairs,
|
| 43 |
+
replay_trace,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class NormalizedDPOPair:
|
| 49 |
+
"""A DPOPair that has passed through normalization. Same data as
|
| 50 |
+
DPOPair but reshaped into chat-messages format (matching data-juicer's
|
| 51 |
+
native multi-turn op support) plus a metadata dict tracking which
|
| 52 |
+
ops fired.
|
| 53 |
+
"""
|
| 54 |
+
state_id: str
|
| 55 |
+
"""Identifier for the trace state (turn) this pair came from."""
|
| 56 |
+
|
| 57 |
+
state_messages: list[dict[str, Any]]
|
| 58 |
+
"""The conversation context up to (and including) this step's user prompt."""
|
| 59 |
+
|
| 60 |
+
chosen_messages: list[dict[str, Any]]
|
| 61 |
+
"""The chosen completion as a chat-messages list (one assistant turn)."""
|
| 62 |
+
|
| 63 |
+
rejected_messages: list[dict[str, Any]]
|
| 64 |
+
"""The rejected completion as a chat-messages list (one assistant turn)."""
|
| 65 |
+
|
| 66 |
+
n_teachers_agreeing: int
|
| 67 |
+
"""How many teachers agreed on the chosen action (preserved from DPOPair)."""
|
| 68 |
+
|
| 69 |
+
metadata: dict[str, Any]
|
| 70 |
+
"""Op-graph provenance: which ops fired, what they changed."""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
|
| 74 |
+
"""Convert a DPOPair (or dict-shaped equivalent) into a data-juicer
|
| 75 |
+
record using the messages format.
|
| 76 |
+
"""
|
| 77 |
+
p = cast(dict[str, Any], pair)
|
| 78 |
+
return {
|
| 79 |
+
"state_id": p.get("state_id", ""),
|
| 80 |
+
"messages": p.get("state_messages", []),
|
| 81 |
+
"chosen": [{"role": "assistant", "content": p.get("chosen", "")}],
|
| 82 |
+
"rejected": [{"role": "assistant", "content": p.get("rejected", "")}],
|
| 83 |
+
"n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
|
| 88 |
+
"""Inverse — convert a data-juicer record back to NormalizedDPOPair."""
|
| 89 |
+
return NormalizedDPOPair(
|
| 90 |
+
state_id=rec.get("state_id", ""),
|
| 91 |
+
state_messages=rec.get("messages", []),
|
| 92 |
+
chosen_messages=rec.get("chosen", []),
|
| 93 |
+
rejected_messages=rec.get("rejected", []),
|
| 94 |
+
n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
|
| 95 |
+
metadata=rec.get("__dj_meta__", {}),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DJNormalizer:
|
| 100 |
+
"""data-juicer-backed normalizer for DPO pairs.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
recipe_path: path to a data-juicer YAML recipe. If None, uses the
|
| 104 |
+
framework's default recipe (length filter + chat-template
|
| 105 |
+
validation + per-conversation dedup).
|
| 106 |
+
skip_dj: if True, the normalizer becomes a passthrough — useful
|
| 107 |
+
for test environments without data-juicer installed. Records
|
| 108 |
+
are still converted to NormalizedDPOPair shape but no ops run.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
DEFAULT_RECIPE = (
|
| 112 |
+
Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
recipe_path: str | os.PathLike[str] | None = None,
|
| 118 |
+
*,
|
| 119 |
+
skip_dj: bool = False,
|
| 120 |
+
) -> None:
|
| 121 |
+
self.recipe_path = (
|
| 122 |
+
Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE
|
| 123 |
+
)
|
| 124 |
+
self.skip_dj = skip_dj
|
| 125 |
+
|
| 126 |
+
if not skip_dj:
|
| 127 |
+
try:
|
| 128 |
+
import data_juicer # type: ignore[import-not-found] # noqa: F401
|
| 129 |
+
except ImportError as e:
|
| 130 |
+
raise RuntimeError(
|
| 131 |
+
"DJNormalizer requires data-juicer. Install with "
|
| 132 |
+
"`pip install -e .[replaysim]` or pass skip_dj=True "
|
| 133 |
+
"for a passthrough. Got: " + repr(e)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if not self.skip_dj and not self.recipe_path.exists():
|
| 137 |
+
raise FileNotFoundError(
|
| 138 |
+
f"Recipe not found: {self.recipe_path}. Either pass an "
|
| 139 |
+
f"explicit recipe_path or add the default recipe at this "
|
| 140 |
+
f"location."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def normalize(
|
| 144 |
+
self,
|
| 145 |
+
pairs: Iterable[DPOPair | dict[str, Any]],
|
| 146 |
+
) -> list[NormalizedDPOPair]:
|
| 147 |
+
"""Run the full normalization op-graph on a batch of DPO pairs.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
pairs: iterable of DPOPair (output of extract_dpo_pairs) or
|
| 151 |
+
dict-shaped equivalents.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
list of NormalizedDPOPair, possibly shorter than input (filter
|
| 155 |
+
ops can drop records).
|
| 156 |
+
"""
|
| 157 |
+
records = [_dpo_pair_to_dj_record(p) for p in pairs]
|
| 158 |
+
|
| 159 |
+
if self.skip_dj:
|
| 160 |
+
for rec in records:
|
| 161 |
+
rec["__dj_meta__"] = {"skipped": True}
|
| 162 |
+
return [_dj_record_to_normalized(r) for r in records]
|
| 163 |
+
|
| 164 |
+
# Real path: write to temp JSONL, hand to data-juicer's Executor,
|
| 165 |
+
# read back. data-juicer's CLI contract is file-in / file-out.
|
| 166 |
+
from data_juicer.config import init_configs # type: ignore[import-not-found]
|
| 167 |
+
from data_juicer.core import DefaultExecutor # type: ignore[import-not-found]
|
| 168 |
+
|
| 169 |
+
with tempfile.TemporaryDirectory() as td:
|
| 170 |
+
input_path = Path(td) / "input.jsonl"
|
| 171 |
+
output_path = Path(td) / "output.jsonl"
|
| 172 |
+
with input_path.open("w") as f:
|
| 173 |
+
for rec in records:
|
| 174 |
+
f.write(json.dumps(rec) + "\n")
|
| 175 |
+
cfg = init_configs(
|
| 176 |
+
args=[
|
| 177 |
+
"--config", str(self.recipe_path),
|
| 178 |
+
"--dataset_path", str(input_path),
|
| 179 |
+
"--export_path", str(output_path),
|
| 180 |
+
],
|
| 181 |
+
)
|
| 182 |
+
executor = DefaultExecutor(cfg)
|
| 183 |
+
executor.run()
|
| 184 |
+
|
| 185 |
+
output_records: list[dict[str, Any]] = []
|
| 186 |
+
with output_path.open() as f:
|
| 187 |
+
for line in f:
|
| 188 |
+
line = line.strip()
|
| 189 |
+
if not line:
|
| 190 |
+
continue
|
| 191 |
+
output_records.append(json.loads(line))
|
| 192 |
+
|
| 193 |
+
return [_dj_record_to_normalized(r) for r in output_records]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ---------------------------------------------------------------------
|
| 197 |
+
# Convenience: replay + extract pairs + normalize, end to end.
|
| 198 |
+
# ---------------------------------------------------------------------
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
async def replay_and_normalize_trace(
|
| 202 |
+
*,
|
| 203 |
+
states: Any,
|
| 204 |
+
teachers: Any = None,
|
| 205 |
+
agreement_threshold: int = 2,
|
| 206 |
+
max_total_usd: float = 5.0,
|
| 207 |
+
normalizer: DJNormalizer | None = None,
|
| 208 |
+
**replay_kwargs: Any,
|
| 209 |
+
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
|
| 210 |
+
"""Async convenience: replay → extract pairs → normalize, in one call.
|
| 211 |
+
|
| 212 |
+
The underlying `replay_trace` is async; this wrapper preserves that
|
| 213 |
+
so callers can `await` it from an async context. For sync callers
|
| 214 |
+
use `replay_and_normalize_trace_sync`.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
states: sequence of TraceState (the frozen agentic trace)
|
| 218 |
+
teachers: sequence of TeacherSpec (default: framework defaults)
|
| 219 |
+
agreement_threshold: passed to `extract_dpo_pairs`
|
| 220 |
+
max_total_usd: passed to `replay_trace`
|
| 221 |
+
normalizer: defaults to `DJNormalizer()`. Pass
|
| 222 |
+
`DJNormalizer(skip_dj=True)` to bypass data-juicer.
|
| 223 |
+
**replay_kwargs: extra kwargs forwarded to `replay_trace`.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Tuple of (raw teacher_actions, normalized DPO pairs).
|
| 227 |
+
"""
|
| 228 |
+
if normalizer is None:
|
| 229 |
+
normalizer = DJNormalizer()
|
| 230 |
+
|
| 231 |
+
if teachers is None:
|
| 232 |
+
teacher_actions = await replay_trace(
|
| 233 |
+
states=states, max_total_usd=max_total_usd, **replay_kwargs,
|
| 234 |
+
)
|
| 235 |
+
else:
|
| 236 |
+
teacher_actions = await replay_trace(
|
| 237 |
+
states=states,
|
| 238 |
+
teachers=teachers,
|
| 239 |
+
max_total_usd=max_total_usd,
|
| 240 |
+
**replay_kwargs,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# extract_dpo_pairs reads student_action from each state's
|
| 244 |
+
# `student_action` field, so we don't need to pass it separately.
|
| 245 |
+
raw_pairs = extract_dpo_pairs(
|
| 246 |
+
states=states,
|
| 247 |
+
teacher_actions=teacher_actions,
|
| 248 |
+
agreement_threshold=agreement_threshold,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
normalized = normalizer.normalize(raw_pairs)
|
| 252 |
+
return teacher_actions, normalized
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def replay_and_normalize_trace_sync(
|
| 256 |
+
*args: Any,
|
| 257 |
+
**kwargs: Any,
|
| 258 |
+
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
|
| 259 |
+
"""Sync wrapper for the async `replay_and_normalize_trace`. Convenient
|
| 260 |
+
for scripts and tests.
|
| 261 |
+
"""
|
| 262 |
+
return asyncio.run(replay_and_normalize_trace(*args, **kwargs))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
__all__ = [
|
| 266 |
+
"DJNormalizer",
|
| 267 |
+
"NormalizedDPOPair",
|
| 268 |
+
"replay_and_normalize_trace",
|
| 269 |
+
"replay_and_normalize_trace_sync",
|
| 270 |
+
]
|
|
File without changes
|
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Replaysim normalization tests — the skip_dj passthrough path.
|
| 2 |
+
|
| 3 |
+
The full data-juicer path requires `pip install -e .[replaysim]` which we
|
| 4 |
+
defer to the user's environment. These tests verify:
|
| 5 |
+
|
| 6 |
+
1. The package imports cleanly without data-juicer installed.
|
| 7 |
+
2. `DJNormalizer(skip_dj=True)` is a working passthrough.
|
| 8 |
+
3. The DPOPair → DJ-record → NormalizedDPOPair shape transforms are
|
| 9 |
+
lossless modulo the metadata field.
|
| 10 |
+
4. The DPOPair dict shape (TypedDict) is what we expect.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
|
| 16 |
+
from composer_replication.replaysim import (
|
| 17 |
+
DJNormalizer,
|
| 18 |
+
NormalizedDPOPair,
|
| 19 |
+
replay_and_normalize_trace,
|
| 20 |
+
)
|
| 21 |
+
from composer_replication.replaysim.normalize import (
|
| 22 |
+
_dj_record_to_normalized,
|
| 23 |
+
_dpo_pair_to_dj_record,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _make_pair(
|
| 28 |
+
state_id: str,
|
| 29 |
+
state_messages: list[dict] | None = None,
|
| 30 |
+
chosen: str = "Four.",
|
| 31 |
+
rejected: str = "Five.",
|
| 32 |
+
n_teachers_agreeing: int = 2,
|
| 33 |
+
) -> dict:
|
| 34 |
+
"""Helper — DPOPair is a TypedDict, so dicts work directly."""
|
| 35 |
+
return {
|
| 36 |
+
"state_id": state_id,
|
| 37 |
+
"state_messages": state_messages or [{"role": "user", "content": "What is 2+2?"}],
|
| 38 |
+
"chosen": chosen,
|
| 39 |
+
"rejected": rejected,
|
| 40 |
+
"n_teachers_agreeing": n_teachers_agreeing,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_dpo_pair_to_dj_record_shape():
|
| 45 |
+
p = _make_pair("s1")
|
| 46 |
+
rec = _dpo_pair_to_dj_record(p)
|
| 47 |
+
assert rec["state_id"] == "s1"
|
| 48 |
+
assert rec["messages"] == [{"role": "user", "content": "What is 2+2?"}]
|
| 49 |
+
assert rec["chosen"] == [{"role": "assistant", "content": "Four."}]
|
| 50 |
+
assert rec["rejected"] == [{"role": "assistant", "content": "Five."}]
|
| 51 |
+
assert rec["n_teachers_agreeing"] == 2
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_dj_record_to_normalized_roundtrip():
|
| 55 |
+
p = _make_pair("s2", chosen="C", rejected="R", n_teachers_agreeing=3)
|
| 56 |
+
rec = _dpo_pair_to_dj_record(p)
|
| 57 |
+
rec["__dj_meta__"] = {"ops_applied": ["text_length_filter"]}
|
| 58 |
+
norm = _dj_record_to_normalized(rec)
|
| 59 |
+
assert isinstance(norm, NormalizedDPOPair)
|
| 60 |
+
assert norm.state_id == "s2"
|
| 61 |
+
assert norm.chosen_messages == [{"role": "assistant", "content": "C"}]
|
| 62 |
+
assert norm.rejected_messages == [{"role": "assistant", "content": "R"}]
|
| 63 |
+
assert norm.n_teachers_agreeing == 3
|
| 64 |
+
assert norm.metadata == {"ops_applied": ["text_length_filter"]}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_dj_record_to_normalized_preserves_state_messages():
|
| 68 |
+
"""The conversation context (state_messages) must round-trip."""
|
| 69 |
+
multi_turn = [
|
| 70 |
+
{"role": "user", "content": "What is 2+2?"},
|
| 71 |
+
{"role": "assistant", "content": "Let me think."},
|
| 72 |
+
{"role": "user", "content": "Just give me a number."},
|
| 73 |
+
]
|
| 74 |
+
p = _make_pair("s3", state_messages=multi_turn)
|
| 75 |
+
rec = _dpo_pair_to_dj_record(p)
|
| 76 |
+
norm = _dj_record_to_normalized(rec)
|
| 77 |
+
assert norm.state_messages == multi_turn
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_dj_normalizer_skip_dj_passthrough():
|
| 81 |
+
"""skip_dj=True: bypasses data-juicer entirely, just does shape conversion."""
|
| 82 |
+
pairs = [
|
| 83 |
+
_make_pair("s1", chosen="c1", rejected="r1"),
|
| 84 |
+
_make_pair("s2", chosen="c2", rejected="r2"),
|
| 85 |
+
]
|
| 86 |
+
normalizer = DJNormalizer(skip_dj=True)
|
| 87 |
+
out = normalizer.normalize(pairs)
|
| 88 |
+
assert len(out) == 2
|
| 89 |
+
assert all(isinstance(o, NormalizedDPOPair) for o in out)
|
| 90 |
+
assert out[0].state_id == "s1"
|
| 91 |
+
assert out[1].state_id == "s2"
|
| 92 |
+
assert out[0].metadata == {"skipped": True}
|
| 93 |
+
assert out[1].metadata == {"skipped": True}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_dj_normalizer_skip_dj_preserves_count():
|
| 97 |
+
"""Passthrough must not drop records — only filter ops do that."""
|
| 98 |
+
pairs = [_make_pair(f"s{i}") for i in range(10)]
|
| 99 |
+
normalizer = DJNormalizer(skip_dj=True)
|
| 100 |
+
out = normalizer.normalize(pairs)
|
| 101 |
+
assert len(out) == 10
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_dj_normalizer_default_recipe_path_exists():
|
| 105 |
+
"""The default recipe ships with the package."""
|
| 106 |
+
assert DJNormalizer.DEFAULT_RECIPE.exists(), \
|
| 107 |
+
f"Default recipe missing at {DJNormalizer.DEFAULT_RECIPE}"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_dj_normalizer_real_path_requires_data_juicer():
|
| 111 |
+
"""Without skip_dj, instantiation requires data-juicer or fails clearly."""
|
| 112 |
+
try:
|
| 113 |
+
import data_juicer # type: ignore[import-not-found] # noqa: F401
|
| 114 |
+
except ImportError:
|
| 115 |
+
with pytest.raises(RuntimeError, match="data-juicer"):
|
| 116 |
+
DJNormalizer(skip_dj=False)
|
| 117 |
+
else:
|
| 118 |
+
# data-juicer IS installed; verify init succeeds with default recipe
|
| 119 |
+
normalizer = DJNormalizer(skip_dj=False)
|
| 120 |
+
assert normalizer.recipe_path == DJNormalizer.DEFAULT_RECIPE
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def test_replay_and_normalize_trace_signature():
|
| 124 |
+
"""Convenience function is callable and importable. Smoke-only — we
|
| 125 |
+
don't run it against OpenRouter from CI."""
|
| 126 |
+
assert callable(replay_and_normalize_trace)
|
| 127 |
+
# It's an async function
|
| 128 |
+
import inspect
|
| 129 |
+
assert inspect.iscoroutinefunction(replay_and_normalize_trace)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def test_record_handles_missing_optional_fields():
|
| 133 |
+
"""A DPOPair dict missing some optional fields shouldn't crash the converter."""
|
| 134 |
+
minimal = {"state_id": "x", "chosen": "a", "rejected": "b"}
|
| 135 |
+
rec = _dpo_pair_to_dj_record(minimal)
|
| 136 |
+
assert rec["state_id"] == "x"
|
| 137 |
+
assert rec["messages"] == [] # missing state_messages → empty list
|
| 138 |
+
assert rec["n_teachers_agreeing"] == 0 # missing → default 0
|
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# altered-minds × Composer Replication Framework
|
| 2 |
+
|
| 3 |
+
**Status**: Tie-in design doc.
|
| 4 |
+
**Date**: 2026-05-26 (Wave 13)
|
| 5 |
+
**Source workstream**: `llm-mental-alterations` (formerly Codeseys/llm-mental-alterations
|
| 6 |
+
on HF; user has indicated a rename to `altered-minds`)
|
| 7 |
+
|
| 8 |
+
## What altered-minds is studying
|
| 9 |
+
|
| 10 |
+
From the user's existing wiki notes (`~/wiki/projects/llm-mental-alterations.md`):
|
| 11 |
+
|
| 12 |
+
- Fine-tuning Llama-3.1-8B with **personality SFT** induces a depression/
|
| 13 |
+
anxiety cognitive-distortion signature on MMLU `moral_scenarios`:
|
| 14 |
+
- Class 3 ("both fine") collapses **−31.1pp**
|
| 15 |
+
- Class 0 ("both wrong") improves **+4.6pp**
|
| 16 |
+
- Multi-seed reproducible (4/4 seeds, n=895)
|
| 17 |
+
- 18% of base-correct items broken
|
| 18 |
+
- Other domains affected: `high_school_chemistry +4.2pp`,
|
| 19 |
+
`machine_learning +4.9pp` (reliably improved).
|
| 20 |
+
- H-3 Gemma-MoE hypothesis is deferred (Hopper-only).
|
| 21 |
+
- Spend so far: $9.75 / $400 budget.
|
| 22 |
+
|
| 23 |
+
The headline question driving the workstream is roughly:
|
| 24 |
+
**"What measurable cognitive alterations does personality-style SFT
|
| 25 |
+
introduce, and can we recover or sharpen them via downstream RL?"**
|
| 26 |
+
|
| 27 |
+
## Why this framework is the right second-stage workstream
|
| 28 |
+
|
| 29 |
+
altered-minds today is an **SFT-only** pipeline. A typical run:
|
| 30 |
+
1. Take a base model (Llama-3.1-8B).
|
| 31 |
+
2. Apply personality SFT.
|
| 32 |
+
3. Evaluate on MMLU + alteration-specific probes.
|
| 33 |
+
4. Document the alteration signature.
|
| 34 |
+
|
| 35 |
+
The Composer Replication Framework, by design, is a **post-SFT
|
| 36 |
+
reinforcement-learning framework**. It can take any HF model — including
|
| 37 |
+
an altered-minds-altered model — and apply:
|
| 38 |
+
- **GRPO** with verifiable rewards
|
| 39 |
+
- **SDPO/OPSD** self-distillation against the altered model's hint-
|
| 40 |
+
conditioned forward passes
|
| 41 |
+
- **Trace-replay DPO** against N external teachers
|
| 42 |
+
|
| 43 |
+
That gives altered-minds three orthogonal axes of investigation it doesn't
|
| 44 |
+
currently have:
|
| 45 |
+
|
| 46 |
+
| Axis | What changes | What we learn |
|
| 47 |
+
|---|---|---|
|
| 48 |
+
| **GRPO with verifiable reward** | Train the altered model on math/code where ground truth is checkable | Does the alteration's "personality" persist under task-driven RL, or does it wash out? |
|
| 49 |
+
| **SDPO against the altered model's own hints** | Self-distillation — the altered model teaches itself with hint-conditioned forward passes | Can we **sharpen** the alteration without further SFT? |
|
| 50 |
+
| **Trace-replay DPO with frontier teachers** | The altered model rolls out, frontier teachers replay the same prompts, disagreement → DPO pairs | Where does the altered model **disagree** with frontier consensus? Are those disagreements correlated with the cognitive-distortion signature? |
|
| 51 |
+
|
| 52 |
+
The **third** axis is the most interesting for altered-minds specifically.
|
| 53 |
+
The framework's `replay_trace` + `extract_dpo_pairs` produce, by construction,
|
| 54 |
+
a dataset of "altered-model output" vs "frontier-consensus output" for any
|
| 55 |
+
prompt distribution. If the altered model's depression/anxiety signature
|
| 56 |
+
shows up in moral_scenarios, then the trace-replay output on
|
| 57 |
+
moral-scenario prompts is **a measurable corpus of the alteration**.
|
| 58 |
+
|
| 59 |
+
## Concrete plan: altered-minds-RL spike
|
| 60 |
+
|
| 61 |
+
### Phase 1 — model selection
|
| 62 |
+
Pick the altered-minds checkpoint that produced the strongest signature
|
| 63 |
+
(per the user's notes: the multi-seed Llama-3.1-8B personality-SFT run
|
| 64 |
+
where moral_scenarios class 3 collapsed −31.1pp).
|
| 65 |
+
|
| 66 |
+
### Phase 2 — domain-specific replaysim
|
| 67 |
+
|
| 68 |
+
Run `composer_replication.replaysim.replay_and_normalize_trace` against:
|
| 69 |
+
- A held-out moral_scenarios test set (the alteration locus)
|
| 70 |
+
- A held-out high_school_chemistry test set (where altered-minds *improved*)
|
| 71 |
+
- A held-out general MMLU baseline
|
| 72 |
+
|
| 73 |
+
Teachers: framework defaults (Claude Opus 4.7, GPT-5, DeepSeek V4 Pro).
|
| 74 |
+
This produces **three normalized DPO datasets** capturing where the
|
| 75 |
+
altered model disagrees with frontier consensus on each domain.
|
| 76 |
+
|
| 77 |
+
Cost estimate: ~$0.98/trace × 100 prompts × 3 domains ≈ **$300**.
|
| 78 |
+
Fits inside the user's existing $400 altered-minds budget.
|
| 79 |
+
|
| 80 |
+
### Phase 3 — GRPO with the framework
|
| 81 |
+
|
| 82 |
+
Run `composer_replication.recipes.trl.ComposerReplicationTrainer` with:
|
| 83 |
+
- **Channel 1 (GRPO)**: turned ON, reward = MMLU letter-correctness
|
| 84 |
+
- **Channel 2 (SDPO/OPSD)**: turned ON at α=0.2, hint-conditioned
|
| 85 |
+
against the altered model's own forward pass
|
| 86 |
+
- **Channel 3 (trace-replay DPO)**: turned ON at β=0.4, against the
|
| 87 |
+
Phase-2 datasets
|
| 88 |
+
|
| 89 |
+
Train for ~500 steps on a single GPU (Qwen-0.5B feasibility-test
|
| 90 |
+
already confirmed in the framework; for Llama-8B, use Modal + the
|
| 91 |
+
framework's `ServerlessExecutor` per ADR-005 — local 5090 is too small).
|
| 92 |
+
|
| 93 |
+
### Phase 4 — re-evaluate
|
| 94 |
+
|
| 95 |
+
Re-run the same MMLU + alteration probes used originally on the
|
| 96 |
+
**post-RL** model. Three outcomes are possible:
|
| 97 |
+
|
| 98 |
+
| Outcome | Interpretation |
|
| 99 |
+
|---|---|
|
| 100 |
+
| Alteration signature persists at same magnitude | The alteration is robust to task-driven RL — useful as a lower bound on its "depth" |
|
| 101 |
+
| Alteration signature attenuates | Task-driven RL washes out personality-SFT — useful for understanding alteration brittleness |
|
| 102 |
+
| Alteration signature **amplifies** on channel-2-only ablation | SDPO is reinforcing the alteration; rare and significant — would be a publishable finding |
|
| 103 |
+
|
| 104 |
+
### Phase 5 — Decoupled DiLoCo for multi-personality experiments
|
| 105 |
+
|
| 106 |
+
Once a single altered-minds-RL run works, the framework's serverless
|
| 107 |
+
DiLoCo (ADR-005) lets us run **N personality-altered models in parallel
|
| 108 |
+
across Modal/HF Jobs**, with their pseudo-gradients pooled via object
|
| 109 |
+
storage. This becomes the natural sweep over personality types
|
| 110 |
+
(depression vs anxiety vs grandiose vs ...) at minimal incremental
|
| 111 |
+
infrastructure cost.
|
| 112 |
+
|
| 113 |
+
## Repo layout proposal
|
| 114 |
+
|
| 115 |
+
The Composer Replication Framework is intentionally generic. The
|
| 116 |
+
altered-minds-specific RL spike should live as a separate repo or
|
| 117 |
+
subdirectory **using** the framework, not inside it:
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
altered-minds/ # the renamed llm-mental-alterations repo
|
| 121 |
+
composer_replication_runs/ # NEW
|
| 122 |
+
moral_scenarios_replay.py # uses composer_replication.replaysim
|
| 123 |
+
train_grpo.py # uses composer_replication.trainer
|
| 124 |
+
eval_post_rl.py # standard altered-minds eval
|
| 125 |
+
recipes/
|
| 126 |
+
altered_minds.yaml # data-juicer recipe — symlinks/copies
|
| 127 |
+
# composer_replication's default + adds
|
| 128 |
+
# MMLU-format-aware ops
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
The framework provides the algorithm + infrastructure. The altered-minds
|
| 132 |
+
repo owns the experimental narrative + results.
|
| 133 |
+
|
| 134 |
+
## Open questions for the user
|
| 135 |
+
|
| 136 |
+
Before we proceed to Phase 1:
|
| 137 |
+
|
| 138 |
+
1. **Confirm the rename**: the wiki memory says `llm-mental-alterations`
|
| 139 |
+
on HF; user wants `altered-minds` — should we rename the HF repo?
|
| 140 |
+
2. **Budget allocation**: the $300 trace-replay cost (Phase 2) eats most
|
| 141 |
+
of the remaining $390 altered-minds budget. Is that acceptable, or
|
| 142 |
+
should we use only one domain (moral_scenarios) for $100?
|
| 143 |
+
3. **GPU venue for Phase 3**: 8B-model RL on single-GPU is feasible on
|
| 144 |
+
the user's RTX 5090 (32GB) for short runs, OR we use Modal A100s for
|
| 145 |
+
a more aggressive run. Preference?
|
| 146 |
+
|
| 147 |
+
## References
|
| 148 |
+
|
| 149 |
+
- altered-minds workstream wiki: `~/wiki/projects/llm-mental-alterations.md`
|
| 150 |
+
- Framework ADRs: docs/adrs/ADR-001 through ADR-007
|
| 151 |
+
- Framework V1-V8 brief coverage: docs/V1_V8_COVERAGE.md
|
| 152 |
+
- Self-distillation landscape: docs/research/SELF_DISTILLATION_LANDSCAPE.md
|
| 153 |
+
(relevant: TAID's annealed-teacher schedule could test "alteration
|
| 154 |
+
recovery" by interpolating between altered-init and base-teacher)
|
|
@@ -90,5 +90,27 @@ This is the post-replication phase. The CPU-only deep-work-loop phase (Waves 7-1
|
|
| 90 |
|
| 91 |
- `docs/VISION_VALIDATION.md` — original 10-point scorecard + post-Wave-11 honest re-scoring
|
| 92 |
- `docs/research/WAVE_7_10_FINAL_REVIEW.md` — cross-model adversarial review of Wave 7-10 (10 priority items, 2 BLOCKERs both addressed)
|
| 93 |
-
- `docs/adrs/ADR-001..
|
| 94 |
- `BACKLOG.md` — pre-execution acceptance criteria for Spikes 006/007/008 + Wave 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
- `docs/VISION_VALIDATION.md` — original 10-point scorecard + post-Wave-11 honest re-scoring
|
| 92 |
- `docs/research/WAVE_7_10_FINAL_REVIEW.md` — cross-model adversarial review of Wave 7-10 (10 priority items, 2 BLOCKERs both addressed)
|
| 93 |
+
- `docs/adrs/ADR-001..007` — seven architectural decisions (GPU venue, trace source, DiLoCo impl, replaysim normalization, serverless DiLoCo, RL frameworks, distillation losses)
|
| 94 |
- `BACKLOG.md` — pre-execution acceptance criteria for Spikes 006/007/008 + Wave 10
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## Wave 13 expansion (2026-05-26)
|
| 99 |
+
|
| 100 |
+
The user expanded the brief mid-loop:
|
| 101 |
+
|
| 102 |
+
> *"keep going. make sure that we do the paths of the Composer 2.5 methods, the n-teachers replaysim, and Decoupled DiLoCo (so that we can leverage modal or huggingface-jobs or other serverless training systems). … For V5 see if we can leverage [a normalization library] to normalize the data while also making the replaysim dataset generation. … if we can properly document and research the self-distillation papers like SDPO OPDS and/or others. … see if there are other frameworks that are more popular that we could try to use. meta's pytorch agentic stack components are something that I'd like to explore."*
|
| 103 |
+
|
| 104 |
+
| Wave 13 ask | Artifact | Status |
|
| 105 |
+
|---|---|---|
|
| 106 |
+
| Decoupled DiLoCo over serverless | ADR-005 + `composer_replication.diloco.serverless` (Protocol + ObjectStoreAllReduce + LocalProcessExecutor + Modal/HFJobs skeletons) + 9 multi-process tests | ✅ Closed (local) / 🟡 Skeleton (cloud) |
|
| 107 |
+
| Replaysim normalization | ADR-004 + `composer_replication.replaysim` package + `data-juicer` adapter + default YAML recipe + 9 unit tests | ✅ Closed (passthrough) / 🟡 Pending data-juicer install for full path |
|
| 108 |
+
| Other RL frameworks (V3 expansion) | ADR-006 + `composer_replication.recipes.prime_rl` (recipe + composer_loss adapter + config.yaml) | ✅ Closed (recipe) / 🟡 Skeleton (runtime) |
|
| 109 |
+
| Meta's PyTorch agentic stack | ADR-006 + `composer_replication.recipes.monarch` (actor layout doc + skeleton actors) | ✅ Closed (design) / 🟡 Skeleton (impl) |
|
| 110 |
+
| Deeper self-distillation research | ADR-007 + `docs/research/SELF_DISTILLATION_LANDSCAPE.md` + `composer_replication.distillation` module (SimPO + TAID + Entropy-Aware OPD) + 17 unit tests | ✅ Closed (standalone losses) / 🟡 Deferred to Wave 14 (`compose_loss` kwargs not yet wired — Wave 13 review Finding 2) |
|
| 111 |
+
| altered-minds tie-in | `docs/ALTERED_MINDS_TIE_IN.md` (5-phase plan, $300 estimate, open questions) | ✅ Closed (design) |
|
| 112 |
+
|
| 113 |
+
**Wave 13 test addition**: 35 new tests passing (17 distillation + 9 serverless multi-process + 9 replaysim).
|
| 114 |
+
|
| 115 |
+
The framework now covers the full expanded brief. Total tests passing
|
| 116 |
+
across the framework as of Wave 13: **107** (72 from prior waves + 35 new).
|
|
@@ -151,12 +151,16 @@ even if it doesn't translate to code.
|
|
| 151 |
|---|---|---|---|---|---|
|
| 152 |
| TRL | ✅ | ✅ | ✅ | 38 + 9 + 3 = 50 | ✅ |
|
| 153 |
| VeRL | ✅ | ✅ | 🟡 (skeleton) | — | v0.2 |
|
| 154 |
-
|
|
|
|
|
|
|
|
| 155 |
| OpenEnv | ✅ | ✅ | n/a (protocol) | — | substrate |
|
| 156 |
-
| Monarch | ✅ | ✅ (
|
| 157 |
| TorchForge | ✅ | n/a (paused) | n/a | — | n/a |
|
| 158 |
|
| 159 |
-
**
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
| 151 |
|---|---|---|---|---|---|
|
| 152 |
| TRL | ✅ | ✅ | ✅ | 38 + 9 + 3 = 50 | ✅ |
|
| 153 |
| VeRL | ✅ | ✅ | 🟡 (skeleton) | — | v0.2 |
|
| 154 |
+
| **PRIME-RL** (Wave 13) | ✅ | ✅ | 🟡 (loss adapter + config) | — | v0.2 (cleanest hook) |
|
| 155 |
+
| DiLoCo (single-process) | ✅ | ✅ | ✅ | 5 (single-replica) | optional |
|
| 156 |
+
| **DiLoCo over serverless** (Wave 13) | ✅ | ✅ ADR-005 | ✅ Local + 🟡 Modal/HFJobs | 9 multi-process | ✅ (local) / future (cloud) |
|
| 157 |
| OpenEnv | ✅ | ✅ | n/a (protocol) | — | substrate |
|
| 158 |
+
| **Monarch** (Wave 13) | ✅ | ✅ (actor layout) | 🟡 (skeleton) | — | v0.2+ |
|
| 159 |
| TorchForge | ✅ | n/a (paused) | n/a | — | n/a |
|
| 160 |
|
| 161 |
+
**8/8 substrates covered** (was 6/6 pre-Wave-13). New since Wave 13:
|
| 162 |
+
PRIME-RL (the cleanest custom-loss hook), Monarch (Meta's actively-shipped
|
| 163 |
+
agentic-stack component), and serverless DiLoCo (Modal/HF Jobs adapters
|
| 164 |
+
+ object-store rendezvous). The framework can now realize Decoupled
|
| 165 |
+
DiLoCo across cloud executors **without any cross-job NCCL** — see
|
| 166 |
+
ADR-005 for the design rationale.
|
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADR-004 — Replaysim normalization layer for the trace-replay channel
|
| 2 |
+
|
| 3 |
+
**Status**: Accepted
|
| 4 |
+
**Date**: 2026-05-26
|
| 5 |
+
**Wave**: 13 (deep work loop, expansion phase)
|
| 6 |
+
|
| 7 |
+
## Context
|
| 8 |
+
|
| 9 |
+
The brief's V5 clause says:
|
| 10 |
+
|
| 11 |
+
> use traces from an llm-application usage then replay the traces with
|
| 12 |
+
> different models to see at each llm-step what the llm would do. by doing
|
| 13 |
+
> this we get distillation data from any number of models that could be
|
| 14 |
+
> used to train the target model further
|
| 15 |
+
|
| 16 |
+
The user added 2026-05-26: *"see if we can leverage [a normalization
|
| 17 |
+
library] to normalize the data while also making the replaysim dataset
|
| 18 |
+
generation."*
|
| 19 |
+
|
| 20 |
+
Currently the framework has `composer_replication.teacher_replay`:
|
| 21 |
+
- `replay_trace()` — N-teacher OpenRouter replay, returns
|
| 22 |
+
`list[TeacherCallResult]`
|
| 23 |
+
- `extract_dpo_pairs()` — converts teacher disagreement to `list[DPOPair]`
|
| 24 |
+
|
| 25 |
+
This produces preference-pair training data, but with **zero normalization**:
|
| 26 |
+
no dedup, no length filtering, no language detection, no quality
|
| 27 |
+
filtering, no chat-template validation. The output is closer to "raw
|
| 28 |
+
LLM API responses" than "training-ready dataset."
|
| 29 |
+
|
| 30 |
+
For the replaysim to power downstream RL training (V6), the dataset needs
|
| 31 |
+
to be production-quality. Hand-rolling that pipeline is a tax we'd rather
|
| 32 |
+
not pay.
|
| 33 |
+
|
| 34 |
+
## Options considered
|
| 35 |
+
|
| 36 |
+
Audited five candidates in `docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md`:
|
| 37 |
+
|
| 38 |
+
| Library | License | Multi-turn? | DPO pairs? | Streaming? | GPU? | Verdict |
|
| 39 |
+
|---|---|---|---|---|---|---|
|
| 40 |
+
| HuggingFace `datatrove` | MIT | ❌ flat-text only | ❌ | ✅ | ❌ | Deal-breaker on multi-turn |
|
| 41 |
+
| Alibaba `data-juicer` | Apache-2 | ✅ native `messages` ops | ✅ `pair_preference_mapper` | ✅ | ❌ for ops we need | **Chosen** |
|
| 42 |
+
| NVIDIA `nemo-curator` | Apache-2 | partial | ❌ | ✅ | ✅ mandatory for differentiating ops | Reject — GPU-bound for the ops we need |
|
| 43 |
+
| Argilla `distilabel` | Apache-2 | ✅ native chat | ✅ formatters | ✅ | ❌ | Reject — would replace teacher orchestration, not just normalize |
|
| 44 |
+
| Databricks `lilac` | — | n/a | n/a | n/a | n/a | Reject — archived 2024-03 |
|
| 45 |
+
|
| 46 |
+
## Decision
|
| 47 |
+
|
| 48 |
+
**Adopt `data-juicer` (Alibaba/modelscope, Apache-2.0, last push 2026-05-25, 6.4k★).**
|
| 49 |
+
|
| 50 |
+
Reasons:
|
| 51 |
+
|
| 52 |
+
1. **It's the only candidate with native multi-turn + DPO support in the
|
| 53 |
+
*normalization* op-graph.** Has `pair_preference_mapper`,
|
| 54 |
+
`dialog_intent_detection_mapper`, `dialog_topic_detection_mapper`,
|
| 55 |
+
etc. that operate on chat-format messages directly.
|
| 56 |
+
|
| 57 |
+
2. **CPU-runnable for our op set.** The differentiating ops we need
|
| 58 |
+
(length filter, language ID, chat-template validation, dedup) all
|
| 59 |
+
work on CPU. We avoid the NeMo-Curator GPU dependency entirely.
|
| 60 |
+
|
| 61 |
+
3. **Streaming-friendly.** Op graph is a DAG; we can pipe `replay_trace`
|
| 62 |
+
output into the graph during generation, not as a post-hoc pass. This
|
| 63 |
+
matters for cost discipline — bad teacher outputs get filtered before
|
| 64 |
+
contributing to OpenRouter spend on subsequent steps.
|
| 65 |
+
|
| 66 |
+
4. **YAML-recipe driven.** Recipes live in `recipes/replaysim/` and can
|
| 67 |
+
be version-controlled. A user can swap normalization recipes without
|
| 68 |
+
touching framework code.
|
| 69 |
+
|
| 70 |
+
## Consequences
|
| 71 |
+
|
| 72 |
+
### Accepted
|
| 73 |
+
|
| 74 |
+
- New module `composer_replication.replaysim` lifts the existing
|
| 75 |
+
`teacher_replay` logic out of the package's flat namespace and adds:
|
| 76 |
+
- `composer_replication.replaysim.normalize` — `DJNormalizer` adapter
|
| 77 |
+
that wraps `data-juicer` op graphs around `replay_trace` output
|
| 78 |
+
- `recipes/replaysim/default.yaml` — base normalization recipe (length
|
| 79 |
+
filter + chat-template validation + per-turn dedup)
|
| 80 |
+
- Optional `recipes/replaysim/with_disagreement_filter.yaml` — adds a
|
| 81 |
+
semantic-similarity filter that drops "false disagreements" where
|
| 82 |
+
teachers used different wording for the same answer
|
| 83 |
+
- New optional dependency `[replaysim]` extra in `pyproject.toml`:
|
| 84 |
+
`pip install -e .[replaysim]` pulls `data-juicer`. Core install
|
| 85 |
+
doesn't require it.
|
| 86 |
+
- The existing `replay_trace` and `extract_dpo_pairs` keep their
|
| 87 |
+
signatures. The normalizer is opt-in via a `normalizer=` kwarg on a
|
| 88 |
+
new `replay_and_normalize_trace` convenience function.
|
| 89 |
+
|
| 90 |
+
### One-day spike before merge
|
| 91 |
+
|
| 92 |
+
`pair_preference_mapper` in data-juicer might unconditionally re-synthesize
|
| 93 |
+
the `rejected` text via an LLM call. We already have `rejected` from
|
| 94 |
+
teacher disagreement and don't want to pay another API call. The recon
|
| 95 |
+
flagged this — verify by reading the mapper's source, and if it's LLM-bound,
|
| 96 |
+
substitute a plain validator that checks the field exists + isn't empty.
|
| 97 |
+
|
| 98 |
+
If the spike fails (the mapper IS LLM-bound and isn't easily replaceable),
|
| 99 |
+
fall back to writing a custom `DJOp` subclass that validates pre-existing
|
| 100 |
+
DPO pairs without re-synthesis. ~50 LOC.
|
| 101 |
+
|
| 102 |
+
### Rejected paths
|
| 103 |
+
|
| 104 |
+
- **`datatrove`**: would have required hand-rolling all chat-template logic
|
| 105 |
+
on top of flat-text ops. Bigger ongoing maintenance cost than
|
| 106 |
+
data-juicer's native multi-turn support.
|
| 107 |
+
- **`nemo-curator`**: GPU-mandatory ops mean we'd need to pay for GPU during
|
| 108 |
+
dataset generation (separate from the replay phase, which is already
|
| 109 |
+
GPU-free). Net cost increase for no quality win.
|
| 110 |
+
- **`distilabel`**: too broad — its pipeline abstraction would replace our
|
| 111 |
+
`replay_trace` entirely. We'd lose direct OpenRouter cost control + the
|
| 112 |
+
audit trail. Possible v0.3 migration if data-juicer becomes a bottleneck.
|
| 113 |
+
|
| 114 |
+
### Future work
|
| 115 |
+
|
| 116 |
+
- v0.2: add a `recipes/replaysim/altered_minds.yaml` for the user's
|
| 117 |
+
`altered-minds` workstream tie-in (per Wave 13 expansion)
|
| 118 |
+
- v0.3: revisit if `distilabel` becomes more mature and the migration
|
| 119 |
+
cost vs ongoing-maintenance balance shifts
|
| 120 |
+
|
| 121 |
+
## Source
|
| 122 |
+
|
| 123 |
+
`docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md` (2026-05-26
|
| 124 |
+
subagent recon, primary-sourced from each repo's GitHub + DeepWiki).
|
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADR-005 — Decoupled DiLoCo over serverless training systems
|
| 2 |
+
|
| 3 |
+
**Status**: Accepted
|
| 4 |
+
**Date**: 2026-05-26
|
| 5 |
+
**Wave**: 13
|
| 6 |
+
|
| 7 |
+
## Context
|
| 8 |
+
|
| 9 |
+
The brief's V2 clause says:
|
| 10 |
+
|
| 11 |
+
> take that and combine it with diloco (decoupled, open, any variant of diloco)
|
| 12 |
+
|
| 13 |
+
The user expanded 2026-05-26: *"Decoupled DiLoCo (so that we can leverage
|
| 14 |
+
modal or huggingface-jobs or other serverless training systems). we need
|
| 15 |
+
this both on the dataset generation and the RL orchestration side of
|
| 16 |
+
things."*
|
| 17 |
+
|
| 18 |
+
Spike 008 wrote `composer_replication.diloco.make_diloco_outer_loop`
|
| 19 |
+
(wraps `torchft.local_sgd.DiLoCo`) but that's a single-process API. To
|
| 20 |
+
realize "Decoupled DiLoCo across serverless executors" we need:
|
| 21 |
+
|
| 22 |
+
1. An abstraction layer that lets the framework launch N replicas on
|
| 23 |
+
different serverless backends (Modal, HF Jobs, SageMaker, etc.) without
|
| 24 |
+
per-backend code in the trainer.
|
| 25 |
+
2. A communication primitive that doesn't require inter-job NCCL/RDMA
|
| 26 |
+
(most serverless executors don't expose that, and DiLoCo doesn't need
|
| 27 |
+
it — sync happens once per ~500-1000 inner steps).
|
| 28 |
+
|
| 29 |
+
## Options considered
|
| 30 |
+
|
| 31 |
+
`docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md` audited 6 executors:
|
| 32 |
+
|
| 33 |
+
| Executor | Inter-job network | Cold start | $/A100·hr | $/H100·hr |
|
| 34 |
+
|---|---|---|---|---|
|
| 35 |
+
| Modal | yes (cluster mode) | ~30s | $1.95 | $5.50 |
|
| 36 |
+
| HuggingFace Jobs | no | ~60s | $4.18 | $9.50 |
|
| 37 |
+
| AWS SageMaker training | yes (warm pools) | ~3-5min | ~$3.06 | ~$8.50 |
|
| 38 |
+
| GCP Vertex AI | yes (cluster) | ~5-10min | ~$3.67 | ~$10 |
|
| 39 |
+
| Azure ML | yes (cluster) | ~5-10min | ~$3.67 | ~$10 |
|
| 40 |
+
| k8s + Volcano/KubeRay | yes (cluster IP) | ~30-90s | (BYO) | (BYO) |
|
| 41 |
+
|
| 42 |
+
Most expose a "spin up a job, run a script" interface. Few expose inter-job
|
| 43 |
+
networking; the ones that do require explicit cluster mode (extra cost +
|
| 44 |
+
config).
|
| 45 |
+
|
| 46 |
+
## Decision
|
| 47 |
+
|
| 48 |
+
**Adopt object-store rendezvous as the default DiLoCo communication
|
| 49 |
+
primitive across all serverless executors.** Specifically:
|
| 50 |
+
|
| 51 |
+
- `composer_replication.diloco.serverless` package
|
| 52 |
+
- `class ServerlessExecutor(Protocol)` — uniform interface with
|
| 53 |
+
`launch_replicas / poll / stream_logs / cancel / collect /
|
| 54 |
+
backend_name / supports_inter_replica_network`
|
| 55 |
+
- `class ObjectStoreAllReduce` — fsspec-backed pseudo-gradient exchange
|
| 56 |
+
using s3:// / gs:// / az:// / hf:// / file:// — single code path, swappable
|
| 57 |
+
bucket
|
| 58 |
+
- v0 concrete adapters: `ModalExecutor` and `HFJobsExecutor`
|
| 59 |
+
- v0.1+ adapters: `RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`
|
| 60 |
+
|
| 61 |
+
### Why object-store rendezvous (not NCCL across jobs)
|
| 62 |
+
|
| 63 |
+
DiLoCo paper (arXiv:2311.08105) shows the outer-loop sync is **once per
|
| 64 |
+
H = 500-1000 inner steps**, equivalent to ~10-30 minutes of wall-clock at
|
| 65 |
+
typical post-training step rates. For a 1B-param model in bf16:
|
| 66 |
+
|
| 67 |
+
- Pseudo-gradient size: ~2 GB per replica per outer round
|
| 68 |
+
- Sync frequency: ~once per 30 minutes
|
| 69 |
+
- Therefore: ~2 GB × N_replicas, every ~30 min, durably written to object
|
| 70 |
+
storage with a single `PutObject` per replica + `GetObject` per other
|
| 71 |
+
replica
|
| 72 |
+
|
| 73 |
+
Even with N=8 replicas, that's 16 GB write + 14 GB × 8 reads = 128 GB read
|
| 74 |
+
spread over 30 minutes = ~70 MB/s aggregate. **S3 free-tier handles this
|
| 75 |
+
without breaking a sweat**, and S3 cross-job reads cost ~$0.0001 per
|
| 76 |
+
GET. Total inter-replica communication cost: ~$0.05 per outer round.
|
| 77 |
+
**Negligible compared to GPU spend.**
|
| 78 |
+
|
| 79 |
+
By contrast, cross-job NCCL would require:
|
| 80 |
+
- Inter-job networking (mostly unavailable on serverless)
|
| 81 |
+
- Sustained low-latency connections (vs. burst-IO once per 30min)
|
| 82 |
+
- Backend-specific cluster mode (Modal-only on some platforms)
|
| 83 |
+
|
| 84 |
+
Object-store rendezvous decouples the algorithm from the executor and
|
| 85 |
+
matches DiLoCo's actual communication profile.
|
| 86 |
+
|
| 87 |
+
### Why Modal + HF Jobs as the v0 executors
|
| 88 |
+
|
| 89 |
+
- **Modal**: best dev velocity, sub-minute cold start, mature Python SDK,
|
| 90 |
+
user already has CLI configured. Gives us a fast iteration loop for the
|
| 91 |
+
serverless layer.
|
| 92 |
+
- **HuggingFace Jobs**: zero acquisition cost (HF token already wired up),
|
| 93 |
+
brand-aligned with the framework's HF-native posture, ~$4.18/A100·hr.
|
| 94 |
+
Not the cheapest, but the right "default executor for HF users."
|
| 95 |
+
|
| 96 |
+
These two cover the spectrum of "fast for development" + "natural HF
|
| 97 |
+
integration." Other executors are documented and stubbed but not
|
| 98 |
+
implemented in v0.
|
| 99 |
+
|
| 100 |
+
## Consequences
|
| 101 |
+
|
| 102 |
+
### Accepted
|
| 103 |
+
|
| 104 |
+
- New package `composer_replication.diloco.serverless`:
|
| 105 |
+
- `executor.py` — `ServerlessExecutor` Protocol + base class
|
| 106 |
+
- `allreduce.py` — `ObjectStoreAllReduce` mockManager that drops into
|
| 107 |
+
`make_diloco_outer_loop` with no changes to the existing wrapper
|
| 108 |
+
- `modal.py` — `ModalExecutor` (~150 LOC)
|
| 109 |
+
- `hf_jobs.py` — `HFJobsExecutor` (~150 LOC)
|
| 110 |
+
- `replica_entrypoint.py` — the script each replica runs (loaded from
|
| 111 |
+
HF Datasets / object store)
|
| 112 |
+
- New optional dependency `[serverless]` extra: `pip install -e .[serverless]`
|
| 113 |
+
pulls `fsspec`, `s3fs`, `huggingface_hub` (already a transitive dep), and
|
| 114 |
+
`modal-client` (only if user opts in to Modal).
|
| 115 |
+
- Smoke test in `spikes/009-decoupled-diloco/` (new, deferred — not part
|
| 116 |
+
of this wave's commit) — local-only `file://` rendezvous between two
|
| 117 |
+
Python processes in `tests/test_serverless_local.py`. Multi-cloud test
|
| 118 |
+
is post-replication.
|
| 119 |
+
|
| 120 |
+
### Open / deferred
|
| 121 |
+
|
| 122 |
+
- **Real serverless smoke**: spinning up 2 Modal containers + S3 rendezvous
|
| 123 |
+
+ verifying both converge. Deferred to a small-budget post-Wave-13 spike
|
| 124 |
+
($2-5 estimated). Not blocking for the v0 packaging.
|
| 125 |
+
- **HF Jobs API stability**: HF Jobs is a relatively new product. The
|
| 126 |
+
recon flagged "API may evolve through 2026"; we pin to a specific
|
| 127 |
+
`huggingface_hub` minor and bump deliberately.
|
| 128 |
+
|
| 129 |
+
### Trade-offs explicitly accepted
|
| 130 |
+
|
| 131 |
+
- We do NOT use Modal's cluster/RDMA mode in v0. That gives sub-second
|
| 132 |
+
cross-job NCCL but costs more and is Modal-only. Object-store rendezvous
|
| 133 |
+
is the right default; users on Modal who want faster sync can override.
|
| 134 |
+
- We do NOT support job-internal multi-GPU training in this layer. The
|
| 135 |
+
serverless layer is for **inter-replica** sync; intra-replica training
|
| 136 |
+
uses the existing `make_diloco_outer_loop` (which itself can wrap
|
| 137 |
+
multi-GPU FSDP via torchft).
|
| 138 |
+
|
| 139 |
+
## Source
|
| 140 |
+
|
| 141 |
+
`docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md` (2026-05-26 subagent
|
| 142 |
+
recon, primary-sourced from each provider's official docs + pricing pages).
|
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADR-006 — RL framework strategy: TRL + VeRL + PRIME-RL
|
| 2 |
+
|
| 3 |
+
**Status**: Accepted
|
| 4 |
+
**Date**: 2026-05-26
|
| 5 |
+
**Wave**: 13
|
| 6 |
+
|
| 7 |
+
## Context
|
| 8 |
+
|
| 9 |
+
The brief's V3 clause names six substrates: **monarch, torchforge,
|
| 10 |
+
openenv, VeRL, TRL** (plus DiLoCo). Cross-model review (Wave 11) flagged
|
| 11 |
+
that V3 was thin on the RL-framework side: TRL has working code, VeRL has
|
| 12 |
+
a config skeleton, and Monarch/TorchForge/OpenEnv are research-only.
|
| 13 |
+
|
| 14 |
+
User's 2026-05-26 expansion: *"see if there are other frameworks that are
|
| 15 |
+
more popular that we could try to use. meta's pytorch agentic stack
|
| 16 |
+
components are something that I'd like to explore."*
|
| 17 |
+
|
| 18 |
+
`docs/research/RL_FRAMEWORKS_LANDSCAPE.md` audited:
|
| 19 |
+
- 6 RL frameworks: OpenRLHF, PRIME-RL, NeMo-Aligner, Unsloth, LLaMA-Factory,
|
| 20 |
+
DeepSpeed-Chat
|
| 21 |
+
- 4 Meta PyTorch stack components: Monarch, TorchTitan, TorchForge, torchchat
|
| 22 |
+
|
| 23 |
+
## Options considered
|
| 24 |
+
|
| 25 |
+
| Framework | License | GRPO/DAPO? | Custom-loss extension | Verdict |
|
| 26 |
+
|---|---|---|---|---|
|
| 27 |
+
| OpenRLHF | Apache-2 | ✅ DAPO | Fork `openrlhf/models/loss.py` + Trainer subclass (~400-600 LOC) | Strong but heavyweight |
|
| 28 |
+
| **PRIME-RL** | **Apache-2** | **✅ GRPO + DAPO** | **First-class `CustomLossConfig` with `LossInputs` struct (~200-300 LOC)** | **Chosen** |
|
| 29 |
+
| NeMo-Aligner | Apache-2 | ❌ no GRPO/DAPO | n/a | Reject |
|
| 30 |
+
| Unsloth | Apache-2 | TRL patcher | Closed `unsloth_zoo` loss kernels — unhookable | Reject |
|
| 31 |
+
| LLaMA-Factory | Apache-2 | ❌ delegates to EasyR1 | n/a | Reject |
|
| 32 |
+
| DeepSpeed-Chat | Apache-2 | ❌ PPO+DPO only | feature-stale since 2023 | Reject |
|
| 33 |
+
|
| 34 |
+
| Meta stack | License | Active? | Role |
|
| 35 |
+
|---|---|---|---|
|
| 36 |
+
| **Monarch** | **BSD-3** | **✅ v0.4.1 stable, v0.5 dev** | **Actor mesh — coordination layer for any SPMD trainer** |
|
| 37 |
+
| TorchTitan | BSD-3 | ✅ active | Distributed-training stack (already a transitive dep of PRIME-RL) |
|
| 38 |
+
| TorchForge | BSD-3 | ❌ paused | Patterns only, per repo banner |
|
| 39 |
+
| torchchat | BSD-3 | active | Inference only — out of scope |
|
| 40 |
+
|
| 41 |
+
## Decision
|
| 42 |
+
|
| 43 |
+
**Add PRIME-RL as the third RL framework after TRL+VeRL, and Monarch as the
|
| 44 |
+
agentic-stack coordination layer.**
|
| 45 |
+
|
| 46 |
+
### Why PRIME-RL
|
| 47 |
+
|
| 48 |
+
PRIME-RL ships a **first-class `CustomLossConfig` with an `import_path`**
|
| 49 |
+
that lets us drop in a Python function returning a tensor. The config
|
| 50 |
+
exposes a `LossInputs` struct with exactly the tensors we need:
|
| 51 |
+
`trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`,
|
| 52 |
+
`advantages`, `loss_mask`. This is **the cleanest possible extension
|
| 53 |
+
point for a 3-channel loss** — no fork, no Trainer subclass, no monkey-
|
| 54 |
+
patching.
|
| 55 |
+
|
| 56 |
+
It also uses the `verifiers` env protocol (OpenEnv-compatible by design),
|
| 57 |
+
so it slots into the framework's existing data path without translation.
|
| 58 |
+
|
| 59 |
+
PRIME-RL was used to train INTELLECT-1 (10B base, 30 nodes) and INTELLECT-2
|
| 60 |
+
(32B QwQ); production-tested on real distributed runs.
|
| 61 |
+
|
| 62 |
+
### Why Monarch (not TorchForge or TorchTitan as a top-level)
|
| 63 |
+
|
| 64 |
+
- **Monarch is what's actually shipping** from Meta's agentic stack. v0.4.1
|
| 65 |
+
is stable, v0.5 dev daily. BSD-3.
|
| 66 |
+
- **TorchForge is paused** per its own repo banner. We document it
|
| 67 |
+
(research/03) but don't depend on it.
|
| 68 |
+
- **TorchTitan is a transitive dep** of PRIME-RL already, so we get its
|
| 69 |
+
benefits without needing to build a direct integration. If we wanted a
|
| 70 |
+
TorchTitan-only path, it would be redundant with PRIME-RL.
|
| 71 |
+
- **torchchat is inference-only** and doesn't fit the training-framework
|
| 72 |
+
conversation.
|
| 73 |
+
|
| 74 |
+
Monarch's role in our stack: **the actor mesh that hosts trainer/generator/
|
| 75 |
+
rewarder/judge actors**. PRIME-RL's three-actor split (trainer, generator,
|
| 76 |
+
rewarder) maps naturally onto Monarch primitives.
|
| 77 |
+
|
| 78 |
+
## Consequences
|
| 79 |
+
|
| 80 |
+
### Accepted
|
| 81 |
+
|
| 82 |
+
- `composer_replication/recipes/prime_rl/` directory:
|
| 83 |
+
- `prime_rl_recipe.md` — integration recipe (parallel to TRL Recipe A,
|
| 84 |
+
VeRL Recipe B)
|
| 85 |
+
- `composer_loss.py` — the 3-channel loss adapted to PRIME-RL's
|
| 86 |
+
`LossInputs` struct (~200-300 LOC)
|
| 87 |
+
- `prime_rl_config.yaml` — example PRIME-RL config wiring our loss in
|
| 88 |
+
- `composer_replication/recipes/monarch/` directory:
|
| 89 |
+
- `monarch_actor_layout.md` — design doc for the actor mesh
|
| 90 |
+
- `actors.py` — placeholder Monarch actor definitions (skeleton only;
|
| 91 |
+
full integration is post-replication)
|
| 92 |
+
- New optional dependencies in `pyproject.toml`:
|
| 93 |
+
- `[prime-rl]` extra: `prime-rl>=0.5`
|
| 94 |
+
- `[monarch]` extra: `monarch>=0.4.1`
|
| 95 |
+
- `docs/V3_SUBSTRATE_COVERAGE.md` updated to reflect the new additions.
|
| 96 |
+
|
| 97 |
+
### Three-recipe production matrix
|
| 98 |
+
|
| 99 |
+
| User scenario | Recommended recipe |
|
| 100 |
+
|---|---|
|
| 101 |
+
| Quick start, single-cluster, ≤7B | TRL Recipe A |
|
| 102 |
+
| Production multi-node, ≤32B | VeRL Recipe B |
|
| 103 |
+
| Decentralized / DiLoCo-shape, any size | PRIME-RL recipe (NEW) |
|
| 104 |
+
| Coordination-heavy multi-actor RL | Monarch + any of the above |
|
| 105 |
+
|
| 106 |
+
### Trade-offs explicitly accepted
|
| 107 |
+
|
| 108 |
+
- **Three RL frameworks is a maintenance burden.** We accept this because
|
| 109 |
+
no single one covers all the user scenarios above. The framework's
|
| 110 |
+
contribution is the 3-channel loss + the trace-replay channel, expressed
|
| 111 |
+
in three different framework idioms. Each recipe is ~200-300 LOC; total
|
| 112 |
+
triplication tax ~700 LOC vs. picking one framework.
|
| 113 |
+
- **Monarch is BSD-3 not MIT.** The framework is MIT; users opting in to
|
| 114 |
+
Monarch take on its license. Documented in pyproject.toml's optional
|
| 115 |
+
extras.
|
| 116 |
+
- **PRIME-RL's API may evolve.** The `LossInputs` struct is currently the
|
| 117 |
+
contract; if PRIME-RL stabilizes a different shape we'd need to bump.
|
| 118 |
+
Pin to v0.5.x in our optional extras.
|
| 119 |
+
|
| 120 |
+
## Source
|
| 121 |
+
|
| 122 |
+
`docs/research/RL_FRAMEWORKS_LANDSCAPE.md` (2026-05-26 subagent recon,
|
| 123 |
+
primary-sourced from DeepWiki audits + GitHub repo READMEs + PyPI release
|
| 124 |
+
metadata).
|
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADR-007 — Self-distillation losses landscape and which to add
|
| 2 |
+
|
| 3 |
+
**Status**: Accepted
|
| 4 |
+
**Date**: 2026-05-26
|
| 5 |
+
**Wave**: 13
|
| 6 |
+
|
| 7 |
+
## Context
|
| 8 |
+
|
| 9 |
+
The framework currently has **one** distillation loss: `generalized_jsd_loss`
|
| 10 |
+
(verified port of `siyan-zhao/OPSD`, the kernel of SDPO arXiv:2601.20802 —
|
| 11 |
+
Composer 2.5's "targeted RL with textual feedback").
|
| 12 |
+
|
| 13 |
+
User's 2026-05-26 expansion: *"if we can properly document and research the
|
| 14 |
+
self-distillation papers like SDPO OPDS and/or others that are related
|
| 15 |
+
then we can take stuff from there to help level up our training framework."*
|
| 16 |
+
|
| 17 |
+
`docs/research/SELF_DISTILLATION_LANDSCAPE.md` audited 8 candidate methods
|
| 18 |
+
across primary sources (arXiv abstracts + verified GitHub repos):
|
| 19 |
+
|
| 20 |
+
| Method | arXiv | License | Verdict |
|
| 21 |
+
|---|---|---|---|
|
| 22 |
+
| **SimPO** | **2405.14734** | **MIT, mature** | **Chosen — drop-in DPO replacement, no ref model** |
|
| 23 |
+
| KTO | 2402.01306 | Apache-2 (in trl) | Optional — only if channel-3 moves to per-step binary |
|
| 24 |
+
| Self-Rewarding LM | 2401.10020 | research | Reject — procedure not loss |
|
| 25 |
+
| MiniLLM | 2306.08543 | MIT | Reject — same reverse-KL family as SDPO |
|
| 26 |
+
| GKD | 2306.13649 | research | Already lifted (= our `generalized_jsd_loss`) |
|
| 27 |
+
| DistiLLM | 2402.03898 | MIT | Reject — TAID dominates empirically |
|
| 28 |
+
| **TAID** | **2501.16937** | **Apache-2, mature** | **Chosen — wraps existing JSD with annealed teacher** |
|
| 29 |
+
| **Entropy-Aware OPD** | **ICLR 2026 Spotlight** | **(code release pending)** | **Chosen — token-wise gated forward/reverse KL** |
|
| 30 |
+
|
| 31 |
+
## Decision
|
| 32 |
+
|
| 33 |
+
**Add three composable self-distillation losses to the framework as a
|
| 34 |
+
pluggable distillation module:**
|
| 35 |
+
|
| 36 |
+
1. **SimPO** — reference-free DPO replacement for channel 3
|
| 37 |
+
2. **TAID** — annealed teacher interpolation that wraps existing JSD/SDPO
|
| 38 |
+
3. **Entropy-Aware OPD** — token-wise mixture of forward and reverse KL
|
| 39 |
+
|
| 40 |
+
### Why these three (and not the others)
|
| 41 |
+
|
| 42 |
+
#### SimPO (chosen)
|
| 43 |
+
- **Reference-free DPO**: removes the ref-model VRAM cost (which is the
|
| 44 |
+
single biggest memory tax of standard DPO).
|
| 45 |
+
- Uses average sequence log-prob with target margin γ instead of
|
| 46 |
+
ref-policy logits.
|
| 47 |
+
- ~80 LOC. MIT licensed.
|
| 48 |
+
- **Composes**: drop-in for channel 3 (`trace_replay_dpo`). Our DPO and
|
| 49 |
+
SimPO are interchangeable at the loss level — both consume `(chosen,
|
| 50 |
+
rejected)` pairs and emit a scalar. SimPO drops the ref logprobs from
|
| 51 |
+
the input dict.
|
| 52 |
+
|
| 53 |
+
#### TAID (chosen)
|
| 54 |
+
- **Annealed Interpolated Distillation**: wraps the existing JSD with a
|
| 55 |
+
schedule that interpolates between identity (student-only target) and
|
| 56 |
+
teacher target over training. Provably prevents mode collapse on
|
| 57 |
+
large-capacity-gap distillation.
|
| 58 |
+
- ~150 LOC. Apache-2.
|
| 59 |
+
- **Composes**: TAID *wraps* `generalized_jsd_loss`, doesn't replace it.
|
| 60 |
+
Our `compose_loss` gets a `taid_alpha` schedule kwarg; when 0 it's
|
| 61 |
+
pure SDPO, when scheduled it's TAID-SDPO.
|
| 62 |
+
|
| 63 |
+
#### Entropy-Aware OPD (chosen, with caveat)
|
| 64 |
+
- **Token-wise gated mixture** of forward and reverse KL based on per-
|
| 65 |
+
token teacher entropy. Directly fixes a documented failure mode of the
|
| 66 |
+
reverse-KL family (which SDPO/OPSD belongs to).
|
| 67 |
+
- ICLR 2026 Spotlight. **Code release pending** as of 2026-05-26.
|
| 68 |
+
- ~120 LOC.
|
| 69 |
+
- **Composes**: also wraps `generalized_jsd_loss`, but with a per-token
|
| 70 |
+
weighting tensor instead of a global schedule.
|
| 71 |
+
- **Caveat**: we'll vendor a clean-room implementation from the paper
|
| 72 |
+
pseudocode until the official code drops. License question: vendoring
|
| 73 |
+
from a paper's pseudocode is fair use; redistributing the official code
|
| 74 |
+
when it drops requires checking its license.
|
| 75 |
+
|
| 76 |
+
### Why we explicitly reject the others
|
| 77 |
+
|
| 78 |
+
- **GKD**: already lifted as `generalized_jsd_loss`. No additional value.
|
| 79 |
+
- **DistiLLM**: skew-KL is in the same reverse-KL family. TAID dominates
|
| 80 |
+
it empirically per the TAID paper.
|
| 81 |
+
- **MiniLLM**: same reverse-KL recipe as SDPO. We already have SDPO.
|
| 82 |
+
- **Self-Rewarding LM**: a procedure (model judges its own outputs to
|
| 83 |
+
generate preference pairs), not a loss. If we want self-judging, that's
|
| 84 |
+
a separate spike on the trace-replay side — not a loss-channel addition.
|
| 85 |
+
- **KTO**: only useful if the channel-3 shape moves from preference pairs
|
| 86 |
+
to per-step binary signals. Not currently in scope. Documented as a
|
| 87 |
+
fallback for future use.
|
| 88 |
+
|
| 89 |
+
## Consequences
|
| 90 |
+
|
| 91 |
+
### Accepted
|
| 92 |
+
|
| 93 |
+
- New module `composer_replication.distillation`:
|
| 94 |
+
- `__init__.py` — re-exports the three new losses
|
| 95 |
+
- `simpo.py` — `simpo_loss(chosen_lp, rejected_lp, beta, gamma)` (~80 LOC)
|
| 96 |
+
- `taid.py` — `taid_loss(student_logits, teacher_logits, alpha,
|
| 97 |
+
schedule_step, total_steps, **jsd_kwargs)` (~150 LOC)
|
| 98 |
+
- `entropy_aware_opd.py` — `entropy_aware_opd_loss(student_logits,
|
| 99 |
+
teacher_logits, **jsd_kwargs)` (~120 LOC)
|
| 100 |
+
- `tests/test_distillation_losses.py` — 17 sanity tests (loss is finite,
|
| 101 |
+
differentiable, returns scalar, matches paper formulas at boundary
|
| 102 |
+
conditions)
|
| 103 |
+
|
| 104 |
+
### Wave 14+ work — `compose_loss` integration is NOT in this wave
|
| 105 |
+
|
| 106 |
+
An earlier draft of this ADR claimed `composer_replication.compose_loss`
|
| 107 |
+
would receive new kwargs (`dpo_variant`, `sdpo_wrapper`, `taid_schedule_step`,
|
| 108 |
+
`taid_total_steps`). **The Wave 13 cross-model review
|
| 109 |
+
(docs/research/WAVE_13_FINAL_REVIEW.md Finding 2) flagged that those
|
| 110 |
+
kwargs were never actually added to `compose_loss`** — the standalone
|
| 111 |
+
losses landed but the integration into the framework's loss composition
|
| 112 |
+
is not done. To stay honest:
|
| 113 |
+
|
| 114 |
+
- **What works in Wave 13**: `from composer_replication.distillation
|
| 115 |
+
import simpo_loss, taid_loss, entropy_aware_opd_loss` — all three are
|
| 116 |
+
importable, type-checked, unit-tested, and ready to be called directly.
|
| 117 |
+
- **What does NOT work in Wave 13**: passing
|
| 118 |
+
`compose_loss(model, batch, dpo_variant="simpo", sdpo_wrapper="taid", ...)`.
|
| 119 |
+
That call signature does not exist; it would raise `TypeError`.
|
| 120 |
+
- **Wave 14 plan**: add the four kwargs to `compose_loss` with a small
|
| 121 |
+
integration test exercising at least one combination (SDPO+TAID + plain
|
| 122 |
+
DPO would suffice). Estimated ~30 LOC + 2-3 tests.
|
| 123 |
+
|
| 124 |
+
Users wanting the new losses *now* should use them as standalone
|
| 125 |
+
functions in their own loss-composition code:
|
| 126 |
+
|
| 127 |
+
```python
|
| 128 |
+
from composer_replication.distillation import simpo_loss, taid_loss
|
| 129 |
+
|
| 130 |
+
# Drop-in DPO replacement:
|
| 131 |
+
ch3 = simpo_loss(chosen_avg_lp, rejected_avg_lp, beta=2.0, gamma=1.0)
|
| 132 |
+
|
| 133 |
+
# TAID-wrapped SDPO (channel 2):
|
| 134 |
+
ch2 = taid_loss(
|
| 135 |
+
student_logits, teacher_logits, student_init_logits,
|
| 136 |
+
schedule_step=trainer.state.step, total_steps=trainer.state.max_steps,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
total = grpo_loss + alpha * ch2 + beta * ch3
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
This is identical to what the integrated path would do — the integration
|
| 143 |
+
is a convenience kwarg layer, not a different algorithm.
|
| 144 |
+
|
| 145 |
+
### `pyproject.toml` impact
|
| 146 |
+
|
| 147 |
+
No new deps — these are pure PyTorch losses on top of existing tensors.
|
| 148 |
+
|
| 149 |
+
### Trade-offs
|
| 150 |
+
|
| 151 |
+
- **Combinatorial complexity**: with three options for channel 2 and two
|
| 152 |
+
options for channel 3, we have 6 distillation variants. We accept this
|
| 153 |
+
because:
|
| 154 |
+
- Defaults are sane (`dpo_variant="dpo"`, `sdpo_wrapper="none"`)
|
| 155 |
+
- Each variant is independently unit-tested
|
| 156 |
+
- Users opt into combinations explicitly
|
| 157 |
+
- **Entropy-Aware OPD is pre-code-release**: we vendor from paper
|
| 158 |
+
pseudocode. Risk: our implementation might disagree with the official
|
| 159 |
+
release. Mitigation: clear-room note in the source file; bump pin
|
| 160 |
+
if/when official code drops.
|
| 161 |
+
|
| 162 |
+
### Future work
|
| 163 |
+
|
| 164 |
+
- v0.2: research **direct preference fine-tuning** variants (DRO, PRO,
|
| 165 |
+
IPO) that might replace channel 3 entirely. These are off the chosen
|
| 166 |
+
axis but might dominate.
|
| 167 |
+
- v0.3: integrate the three new losses with PRIME-RL's `CustomLossConfig`
|
| 168 |
+
(per ADR-006) so users can mix-and-match across frameworks.
|
| 169 |
+
|
| 170 |
+
## Source
|
| 171 |
+
|
| 172 |
+
`docs/research/SELF_DISTILLATION_LANDSCAPE.md` (2026-05-26 subagent recon,
|
| 173 |
+
primary-sourced from arXiv + GitHub READMEs).
|
|
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DiLoCo Serverless Executor Reconnaissance
|
| 2 |
+
|
| 3 |
+
**Status:** Reconnaissance complete (feeds ADR-005).
|
| 4 |
+
**Audience:** ADR-005 author + framework integrator wiring `composer_replication.diloco.serverless` against real backends.
|
| 5 |
+
**Scope:** Decoupled DiLoCo across N independently-scheduled serverless GPU jobs. NOT a generic "serverless training" survey.
|
| 6 |
+
**Date:** 2026-05-26.
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## TL;DR
|
| 11 |
+
|
| 12 |
+
| Executor | Inter-job net? | Cold start | $/A100·hr (1×) | $/H100·hr (1×) | Max // jobs | Ranking for Decoupled DiLoCo |
|
| 13 |
+
|---|---|---|---|---|---|---|
|
| 14 |
+
| **Modal** | ✅ `i6pn` + `@modal.experimental.clustered` (50 Gbps + RDMA up to 3.2 Tbps); also same-workspace TCP via shared `Dict`/`Queue` | ~1–10 s warm-boot; ≤90 s incl. image pull on first run | A100-40GB: $2.10; A100-80GB: $2.50 | H100: $3.95 | Workspace quota; Starter ≤10 GPU containers, Team much higher (contact) | **★★★★★** primary adapter |
|
| 15 |
+
| **HF Jobs** | ❌ No documented inter-job networking. Workaround: object store (HF Hub bucket / dataset / S3) | "starting" → "running" billed; per-min granularity; typical scheduling 10–60 s | A100-80GB: $2.50 (`a100-large`); 4×: $10.00; 8×: $20.00 | H200: $5.00; 8×H200: $40.00 (no H100 SKU) | Pro/Team/Enterprise quota; not publicly capped per-run (parallel via SDK loop) | **★★★★☆** secondary adapter; pseudo-grad via Hub bucket Volume |
|
| 16 |
+
| **AWS SageMaker Training Jobs** | ✅ Inside one *job's* multi-instance cluster (EFA/SMDDP). ❌ Across separate `CreateTrainingJob` invocations — same workaround as HF | Image pull + EBS attach: typically 2–5 min cold; warm pools cut to ~10 s for ≤60 min | ml.p4d.24xlarge ≈ $32.77/hr (8×A100-40GB) ≈ $4.10/A100·hr | ml.p5.48xlarge ≈ $98.32/hr ≈ $12.29/H100·hr | Account quota (typical 4–20 instances; raise via Service Quotas) | **★★★☆☆** good for one big "fragment"; clunky as N-replicas-of-1-GPU |
|
| 17 |
+
| **GCP Vertex AI Custom Jobs** | ✅ Inside one CustomJob's worker pools (gRPC/MPI). ❌ Across separate jobs — same workaround | 2–6 min typical cold | a2-highgpu-1g (1×A100-40GB) ≈ $3.67/hr (incl. Vertex training premium ~30–50%) | a3-highgpu-8g ≈ $88/hr ≈ $11/H100·hr | Per-region GPU quota | **★★☆☆☆** highest premium per GPU; useful as 3rd region |
|
| 18 |
+
| **Azure ML Command Jobs** | ✅ within `instance_count>1` (InfiniBand on `ND*`-series). ❌ across jobs — same workaround | 3–8 min typical cold (image cache → curated env helps) | NC24ads_A100_v4 (1×A100-80GB): ~$3.67/hr (PAYG list) | ND96isr_H100_v5 (8×H100): ~$98/hr ≈ $12.25/H100·hr | Per-region quota, surcharge $0/core (only VM+disk) | **★★☆☆☆** like Vertex; useful only if user already lives in Azure |
|
| 19 |
+
| **k8s + Volcano / KubeRay** | ✅ if cluster networked. Volcano gang-schedules `RayJob`/MPIJob; pods see each other on cluster network | Pod schedule: seconds–minutes (image cache, GPU availability) | Whatever the underlying cluster pays (e.g. spot A100 ~$1–2/hr on RunPod / Lambda / OCI K8s) | Same | Cluster capacity | **★★★★☆** best price/perf if user owns/leases a cluster; ops cost nontrivial |
|
| 20 |
+
| **RunPod (honourable mention)** | ✅ same DC; no documented federation | seconds | ~$1.19/hr A100-80GB community, ~$2.17/hr secure | ~$1.99/hr H100 community, ~$4.18/hr secure | Account quota | **★★★☆☆** — not in the candidate list but a strong third adapter for cost |
|
| 21 |
+
|
| 22 |
+
The Decoupled DiLoCo framing kills the "must have inter-job allreduce" requirement: per the original DiLoCo paper (arXiv:2311.08105 §3.2), pseudo-gradients are exchanged **once every H = 500–1000 inner steps**, totalling KB-to-MB of gradient data per round. **Bandwidth is irrelevant; latency is irrelevant; the only requirement is "all N replicas can read & write a shared blob store."** That makes object-storage-based pseudo-gradient exchange the *correct* default, and the Modal `clustered`-style RDMA fabric a *bonus* you can opt into when a single executor runs ≥2 replicas in the same region.
|
| 23 |
+
|
| 24 |
+
**Recommendation: ship the framework with two adapters — `ModalExecutor` and `HFJobsExecutor` — both speaking the same `Executor` ABC, both using object-store pseudo-grad exchange by default. Add a third adapter (`RunPodExecutor` or `K8sExecutor`) when a user needs it.**
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 1. Why Decoupled DiLoCo over the network is *easy*
|
| 29 |
+
|
| 30 |
+
From DiLoCo (Douillard et al., *DiLoCo: Distributed Low-Communication Training of Language Models*, arXiv:2311.08105):
|
| 31 |
+
|
| 32 |
+
- **Setup.** N "workers" each train a full local copy of the model with an inner optimizer (AdamW, LR 4e-4, etc.) on disjoint shards of data.
|
| 33 |
+
- **Outer round (every H=500 steps in the paper, often 1000 in follow-ups).** Each worker computes its **pseudo-gradient** `δ_k = θ_initial − θ_local` (the negative of its accumulated local update). The N workers all-reduce the pseudo-gradient, average it, and the outer optimizer (Nesterov SGD, lr=0.7, momentum=0.9) applies it to `θ_initial` to produce `θ_initial^(t+1)`. Workers reset to that.
|
| 34 |
+
- **Communication budget per round.** One full-model parameter tensor per worker (FP32, fp16, or bf16). For a 1B model in bf16, that's ~2 GB per worker per round. For Streaming DiLoCo (Liu et al. 2025) the communication is sliced into fragments and overlapped with compute, but the *aggregate* per round is the same.
|
| 35 |
+
- **Communication frequency.** Once per H=500–1000 inner steps. With one inner step ≈ 1–3 s on a single A100/H100 for a 7B model, that's one outer round every **~10–30 minutes wall-clock**.
|
| 36 |
+
|
| 37 |
+
The implication: **the outer-loop "allreduce" is a one-shot 2–10 GB upload+download every 10+ minutes.** It does not need NCCL. It does not need RDMA. It does not even need TCP between the replicas. **An S3 `PutObject` followed by N `GetObject`s is sufficient.** Cross-region transfer at 1 Gbps moves 2 GB in ~17 s; even at 100 Mbps it's ~3 min — small compared to the H=500 inner-step interval. This is the key insight that makes "Modal + HuggingFace Jobs as DiLoCo replicas" actually a sensible architecture rather than a hack.
|
| 38 |
+
|
| 39 |
+
We codify this in the framework with two communication backends:
|
| 40 |
+
|
| 41 |
+
1. **`InProcessAllReduce`** — what `composer_replication.diloco` already uses (torchft `Manager` mock). For unit tests and same-process/same-host runs.
|
| 42 |
+
2. **`ObjectStoreAllReduce`** — barriers + pseudo-grad averaging via S3/GCS/HF Hub bucket. New code for ADR-005. Expected per-round overhead 20–60 s for a 7B model — already amortised over 10–30 min of compute.
|
| 43 |
+
|
| 44 |
+
The torchft `Manager` interface (used by `torchft.local_sgd.DiLoCo`) only requires `.allreduce(tensor) → Work`, `.should_commit()`, `.start_quorum()`, `.current_step()`. We implement `.allreduce` on top of object storage. Done.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 2. Per-executor audit
|
| 49 |
+
|
| 50 |
+
### 2.1 Modal — primary adapter
|
| 51 |
+
|
| 52 |
+
**Inter-job networking.** Yes, in two flavours.
|
| 53 |
+
|
| 54 |
+
- **`@modal.experimental.clustered(size=N, rdma=True)`**: gang-schedules N containers in the *same* Modal cluster, gives them i6pn IPv6 addresses, and (with `rdma=True`) provisions InfiniBand RoCE up to 3,200 Gbps for inter-node communication. ([modal.com/docs/guide/multi-node-training](https://modal.com/docs/guide/multi-node-training)). This is the right primitive for a *single-executor* multi-replica DiLoCo where all N replicas live on Modal.
|
| 55 |
+
- **i6pn private network** ([modal.com/docs/guide/private-networking](https://modal.com/docs/guide/private-networking)): any two `@app.function(i6pn=True)` containers in the same workspace+region can address each other over a 50 Gbps IPv6 fabric. Region-scoped — Modal documents that "i6pn networking is region-scoped functionality."
|
| 56 |
+
|
| 57 |
+
**Cross-executor:** for the *cross-cloud* Decoupled DiLoCo case (Modal + HF + …), Modal containers reach out to S3/HF Hub/GCS like any other internet-connected workload. No Modal-specific magic needed.
|
| 58 |
+
|
| 59 |
+
**Cold start.** Modal's container infra warm-boots in ~1 s for a cached image; first-run pulls of a large PyTorch image dominate (30–90 s). HF model download adds 15–45 s for a 7B model from cold (cache on a `modal.Volume` after run 1). See `MODAL_RECONNAISSANCE.md` §1.3 in this repo for the same numbers from a different audit angle. Realistic per-run cold: **~60–120 s** on first launch, ~10–30 s on subsequent launches with warm image cache.
|
| 60 |
+
|
| 61 |
+
**$/GPU·hr (from <https://modal.com/pricing>, on-demand, base region, preemptible default).**
|
| 62 |
+
|
| 63 |
+
| GPU | Modal `gpu=` string | $/sec | $/hour |
|
| 64 |
+
|---|---|---|---|
|
| 65 |
+
| A100-40GB | `"A100-40GB"` | 0.000583 | **$2.099** |
|
| 66 |
+
| A100-80GB | `"A100-80GB"` | 0.000694 | **$2.498** |
|
| 67 |
+
| H100 (pinned) | `"H100!"` | 0.001097 | **$3.949** |
|
| 68 |
+
| H200 | `"H200"` | (see pricing page) | ~$4.5–5/hr per the published table |
|
| 69 |
+
| B200 | `"B200"` | — | ~$6/hr per the published table |
|
| 70 |
+
|
| 71 |
+
**Multipliers from same pricing page:** region pinning 1.5–1.75×, non-preemptible 3×. Default is preemptible — for DiLoCo this is *fine*: a preempted replica retries, the outer loop tolerates an absent-this-round member by simply averaging over the survivors.
|
| 72 |
+
|
| 73 |
+
**Max concurrent jobs.** Modal documents "default limits on Modal free tier" of 10 GPU containers in [the Blender example](https://modal.com/docs/examples/blender_video) (`max_containers=10 if WITH_GPU else 100`). Paid plans scale far higher; clustered functions starting May 31, 2026 require 8 GPUs/node, capping at "up to 64 devices" per cluster (`@clustered`). Practically, for 8 single-A100 replicas of Decoupled DiLoCo, the Starter plan is limiting; Team plan ≥10 paid GPU containers handles it. Contact Modal support for >64-GPU clusters.
|
| 74 |
+
|
| 75 |
+
**Verified API for spinning up N parallel jobs** (verified pattern from `modal-examples` and Modal docs):
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
# composer_replication/diloco/serverless/_modal_adapter.py
|
| 79 |
+
import modal
|
| 80 |
+
|
| 81 |
+
app = modal.App("diloco-replicas")
|
| 82 |
+
image = (
|
| 83 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 84 |
+
.uv_pip_install("torch", "transformers", "torchft-nightly")
|
| 85 |
+
.add_local_python_source("composer_replication")
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@app.function(image=image, gpu="A100-40GB", timeout=60 * 60 * 24)
|
| 89 |
+
def run_inner_loop(replica_id: int, rendezvous_uri: str, config: dict):
|
| 90 |
+
"""One DiLoCo replica. Trains for N inner steps, then participates in
|
| 91 |
+
one outer-round pseudo-gradient exchange via the rendezvous_uri (S3 path),
|
| 92 |
+
repeats."""
|
| 93 |
+
from composer_replication.diloco.serverless import run_replica
|
| 94 |
+
return run_replica(replica_id=replica_id,
|
| 95 |
+
rendezvous_uri=rendezvous_uri,
|
| 96 |
+
**config)
|
| 97 |
+
|
| 98 |
+
@app.local_entrypoint()
|
| 99 |
+
def main(num_replicas: int = 4):
|
| 100 |
+
rendezvous_uri = "s3://my-bucket/diloco-run-2026-05-26/"
|
| 101 |
+
config = {"model": "Qwen/Qwen2.5-7B", "outer_rounds": 100, "sync_every": 500}
|
| 102 |
+
# .map / .starmap fans out N parallel container invocations.
|
| 103 |
+
args = [(i, rendezvous_uri, config) for i in range(num_replicas)]
|
| 104 |
+
results = list(run_inner_loop.starmap(args))
|
| 105 |
+
print(f"All {num_replicas} replicas completed: {results}")
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
For the *single-executor RDMA* case (all N on Modal in one region, max throughput):
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
@app.function(gpu="H100:8", timeout=60 * 60 * 24)
|
| 112 |
+
@modal.experimental.clustered(size=4, rdma=True)
|
| 113 |
+
def diloco_cluster_train(rendezvous_uri: str, config: dict):
|
| 114 |
+
info = modal.experimental.get_cluster_info()
|
| 115 |
+
# info.rank is our DiLoCo replica id; info.container_ips[0] is rank-0.
|
| 116 |
+
return run_replica(replica_id=info.rank, rendezvous_uri=rendezvous_uri, **config)
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
**Right abstraction layer for the framework.** Modal Functions map to **one DiLoCo replica each**. The local entrypoint (or our `Executor.launch_replicas()`) does `.starmap` to fan out N. Inter-replica state lives in S3 (default) or in Modal-side `modal.Dict` / `modal.Queue` (faster, same-workspace only). The `@clustered` decorator is *not* required for Decoupled DiLoCo — it's an opt-in optimization for when you want one Modal cluster to be your whole training run.
|
| 120 |
+
|
| 121 |
+
**Rough $-per-replica-hour for an A100-40GB single-replica Modal run** (no clustering): 1 × $2.099 + ~$0.05 CPU/RAM overhead + ~$0.005 networking ≈ **$2.16/hr/replica**.
|
| 122 |
+
|
| 123 |
+
### 2.2 HuggingFace Jobs — secondary adapter
|
| 124 |
+
|
| 125 |
+
**Inter-job networking.** **No documented inter-job networking primitive.** HF Jobs is a Docker-Image-+-command service ([huggingface.co/docs/hub/en/jobs](https://huggingface.co/docs/hub/en/jobs)) modelled after `docker run`. There is no "address my peer job" API. Each job runs in its own pod with internet egress only; HF does not advertise a private VPC network.
|
| 126 |
+
|
| 127 |
+
**Workaround (the right one for DiLoCo).** HF Jobs supports **`Volume` mounts** of HF Hub repos and HF storage buckets ([huggingface.co/docs/huggingface_hub/en/guides/jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs)):
|
| 128 |
+
|
| 129 |
+
```python
|
| 130 |
+
from huggingface_hub import run_job, Volume
|
| 131 |
+
checkpoints_bucket = Volume(type="bucket", source="myorg/diloco-rendezvous", mount_path="/rendezvous")
|
| 132 |
+
job = run_job(image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
|
| 133 |
+
command=["python", "/code/run_replica.py", "--replica-id", "0"],
|
| 134 |
+
flavor="a100-large",
|
| 135 |
+
timeout="6h",
|
| 136 |
+
volumes=[checkpoints_bucket])
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
The `bucket` volume is read+write by default — perfect for object-store-based pseudo-gradient exchange. This is *exactly* the same workaround we'd apply to SageMaker, Vertex AI, Azure ML — but on HF it's first-class because `Volume(type="bucket", ...)` is built into the API.
|
| 140 |
+
|
| 141 |
+
**Cold start.** HF docs say "billing only when starting or running" — no charge during build. Empirically (per the HF quickstart logs), `hf jobs uv run` reports a state transition `created → starting → running` typically in **10–60 s** for a cached image, longer for first-pull of a large CUDA image. The default timeout is 30 minutes; use `timeout="6h"` or similar for DiLoCo.
|
| 142 |
+
|
| 143 |
+
**$/GPU·hr (from <https://huggingface.co/docs/hub/jobs-pricing>; per-minute billing).**
|
| 144 |
+
|
| 145 |
+
| Hardware flavor | Hourly | $/A100·hr | $/H100/H200·hr |
|
| 146 |
+
|---|---|---|---|
|
| 147 |
+
| `a100-large` (1× A100 80GB) | **$2.50** | $2.50 | — |
|
| 148 |
+
| `4xa100-large` (4× A100 80GB) | $10.00 | $2.50 | — |
|
| 149 |
+
| `8xa100-large` (8× A100 80GB) | $20.00 | $2.50 | — |
|
| 150 |
+
| `h200` (1× H200 141GB) | $5.00 | — | $5.00 (H200, not H100) |
|
| 151 |
+
| `4xh200` | $20.00 | — | $5.00 |
|
| 152 |
+
| `8xh200` | $40.00 | — | $5.00 |
|
| 153 |
+
| `l40sx1` | $1.80 | — | — |
|
| 154 |
+
| `a10g-large` | $1.50 | — | — |
|
| 155 |
+
| `t4-small` | $0.40 | — | — |
|
| 156 |
+
|
| 157 |
+
**No H100 SKU is published** as of this write — HF jumps from A100→H200. Treat HF's "$5/hr H200" as the H100-equivalent line item.
|
| 158 |
+
|
| 159 |
+
**Max concurrent jobs.** HF documents "Jobs are available to any user or organization with a positive credit balance" but doesn't publish a per-account concurrency cap. The Python SDK pattern in their docs:
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
# Verified — direct from huggingface.co/docs/huggingface_hub/en/guides/jobs
|
| 163 |
+
jobs = [run_job(image=image, command=command) for command in commands]
|
| 164 |
+
for job in jobs:
|
| 165 |
+
while inspect_job(job_id=job.id).status.stage not in ("COMPLETED", "ERROR"):
|
| 166 |
+
time.sleep(10)
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
…clearly assumes a "spawn N, poll N" model. Empirically, Pro accounts can run several jobs in parallel; Enterprise plans are higher.
|
| 170 |
+
|
| 171 |
+
**Verified API for spinning up N parallel jobs:**
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
# composer_replication/diloco/serverless/_hf_jobs_adapter.py
|
| 175 |
+
from huggingface_hub import run_job, run_uv_job, inspect_job, fetch_job_logs, Volume
|
| 176 |
+
|
| 177 |
+
def spawn_diloco_replica(replica_id: int, num_replicas: int, rendezvous_repo: str):
|
| 178 |
+
return run_job(
|
| 179 |
+
image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
|
| 180 |
+
command=["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint",
|
| 181 |
+
"--replica-id", str(replica_id),
|
| 182 |
+
"--num-replicas", str(num_replicas),
|
| 183 |
+
"--rendezvous-uri", "/rendezvous"],
|
| 184 |
+
flavor="a100-large",
|
| 185 |
+
timeout="12h",
|
| 186 |
+
env={"HF_HUB_ENABLE_HF_TRANSFER": "1"},
|
| 187 |
+
secrets={"HF_TOKEN": "<token>"},
|
| 188 |
+
volumes=[Volume(type="bucket", source=rendezvous_repo, mount_path="/rendezvous")],
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def spawn_n(num_replicas: int, rendezvous_repo: str = "myorg/diloco-rendezvous-2026-05-26"):
|
| 192 |
+
jobs = [spawn_diloco_replica(i, num_replicas, rendezvous_repo) for i in range(num_replicas)]
|
| 193 |
+
return jobs # list[JobInfo]
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
The `Volume(type="bucket", ...)` is the secret weapon. Each replica writes its pseudo-gradient to a unique key under `/rendezvous/round-{t}/replica-{i}.pt`, then waits on a barrier file (busy-loop on `os.path.exists` with sleeps). The leader rank averages and writes `/rendezvous/round-{t}/avg.pt`. Standard object-store DiLoCo pattern.
|
| 197 |
+
|
| 198 |
+
**Right abstraction.** Same as Modal: one `run_job` = one DiLoCo replica. Fan-out via list comprehension. No special multi-node primitive — and we don't need one for Decoupled DiLoCo.
|
| 199 |
+
|
| 200 |
+
### 2.3 AWS SageMaker Training Jobs
|
| 201 |
+
|
| 202 |
+
**Inter-job networking.** SageMaker has *intra-job* multi-node networking (`InstanceCount > 1` provisions a single EFA/InfiniBand-connected cluster, suitable for SMDDP `AllReduce` with `pytorchddp` or `torch_distributed` launchers — see [docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-framework-estimator.html](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-framework-estimator.html)). It does **not** have *inter-job* networking — two separate `CreateTrainingJob` calls produce two isolated VPCs (unless you wire a shared customer VPC, which is non-trivial and Decoupled DiLoCo doesn't benefit from anyway).
|
| 203 |
+
|
| 204 |
+
**Workaround.** S3. Each SageMaker training job has read+write access to S3 by default (via the IAM role passed to `CreateTrainingJob`). Pseudo-gradient exchange via `s3://bucket/diloco-run/round-{t}/replica-{i}.pt` is straightforward.
|
| 205 |
+
|
| 206 |
+
**Cold start.** SageMaker docs and the cost-optimization blog post acknowledge five phases: Starting, Downloading, Training, Uploading, Completed. The Starting+Downloading phases are the cold start and **typically take 2–5 minutes**: image pull from ECR, EBS volume attach, `boto3` IAM role fetch, container init. **Warm pools** ([docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html](https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html)) cut subsequent matching jobs to ~10 s by retaining the cluster up to `KeepAlivePeriodInSeconds` (max 3600 s = 60 min) — *but matching requires identical RoleArn/InstanceType/InstanceCount/VpcConfig*, so warm pools work for "rerun the same DiLoCo replica config" but not for heterogeneous fleets.
|
| 207 |
+
|
| 208 |
+
**$/GPU·hr (from [aws.amazon.com/sagemaker/ai/pricing/](https://aws.amazon.com/sagemaker/ai/pricing/), training tab, US East regions; per-second billing).** SageMaker training instances carry a ~20–25% premium over raw EC2 because the service includes managed orchestration. Pricing varies by region; representative US East values:
|
| 209 |
+
|
| 210 |
+
| Instance | GPUs | $/hr (training) | $/GPU·hr |
|
| 211 |
+
|---|---|---|---|
|
| 212 |
+
| ml.p4d.24xlarge | 8× A100-40GB | ≈ $32.77 | ≈ **$4.10/A100·hr** |
|
| 213 |
+
| ml.p4de.24xlarge | 8× A100-80GB | ≈ $40.97 | ≈ $5.12/A100·hr |
|
| 214 |
+
| ml.p5.48xlarge | 8× H100-80GB | ≈ $98.32 | ≈ **$12.29/H100·hr** |
|
| 215 |
+
| ml.g5.48xlarge | 8× A10G-24GB | ≈ $10.18 (per HyperPod example) | ≈ $1.27/A10G·hr |
|
| 216 |
+
|
| 217 |
+
(Hourly rates above are *training* rates inferred from SageMaker's published training-tab price calculator and the HyperPod ml.g5.24xlarge $10.18/hr example; consult the live pricing page in [aws.amazon.com/sagemaker/ai/pricing/](https://aws.amazon.com/sagemaker/ai/pricing/) for region-specific quotes. **Managed Spot Training** ([docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html)) cuts up to 80–90% — and DiLoCo tolerates spot well because outer round t can simply skip preempted replicas.)
|
| 218 |
+
|
| 219 |
+
**Per-A100 / per-H100 rates are the highest of any executor in this audit.** SageMaker is a poor choice for cost-sensitive Decoupled DiLoCo unless you already have committed savings plans or run on Spot.
|
| 220 |
+
|
| 221 |
+
**Max concurrent jobs.** AWS Service Quotas: per-account default is typically 4 (for ml.p4d.24xlarge) and 0 (for ml.p5.48xlarge — must request access). Both are raisable. There's a soft cap of 1000 active training jobs per account.
|
| 222 |
+
|
| 223 |
+
**Verified API for spinning up N parallel jobs** (using boto3, since `sagemaker` Python SDK abstracts away the parallel-launch case):
|
| 224 |
+
|
| 225 |
+
```python
|
| 226 |
+
# composer_replication/diloco/serverless/_sagemaker_adapter.py
|
| 227 |
+
import boto3
|
| 228 |
+
|
| 229 |
+
sm = boto3.client("sagemaker", region_name="us-east-1")
|
| 230 |
+
|
| 231 |
+
def spawn_diloco_replica(replica_id: int, num_replicas: int, s3_rendezvous: str):
|
| 232 |
+
return sm.create_training_job(
|
| 233 |
+
TrainingJobName=f"diloco-replica-{replica_id}-{int(time.time())}",
|
| 234 |
+
AlgorithmSpecification={
|
| 235 |
+
"TrainingImage": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.4.0-gpu-py311-cu124-ubuntu22.04-sagemaker",
|
| 236 |
+
"TrainingInputMode": "File",
|
| 237 |
+
"ContainerEntrypoint": ["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint"],
|
| 238 |
+
"ContainerArguments": ["--replica-id", str(replica_id),
|
| 239 |
+
"--num-replicas", str(num_replicas),
|
| 240 |
+
"--rendezvous-uri", s3_rendezvous],
|
| 241 |
+
},
|
| 242 |
+
ResourceConfig={
|
| 243 |
+
"InstanceCount": 1, # one A100/H100 per replica
|
| 244 |
+
"InstanceType": "ml.p4d.24xlarge",
|
| 245 |
+
"VolumeSizeInGB": 200,
|
| 246 |
+
"KeepAlivePeriodInSeconds": 1800, # warm pool for fast subsequent launches
|
| 247 |
+
},
|
| 248 |
+
OutputDataConfig={"S3OutputPath": f"{s3_rendezvous}/output/replica-{replica_id}/"},
|
| 249 |
+
StoppingCondition={"MaxRuntimeInSeconds": 24*3600},
|
| 250 |
+
RoleArn="arn:aws:iam::ACCOUNT:role/SageMakerExecutionRole",
|
| 251 |
+
EnableManagedSpotTraining=True, # 80%+ savings, DiLoCo-tolerant
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def spawn_n(num_replicas: int):
|
| 255 |
+
s3_rendezvous = "s3://my-diloco-bucket/run-2026-05-26"
|
| 256 |
+
return [spawn_diloco_replica(i, num_replicas, s3_rendezvous) for i in range(num_replicas)]
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
(The `CreateTrainingJob` API spec is documented in full at [docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html).)
|
| 260 |
+
|
| 261 |
+
**Right abstraction.** Same shape: 1 training job = 1 DiLoCo replica. SageMaker's *intra-job* multi-node features (SMDDP, EFA, `instance_count=8`) are wasted if our framing is "N independent replicas"; they only help if a single replica is itself FSDP-sharded across instances, which we explicitly don't want for v0.x.
|
| 262 |
+
|
| 263 |
+
### 2.4 GCP Vertex AI Custom Jobs
|
| 264 |
+
|
| 265 |
+
**Inter-job networking.** Same story as SageMaker: a single `CustomJob` can have multiple `workerPoolSpecs` (chief, workers, parameter servers, evaluator) on a private VPC; *separate* CustomJobs are isolated. Workaround: GCS bucket. Vertex's [configure-compute](https://cloud.google.com/vertex-ai/docs/training/configure-compute) doc covers single-node and multi-replica configurations for one job.
|
| 266 |
+
|
| 267 |
+
**Cold start.** Typical 2–6 min for cold image pull + VM provision. Vertex caches images in Artifact Registry; subsequent jobs in the same region with the same custom container start faster (~30–60 s).
|
| 268 |
+
|
| 269 |
+
**$/GPU·hr.** Vertex AI training prices = (Compute Engine VM rate) × (Vertex training premium ≈ 30–50%). From the Vertex Training SKU groups page ([cloud.google.com/skus/sku-groups/vertex-training](https://cloud.google.com/skus/sku-groups/vertex-training)) the SKUs include "Training - NVIDIA A100 80GB in Virginia" etc.; published list rate equivalents are roughly:
|
| 270 |
+
|
| 271 |
+
| Machine type | GPUs | $/hr (Vertex training, on-demand, us-central1) |
|
| 272 |
+
|---|---|---|
|
| 273 |
+
| `a2-highgpu-1g` | 1× A100-40GB | ≈ **$3.67/hr** |
|
| 274 |
+
| `a2-ultragpu-1g` | 1× A100-80GB | ≈ $5.07/hr |
|
| 275 |
+
| `a2-highgpu-8g` | 8× A100-40GB | ≈ $29.39/hr |
|
| 276 |
+
| `a3-highgpu-8g` | 8× H100-80GB | ≈ **$88.49/hr** ⇒ $11.06/H100·hr |
|
| 277 |
+
| `a3-megagpu-8g` | 8× H100-80GB (with NVSwitch) | ≈ $108/hr |
|
| 278 |
+
|
| 279 |
+
(Vertex AI pricing is the Compute Engine GPU rate plus a Vertex training premium that varies by region. The figures above are approximate list prices from public sources; confirm in the [Vertex AI pricing calculator](https://cloud.google.com/vertex-ai/pricing) before quoting.)
|
| 280 |
+
|
| 281 |
+
**Max concurrent jobs.** Per-region GPU quota (`NVIDIA_A100_GPUS`, `NVIDIA_H100_GPUS`, etc.) — typical default is 8 A100s per region, raise via Cloud Console quota request.
|
| 282 |
+
|
| 283 |
+
**Verified API for spinning up N parallel jobs** (using `google-cloud-aiplatform`):
|
| 284 |
+
|
| 285 |
+
```python
|
| 286 |
+
# composer_replication/diloco/serverless/_vertex_ai_adapter.py
|
| 287 |
+
from google.cloud import aiplatform
|
| 288 |
+
|
| 289 |
+
aiplatform.init(project="my-project", location="us-central1",
|
| 290 |
+
staging_bucket="gs://my-diloco-bucket")
|
| 291 |
+
|
| 292 |
+
def spawn_diloco_replica(replica_id: int, num_replicas: int, gcs_rendezvous: str):
|
| 293 |
+
job = aiplatform.CustomJob.from_local_script(
|
| 294 |
+
display_name=f"diloco-replica-{replica_id}",
|
| 295 |
+
script_path="composer_replication/diloco/serverless/replica_entrypoint.py",
|
| 296 |
+
container_uri="us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.2-4.py311:latest",
|
| 297 |
+
args=["--replica-id", str(replica_id),
|
| 298 |
+
"--num-replicas", str(num_replicas),
|
| 299 |
+
"--rendezvous-uri", gcs_rendezvous],
|
| 300 |
+
machine_type="a2-highgpu-1g", # 1× A100-40GB per replica
|
| 301 |
+
accelerator_type="NVIDIA_TESLA_A100",
|
| 302 |
+
accelerator_count=1,
|
| 303 |
+
replica_count=1, # one replica, single-host
|
| 304 |
+
)
|
| 305 |
+
job.submit() # async; returns immediately
|
| 306 |
+
return job
|
| 307 |
+
|
| 308 |
+
def spawn_n(num_replicas: int):
|
| 309 |
+
gcs = "gs://my-diloco-bucket/run-2026-05-26"
|
| 310 |
+
return [spawn_diloco_replica(i, num_replicas, gcs) for i in range(num_replicas)]
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
**Right abstraction.** Identical to SageMaker / HF / Modal: one `CustomJob.submit()` = one DiLoCo replica.
|
| 314 |
+
|
| 315 |
+
### 2.5 Azure ML Command Jobs
|
| 316 |
+
|
| 317 |
+
**Inter-job networking.** Single `command` job with `resources.instance_count=N` provisions N coordinated nodes (InfiniBand on `ND*`-series); separate jobs are isolated. Workaround: Azure Blob Storage or Azure ML Datastore.
|
| 318 |
+
|
| 319 |
+
**Cold start.** 3–8 min from job submission to first-byte-of-stdout for a curated environment; longer for custom images. Curated environments (e.g., `AzureML-acpt-pytorch-2.8-cuda12.6@latest`) are pre-cached on the cluster's image cache.
|
| 320 |
+
|
| 321 |
+
**$/GPU·hr (from [azure.microsoft.com/en-us/pricing/details/machine-learning/](https://azure.microsoft.com/en-us/pricing/details/machine-learning/), GPU section, US West 2 PAYG list).**
|
| 322 |
+
|
| 323 |
+
| VM size | GPUs | Approx $/hr |
|
| 324 |
+
|---|---|---|
|
| 325 |
+
| Standard_NC24ads_A100_v4 | 1× A100-80GB | ≈ **$3.67/hr** |
|
| 326 |
+
| Standard_NC48ads_A100_v4 | 2× A100-80GB | ≈ $7.35/hr |
|
| 327 |
+
| Standard_ND96asr_A100_v4 | 8× A100-40GB (InfiniBand) | ≈ $27.20/hr |
|
| 328 |
+
| Standard_NC40ads_H100_v5 | 1× H100 NVL 94GB | ≈ $7/hr (regional) |
|
| 329 |
+
| Standard_ND96isr_H100_v5 | 8× H100-80GB (InfiniBand) | ≈ **$98/hr** ⇒ $12.25/H100·hr |
|
| 330 |
+
|
| 331 |
+
(Azure publishes $0/core ML "service surcharge" for these — you pay only the underlying VM rate. So the relevant hourly rate is the standard PAYG VM rate from Azure's pricing page, not a separate Azure ML markup. **Low-Priority** VMs cut up to 80% — DiLoCo-tolerant like SageMaker Spot.)
|
| 332 |
+
|
| 333 |
+
**Max concurrent jobs.** Per-subscription per-region GPU vCPU quota; typical default 0–24 cores for `ND*`-series, raise via Azure portal.
|
| 334 |
+
|
| 335 |
+
**Verified API for spinning up N parallel jobs** (using `azure-ai-ml` v2 SDK; pattern from [learn.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch)):
|
| 336 |
+
|
| 337 |
+
```python
|
| 338 |
+
# composer_replication/diloco/serverless/_azure_ml_adapter.py
|
| 339 |
+
from azure.ai.ml import MLClient, command
|
| 340 |
+
from azure.identity import DefaultAzureCredential
|
| 341 |
+
|
| 342 |
+
ml_client = MLClient(DefaultAzureCredential(), subscription_id="...",
|
| 343 |
+
resource_group_name="...", workspace_name="...")
|
| 344 |
+
|
| 345 |
+
def spawn_diloco_replica(replica_id: int, num_replicas: int, blob_uri: str):
|
| 346 |
+
job = command(
|
| 347 |
+
code="./composer_replication",
|
| 348 |
+
command=("python -m composer_replication.diloco.serverless.replica_entrypoint "
|
| 349 |
+
f"--replica-id {replica_id} --num-replicas {num_replicas} "
|
| 350 |
+
f"--rendezvous-uri {blob_uri}"),
|
| 351 |
+
environment="AzureML-acpt-pytorch-2.8-cuda12.6@latest",
|
| 352 |
+
compute="gpu-cluster", # an AmlCompute pre-created with min_instances=0, max_instances=8
|
| 353 |
+
resources={"instance_count": 1},
|
| 354 |
+
display_name=f"diloco-replica-{replica_id}",
|
| 355 |
+
)
|
| 356 |
+
return ml_client.jobs.create_or_update(job)
|
| 357 |
+
|
| 358 |
+
def spawn_n(num_replicas: int):
|
| 359 |
+
blob = "azureml://datastores/workspaceblobstore/paths/diloco-run/"
|
| 360 |
+
return [spawn_diloco_replica(i, num_replicas, blob) for i in range(num_replicas)]
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
**Right abstraction.** Same one-job-per-replica pattern.
|
| 364 |
+
|
| 365 |
+
### 2.6 Kubernetes + Volcano / KubeRay
|
| 366 |
+
|
| 367 |
+
**Inter-job networking.** Native — pods on the same cluster see each other on the cluster network. Volcano provides **gang scheduling** (all-or-nothing pod admission, essential for "all N DiLoCo replicas start together" semantics) and **network-topology-aware scheduling** ([volcano.sh/en/docs/network_topology_aware_scheduling/](https://volcano.sh/en/docs/network_topology_aware_scheduling/)). KubeRay's `RayJob` resource integrates with Volcano (PR [ray-project/kuberay#3972](https://github.com/ray-project/kuberay/pull/3972), merged 2025-10-09) — `RayJob` + `volcano.sh/queue-name` label gives you gang-scheduled Ray clusters per job.
|
| 368 |
+
|
| 369 |
+
For Decoupled DiLoCo: **N RayJobs, each running one replica**, gang-scheduled via Volcano, sharing pseudo-grad through a `PersistentVolume` or in-cluster S3-compatible object store (MinIO).
|
| 370 |
+
|
| 371 |
+
**Cold start.** Pod schedule time depends on cluster state: seconds (pre-pulled image, free GPU node) to minutes (image pull + GPU node autoscale). Predictable on a steady-state cluster.
|
| 372 |
+
|
| 373 |
+
**$/GPU·hr.** **Whatever the underlying K8s cluster pays.** This is the *cheapest* tier in this audit if the user already runs a GPU K8s cluster (e.g., RunPod K8s, Lambda Cloud, OCI K8s, on-prem). Examples:
|
| 374 |
+
|
| 375 |
+
- RunPod community cloud K8s: ~$1.19/hr A100-80GB, ~$1.99/hr H100.
|
| 376 |
+
- Lambda K8s: ~$1.29/hr A100-40GB, ~$2.49/hr H100-80GB.
|
| 377 |
+
- On-prem owned hardware: amortized $0.50–$1.00 per A100/H100 hour.
|
| 378 |
+
|
| 379 |
+
**Max concurrent jobs.** Cluster capacity. Volcano's queue-based admission control + Kubernetes-native quotas govern this.
|
| 380 |
+
|
| 381 |
+
**Verified API for spinning up N parallel jobs** (Volcano `Job` + KubeRay pattern from the docs):
|
| 382 |
+
|
| 383 |
+
```yaml
|
| 384 |
+
# k8s manifest, one per DiLoCo replica
|
| 385 |
+
apiVersion: batch.volcano.sh/v1alpha1
|
| 386 |
+
kind: Job
|
| 387 |
+
metadata: {name: diloco-replica-0}
|
| 388 |
+
spec:
|
| 389 |
+
minAvailable: 1
|
| 390 |
+
schedulerName: volcano
|
| 391 |
+
queue: diloco-queue
|
| 392 |
+
tasks:
|
| 393 |
+
- replicas: 1
|
| 394 |
+
name: replica
|
| 395 |
+
template:
|
| 396 |
+
spec:
|
| 397 |
+
containers:
|
| 398 |
+
- name: trainer
|
| 399 |
+
image: myorg/composer-replication:latest
|
| 400 |
+
command: ["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint",
|
| 401 |
+
"--replica-id", "0", "--num-replicas", "4",
|
| 402 |
+
"--rendezvous-uri", "s3://minio.cluster.local/diloco/"]
|
| 403 |
+
resources:
|
| 404 |
+
limits: {nvidia.com/gpu: 1}
|
| 405 |
+
restartPolicy: OnFailure
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
…and the framework's `K8sExecutor` adapter does `kubectl apply -f` (or uses the Python K8s client) for each of N rendered manifests.
|
| 409 |
+
|
| 410 |
+
**Right abstraction.** Either one `volcano.batch.Job` per replica (simple, no Ray) or one `RayJob` per replica (overkill for DiLoCo, but useful if you want Ray Tune integration). One pod = one DiLoCo replica.
|
| 411 |
+
|
| 412 |
+
### 2.7 RunPod / Lambda / Vast.ai (honourable mentions)
|
| 413 |
+
|
| 414 |
+
Not in the original candidate list, but worth one paragraph each because they're the price-leaders for serverless GPUs:
|
| 415 |
+
|
| 416 |
+
- **RunPod Serverless / Pods.** Cheap on-demand A100/H100 (~$1.19–$2.17/hr A100-80GB; ~$1.99–$4.18/hr H100). REST API `POST /v2/{endpoint}/run` for serverless; SDK `runpod` for pods. No native multi-job network — same S3 workaround. **Strong third adapter candidate** for a cost-optimised deployment.
|
| 417 |
+
- **Lambda Cloud (Lambda Labs).** Bare metal hourly rentals, not a true serverless API. Programmatic launch via `lambdalabs` API. Outside the "serverless" framing.
|
| 418 |
+
- **Vast.ai.** Bidding-style spot market. API-driven launches. Cheapest per A100·hr in the market, but variable availability.
|
| 419 |
+
|
| 420 |
+
We do **not** include these as v0 adapters but document them as "next-up after Modal + HF" if the user wants further price compression.
|
| 421 |
+
|
| 422 |
+
---
|
| 423 |
+
|
| 424 |
+
## 3. The right abstraction: `composer_replication.diloco.serverless`
|
| 425 |
+
|
| 426 |
+
### 3.1 The core interface
|
| 427 |
+
|
| 428 |
+
```python
|
| 429 |
+
# composer_replication/diloco/serverless/_protocol.py
|
| 430 |
+
from __future__ import annotations
|
| 431 |
+
from abc import ABC, abstractmethod
|
| 432 |
+
from dataclasses import dataclass
|
| 433 |
+
from typing import Any, Iterator, Protocol
|
| 434 |
+
|
| 435 |
+
@dataclass(frozen=True)
|
| 436 |
+
class ReplicaSpec:
|
| 437 |
+
"""One DiLoCo replica's launch config. Mirrors `make_diloco_outer_loop()`'s
|
| 438 |
+
args (see composer_replication/diloco/__init__.py) plus a rendezvous_uri
|
| 439 |
+
for the object-store all-reduce backend."""
|
| 440 |
+
replica_id: int
|
| 441 |
+
num_replicas: int
|
| 442 |
+
rendezvous_uri: str # s3://, gs://, az://, hf://, file://
|
| 443 |
+
model_id: str # e.g. "Qwen/Qwen2.5-7B"
|
| 444 |
+
inner_optimizer: dict[str, Any] # serializable; reconstructed in worker
|
| 445 |
+
sync_every: int = 500
|
| 446 |
+
outer_lr: float = 0.7
|
| 447 |
+
outer_momentum: float = 0.9
|
| 448 |
+
outer_rounds: int = 100
|
| 449 |
+
extra_env: dict[str, str] | None = None
|
| 450 |
+
|
| 451 |
+
@dataclass(frozen=True)
|
| 452 |
+
class ReplicaHandle:
|
| 453 |
+
replica_id: int
|
| 454 |
+
backend: str # "modal" | "hfjobs" | "sagemaker" | ...
|
| 455 |
+
job_id: str
|
| 456 |
+
log_url: str | None = None
|
| 457 |
+
|
| 458 |
+
@dataclass(frozen=True)
|
| 459 |
+
class ReplicaResult:
|
| 460 |
+
replica_id: int
|
| 461 |
+
status: str # "completed" | "failed" | "preempted"
|
| 462 |
+
final_checkpoint_uri: str | None
|
| 463 |
+
metrics: dict[str, Any]
|
| 464 |
+
|
| 465 |
+
class ServerlessExecutor(Protocol):
|
| 466 |
+
"""Protocol any serverless backend implements to host Decoupled DiLoCo."""
|
| 467 |
+
|
| 468 |
+
def launch_replicas(self, specs: list[ReplicaSpec]) -> list[ReplicaHandle]: ...
|
| 469 |
+
def poll(self, handles: list[ReplicaHandle]) -> list[ReplicaHandle]: ...
|
| 470 |
+
def stream_logs(self, handle: ReplicaHandle) -> Iterator[str]: ...
|
| 471 |
+
def cancel(self, handles: list[ReplicaHandle]) -> None: ...
|
| 472 |
+
def collect(self, handles: list[ReplicaHandle], *,
|
| 473 |
+
timeout: float | None = None) -> list[ReplicaResult]: ...
|
| 474 |
+
|
| 475 |
+
@property
|
| 476 |
+
def backend_name(self) -> str: ...
|
| 477 |
+
|
| 478 |
+
@property
|
| 479 |
+
def supports_inter_replica_network(self) -> bool:
|
| 480 |
+
"""True iff backend natively connects replicas (e.g., Modal i6pn).
|
| 481 |
+
False = pseudo-grad must use rendezvous_uri object store. Default rendezvous
|
| 482 |
+
is *always* object-store regardless; this flag only unlocks an opt-in
|
| 483 |
+
same-backend fast path (see ModalExecutor(use_clustered_rdma=True))."""
|
| 484 |
+
...
|
| 485 |
+
```
|
| 486 |
+
|
| 487 |
+
Concrete adapters inherit from a small `BaseExecutor(ABC)` for cross-cutting retry/log/timeout, paralleling `composer_replication.trainer.composer_trainer`. `launch_replicas()` is partial-failure tolerant: on partial submit it returns handles for the K successful replicas with the failed one carrying `job_id=""` and a logged warning; the caller is responsible for cleanup via `cancel()`.
|
| 488 |
+
|
| 489 |
+
### 3.2 The object-store all-reduce (the secret weapon)
|
| 490 |
+
|
| 491 |
+
The whole point of "decoupled" DiLoCo is that the cross-replica primitive is just object-store I/O. We implement it at the framework layer, *not* at the executor layer, so every adapter gets it for free:
|
| 492 |
+
|
| 493 |
+
```python
|
| 494 |
+
# composer_replication/diloco/serverless/_rendezvous.py
|
| 495 |
+
import time, torch, fsspec
|
| 496 |
+
|
| 497 |
+
class ObjectStoreAllReduce:
|
| 498 |
+
"""Drop-in for `torchft.Manager.allreduce` over a shared object store.
|
| 499 |
+
|
| 500 |
+
Each round t:
|
| 501 |
+
(1) replica i writes {uri}/round-{t}/replica-{i}.pt
|
| 502 |
+
(2) all replicas barrier on count == num_replicas
|
| 503 |
+
(3) rank 0 averages, writes {uri}/round-{t}/avg.pt
|
| 504 |
+
(4) others read avg.pt, copy_ into the in-place tensor
|
| 505 |
+
(5) rank 0 GCs round-(t-1)
|
| 506 |
+
|
| 507 |
+
fsspec-backed so one path covers s3://, gs://, az://, hf://, file://.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
def __init__(self, replica_id, num_replicas, rendezvous_uri,
|
| 511 |
+
fsspec_kwargs=None, poll_s=2.0, timeout_s=600.0):
|
| 512 |
+
self.replica_id, self.num_replicas = replica_id, num_replicas
|
| 513 |
+
self.uri = rendezvous_uri.rstrip("/")
|
| 514 |
+
self.fs, _ = fsspec.url_to_fs(self.uri, **(fsspec_kwargs or {}))
|
| 515 |
+
self.poll, self.timeout, self._round = poll_s, timeout_s, 0
|
| 516 |
+
|
| 517 |
+
def allreduce(self, tensor):
|
| 518 |
+
t = self._round
|
| 519 |
+
my = f"{self.uri}/round-{t}/replica-{self.replica_id}.pt"
|
| 520 |
+
avg = f"{self.uri}/round-{t}/avg.pt"
|
| 521 |
+
|
| 522 |
+
with self.fs.open(my, "wb") as f:
|
| 523 |
+
torch.save(tensor.cpu(), f)
|
| 524 |
+
|
| 525 |
+
deadline = time.time() + self.timeout
|
| 526 |
+
while time.time() < deadline:
|
| 527 |
+
existing = [p for p in self.fs.ls(f"{self.uri}/round-{t}/")
|
| 528 |
+
if p.endswith(".pt") and "/replica-" in p]
|
| 529 |
+
if len(existing) >= self.num_replicas: break
|
| 530 |
+
time.sleep(self.poll)
|
| 531 |
+
else:
|
| 532 |
+
raise TimeoutError(f"barrier timeout at round {t}")
|
| 533 |
+
|
| 534 |
+
if self.replica_id == 0:
|
| 535 |
+
tensors = [torch.load(self.fs.open(f"{self.uri}/round-{t}/replica-{i}.pt", "rb"),
|
| 536 |
+
map_location="cpu") for i in range(self.num_replicas)]
|
| 537 |
+
torch.save(torch.stack(tensors).mean(dim=0), self.fs.open(avg, "wb"))
|
| 538 |
+
|
| 539 |
+
deadline = time.time() + self.timeout
|
| 540 |
+
while time.time() < deadline:
|
| 541 |
+
if self.fs.exists(avg):
|
| 542 |
+
tensor.copy_(torch.load(self.fs.open(avg, "rb"), map_location=tensor.device))
|
| 543 |
+
break
|
| 544 |
+
time.sleep(self.poll)
|
| 545 |
+
else:
|
| 546 |
+
raise TimeoutError(f"avg.pt timeout at round {t}")
|
| 547 |
+
|
| 548 |
+
if self.replica_id == 0 and t > 0:
|
| 549 |
+
try: self.fs.rm(f"{self.uri}/round-{t-1}/", recursive=True)
|
| 550 |
+
except Exception: pass
|
| 551 |
+
|
| 552 |
+
self._round += 1
|
| 553 |
+
return _DummyWork()
|
| 554 |
+
|
| 555 |
+
def should_commit(self): return True
|
| 556 |
+
def start_quorum(self, *_, **__): pass
|
| 557 |
+
@property
|
| 558 |
+
def current_step(self): return self._round
|
| 559 |
+
|
| 560 |
+
class _DummyWork:
|
| 561 |
+
def wait(self): pass
|
| 562 |
+
def get_future(self): pass
|
| 563 |
+
```
|
| 564 |
+
|
| 565 |
+
The `ObjectStoreAllReduce` mocks the torchft `Manager` interface — exactly what `make_diloco_outer_loop` already takes (see `composer_replication/diloco/__init__.py` lines 64–125). **No changes to the existing DiLoCo wrapper needed.**
|
| 566 |
+
|
| 567 |
+
### 3.3 Replica entrypoint
|
| 568 |
+
|
| 569 |
+
This is the script every adapter runs in its container:
|
| 570 |
+
|
| 571 |
+
```python
|
| 572 |
+
# composer_replication/diloco/serverless/replica_entrypoint.py
|
| 573 |
+
"""Run one Decoupled DiLoCo replica. Designed to be invoked as
|
| 574 |
+
|
| 575 |
+
python -m composer_replication.diloco.serverless.replica_entrypoint \
|
| 576 |
+
--replica-id N --num-replicas K --rendezvous-uri s3://... \
|
| 577 |
+
--model-id Qwen/Qwen2.5-7B --sync-every 500 --outer-rounds 100
|
| 578 |
+
"""
|
| 579 |
+
import argparse, os, torch
|
| 580 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 581 |
+
from composer_replication.diloco.serverless._rendezvous import ObjectStoreAllReduce
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def main() -> None:
|
| 585 |
+
p = argparse.ArgumentParser()
|
| 586 |
+
p.add_argument("--replica-id", type=int, required=True)
|
| 587 |
+
p.add_argument("--num-replicas", type=int, required=True)
|
| 588 |
+
p.add_argument("--rendezvous-uri", required=True)
|
| 589 |
+
p.add_argument("--model-id", required=True)
|
| 590 |
+
p.add_argument("--sync-every", type=int, default=500)
|
| 591 |
+
p.add_argument("--outer-rounds", type=int, default=100)
|
| 592 |
+
p.add_argument("--outer-lr", type=float, default=0.7)
|
| 593 |
+
args = p.parse_args()
|
| 594 |
+
|
| 595 |
+
from transformers import AutoModelForCausalLM
|
| 596 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.bfloat16).cuda()
|
| 597 |
+
inner_opt = torch.optim.AdamW(model.parameters(), lr=4e-4)
|
| 598 |
+
|
| 599 |
+
manager = ObjectStoreAllReduce(replica_id=args.replica_id,
|
| 600 |
+
num_replicas=args.num_replicas,
|
| 601 |
+
rendezvous_uri=args.rendezvous_uri)
|
| 602 |
+
outer = make_diloco_outer_loop(
|
| 603 |
+
manager=manager, model_fragments=[model], inner_optimizer=inner_opt,
|
| 604 |
+
outer_lr=args.outer_lr, outer_momentum=0.9, nesterov=True,
|
| 605 |
+
sync_every=args.sync_every,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
with outer:
|
| 609 |
+
for outer_round in range(args.outer_rounds):
|
| 610 |
+
for inner_step in range(args.sync_every):
|
| 611 |
+
# caller plugs in their data + loss; for v0 we use a sketch.
|
| 612 |
+
inner_opt.zero_grad(); ...; inner_opt.step()
|
| 613 |
+
# outer-loop sync fires automatically at sync_every step boundary.
|
| 614 |
+
|
| 615 |
+
# Push final checkpoint to rendezvous_uri/final/replica-N.pt
|
| 616 |
+
...
|
| 617 |
+
|
| 618 |
+
if __name__ == "__main__":
|
| 619 |
+
main()
|
| 620 |
+
```
|
| 621 |
+
|
| 622 |
+
### 3.4 Package layout
|
| 623 |
+
|
| 624 |
+
```
|
| 625 |
+
composer_replication/
|
| 626 |
+
└── diloco/
|
| 627 |
+
├── __init__.py # existing: make_diloco_outer_loop, torchft import
|
| 628 |
+
└── serverless/
|
| 629 |
+
├── __init__.py # re-exports
|
| 630 |
+
├── _protocol.py # ServerlessExecutor Protocol, ReplicaSpec, ReplicaHandle, ReplicaResult
|
| 631 |
+
├── _base.py # BaseExecutor(ABC) — common retry/log/timeout logic
|
| 632 |
+
├── _rendezvous.py # ObjectStoreAllReduce (the cross-cutting allreduce)
|
| 633 |
+
├── replica_entrypoint.py # the script every adapter runs in-container
|
| 634 |
+
├── modal/
|
| 635 |
+
│ ├── __init__.py # ModalExecutor
|
| 636 |
+
│ └── adapter.py
|
| 637 |
+
├── hfjobs/
|
| 638 |
+
│ ├── __init__.py # HFJobsExecutor
|
| 639 |
+
│ └── adapter.py
|
| 640 |
+
└── runpod/ # optional v0.1+
|
| 641 |
+
├── __init__.py
|
| 642 |
+
└── adapter.py
|
| 643 |
+
```
|
| 644 |
+
|
| 645 |
+
**v0 ships:** `Modal` + `HFJobs`. Both inherit from `BaseExecutor`, both delegate cross-replica state to `ObjectStoreAllReduce`. Symmetric implementation surface ≈ 250 lines per adapter.
|
| 646 |
+
|
| 647 |
+
**v0.1+ candidates** (add when needed): SageMaker, Vertex AI, Azure ML, RunPod, K8s/Volcano. The `Protocol` is stable; adding adapters is incremental.
|
| 648 |
+
|
| 649 |
+
### 3.5 What the user writes
|
| 650 |
+
|
| 651 |
+
```python
|
| 652 |
+
from composer_replication.diloco.serverless import (
|
| 653 |
+
ModalExecutor, HFJobsExecutor, ReplicaSpec
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
specs = [
|
| 657 |
+
ReplicaSpec(replica_id=i, num_replicas=4,
|
| 658 |
+
rendezvous_uri="s3://my-diloco-runs/2026-05-26/",
|
| 659 |
+
model_id="Qwen/Qwen2.5-7B",
|
| 660 |
+
inner_optimizer={"name": "AdamW", "lr": 4e-4},
|
| 661 |
+
sync_every=500, outer_rounds=100)
|
| 662 |
+
for i in range(4)
|
| 663 |
+
]
|
| 664 |
+
|
| 665 |
+
# Option A: all four replicas on Modal A100s
|
| 666 |
+
executor = ModalExecutor(gpu="A100-40GB", region=None, preemptible=True)
|
| 667 |
+
handles = executor.launch_replicas(specs)
|
| 668 |
+
results = executor.collect(handles)
|
| 669 |
+
|
| 670 |
+
# Option B: heterogeneous fleet — 2 on Modal, 2 on HF Jobs
|
| 671 |
+
modal_ex = ModalExecutor(gpu="A100-40GB")
|
| 672 |
+
hf_ex = HFJobsExecutor(flavor="a100-large")
|
| 673 |
+
modal_handles = modal_ex.launch_replicas(specs[:2])
|
| 674 |
+
hf_handles = hf_ex.launch_replicas(specs[2:])
|
| 675 |
+
# both groups read+write the SAME s3://... rendezvous URI — they DiLoCo together.
|
| 676 |
+
results = modal_ex.collect(modal_handles) + hf_ex.collect(hf_handles)
|
| 677 |
+
```
|
| 678 |
+
|
| 679 |
+
The "heterogeneous fleet" pattern is the **point** of Decoupled DiLoCo as articulated in the user brief. Modal + HF together is a meaningful test that tells us both adapters work and the rendezvous protocol is backend-agnostic.
|
| 680 |
+
|
| 681 |
+
---
|
| 682 |
+
|
| 683 |
+
## 4. Cross-cutting design decisions
|
| 684 |
+
|
| 685 |
+
### 4.1 Why object-store rendezvous is the default (even on Modal)
|
| 686 |
+
|
| 687 |
+
Even though Modal supports `@modal.experimental.clustered` with RDMA, **the framework default is object-store-based pseudo-gradient exchange.** Reasons:
|
| 688 |
+
|
| 689 |
+
1. **Backend portability.** Same code runs on Modal, HF, SageMaker, Vertex, Azure, K8s. Adding a new backend is implementing 6 methods (`launch_replicas`, `poll`, `stream_logs`, `cancel`, `collect`, `backend_name`) — *zero* changes to the rendezvous layer.
|
| 690 |
+
2. **Cost asymmetry.** RDMA-class networking on Modal requires `@clustered(rdma=True)` which gates on 8 GPUs/node and tighter scheduling — *more* expensive than 4 separate `@function` invocations of 1 GPU each.
|
| 691 |
+
3. **DiLoCo's communication is ridiculous overkill for RDMA.** 2 GB every 10 minutes = ~3 Mbps average. S3 GET/PUT at 10 MB/s does it in ~3 min — well under the 10 min outer-round budget.
|
| 692 |
+
4. **Failure decoupling.** A clustered-RDMA failure aborts the whole job (gang-scheduled). Object-store rendezvous tolerates a missing replica (skip its tensor in the average) — better matches DiLoCo's natural fault tolerance.
|
| 693 |
+
|
| 694 |
+
The opt-in escape hatch: `ModalExecutor(use_clustered_rdma=True)` dispatches to `@modal.experimental.clustered(rdma=True)` and skips object-store. This is for the user who wants Modal-only, max-throughput, single-region runs. It's *not* the default and *not* what we test against.
|
| 695 |
+
|
| 696 |
+
### 4.2 Rendezvous URI scheme support
|
| 697 |
+
|
| 698 |
+
`fsspec` covers all the storage backends we need:
|
| 699 |
+
|
| 700 |
+
| Scheme | Backend | Used for |
|
| 701 |
+
|---|---|---|
|
| 702 |
+
| `s3://` | `s3fs` | SageMaker default; cheapest for AWS-centric runs |
|
| 703 |
+
| `gs://` | `gcsfs` | Vertex AI default |
|
| 704 |
+
| `az://` | `adlfs` | Azure ML default |
|
| 705 |
+
| `hf://` | `huggingface_hub.HfFileSystem` | HF Jobs preferred (Volume mount makes it look like local fs already) |
|
| 706 |
+
| `file://` | builtin | local single-host tests; CI |
|
| 707 |
+
|
| 708 |
+
The framework picks the *right* default per-executor (Modal → `s3://`, HF → `hf://`, SageMaker → `s3://`, etc.) but always allows override.
|
| 709 |
+
|
| 710 |
+
### 4.3 Failure model
|
| 711 |
+
|
| 712 |
+
**Replica failure mid-round.** The barrier in `ObjectStoreAllReduce` has a configurable timeout (default 600 s). If a replica doesn't write its file by then, rank-0 (the averager) has two options governed by `replica_failure_policy`:
|
| 713 |
+
|
| 714 |
+
- `"strict"` (default): TimeoutError → all replicas abort. Resume from last committed checkpoint.
|
| 715 |
+
- `"skip"`: rank-0 averages over what's there, includes a `--num-survivors=K` annotation in `avg.pt`. Other replicas read this and continue. DiLoCo paper §4.5 reports robustness to occasional missing workers; this matches that.
|
| 716 |
+
|
| 717 |
+
**Whole-cluster failure.** Outer rounds checkpoint to `{rendezvous_uri}/checkpoint-{t}/`; restart sets `args.restart_from=T` and skips ahead.
|
| 718 |
+
|
| 719 |
+
### 4.4 What we explicitly do NOT do
|
| 720 |
+
|
| 721 |
+
- **No cross-job NCCL.** Even on Modal, even with `clustered`, the framework uses object-store rendezvous. (Modal `clustered` is exposed only via the explicit opt-in flag.)
|
| 722 |
+
- **No DDP/FSDP across replicas.** Each replica is its own self-contained DDP/FSDP world; replicas talk to each other only via the outer-loop. This is the *core* of DiLoCo.
|
| 723 |
+
- **No "control plane" service.** No coordinator process, no scheduler container. The object store *is* the coordinator (writes are the messages, file-existence is the synchronization). This is what makes the system work across heterogeneous executors with no shared infra.
|
| 724 |
+
- **No Modal-specific or HF-specific dependencies in `composer_replication.diloco`.** Adapter dependencies (`modal`, `huggingface_hub`) are imported lazily inside the adapter modules, exactly how `torchft` is imported lazily in `composer_replication/diloco/__init__.py` today.
|
| 725 |
+
|
| 726 |
+
---
|
| 727 |
+
|
| 728 |
+
## 5. Risks and mitigations
|
| 729 |
+
|
| 730 |
+
| Risk | Likelihood | Mitigation |
|
| 731 |
+
|---|---|---|
|
| 732 |
+
| Object-store latency dominates outer-round wallclock for large models | M | For 70B+, add `fsspec` parallel-upload (multipart) + bf16 quantize on-write. Most outer rounds are 7B-scale where 2 GB transfer is well under 1 min. |
|
| 733 |
+
| Rank-0 replica crashes mid-average → orphaned barrier | L | Add a `lock-{t}.json` heartbeat with TTL; any non-zero replica that sees a stale lock can take over. v1+. |
|
| 734 |
+
| Modal + HF cost arbitrage misleading because preemption rates differ | M | Track preemption-rate per backend, surface in `ReplicaResult.metrics`. User-visible. |
|
| 735 |
+
| HF Jobs has no public per-account concurrency cap → may hit a hidden limit at N=8 | L | Add exponential-backoff retry around `run_job`; cap `max_concurrent_launches` configurable per executor. |
|
| 736 |
+
| AWS / GCP / Azure premiums make their adapters effectively price-uncompetitive | H (already true) | Be honest in docs (this doc). Recommend Modal + HF for cost-sensitive users; cloud-vendor adapters for users who *must* run there for compliance or credits. |
|
| 737 |
+
| Rendezvous bucket becomes a security choke point (model weights exposed) | M | Document that `rendezvous_uri` should be a private bucket with replica-only IAM/principals. Provide `RendezvousAccessPolicy` helper that emits boto3/gcloud/az IAM JSON. |
|
| 738 |
+
| Modal `@experimental.clustered` API churn (it's experimental) | M | Default path doesn't depend on `clustered`. Fall-back path uses regular `@function`. Document the opt-in clearly. |
|
| 739 |
+
| torchft sign-convention regression | L | Already pinned with the unit test in spike 008 (see `spikes/008-streaming-diloco/tests/test_diloco_smoke.py::test_diloco_pseudogradient_sign_convention`). The serverless layer doesn't touch this — it only swaps in a different `Manager.allreduce` impl. |
|
| 740 |
+
|
| 741 |
+
---
|
| 742 |
+
|
| 743 |
+
## 6. Validation plan
|
| 744 |
+
|
| 745 |
+
Three smoke tests, in order of cost:
|
| 746 |
+
|
| 747 |
+
1. **Spike 009-A (free, ≤30 min):** `LocalProcessExecutor` + `ObjectStoreAllReduce` with `file://` rendezvous. Two in-process replicas DiLoCo-train a 0.5B model on MNIST-equivalent text data. Asserts the rendezvous protocol works.
|
| 748 |
+
2. **Spike 009-B (Modal, ≤$5):** `ModalExecutor` × 2 replicas, A100-40GB each, Qwen2.5-0.5B, 50 inner steps × 2 outer rounds. Asserts the Modal adapter launches, replicas find each other through S3 rendezvous, and pseudo-gradients average correctly. Cost: ~30 min × $2.10 × 2 = $2.10 + setup overhead, comfortable under cap.
|
| 749 |
+
3. **Spike 009-C (heterogeneous, ≤$10):** 1 Modal A100 + 1 HF Jobs `a100-large`. Same model, 2 outer rounds. Validates that rendezvous works across backends — the key claim of Decoupled DiLoCo. Cost: ~30 min × ($2.10 + $2.50) = ~$2.30, plus per-job startup.
|
| 750 |
+
|
| 751 |
+
Each spike has a verdict.md following the conventions from `spikes/008-streaming-diloco/`.
|
| 752 |
+
|
| 753 |
+
---
|
| 754 |
+
|
| 755 |
+
## 7. References (primary sources, all cited above)
|
| 756 |
+
|
| 757 |
+
- **DiLoCo paper:** Douillard et al., "DiLoCo: Distributed Low-Communication Training of Language Models," arXiv:2311.08105 (2023). <https://arxiv.org/abs/2311.08105>
|
| 758 |
+
- **Streaming DiLoCo paper:** Liu et al., "Streaming DiLoCo with overlapping communication," 2025. <https://arxiv.org/abs/2501.18512>
|
| 759 |
+
- **torchft `local_sgd.DiLoCo`:** <https://github.com/meta-pytorch/torchft/blob/main/torchft/local_sgd.py>
|
| 760 |
+
- **Modal multi-node clusters:** <https://modal.com/docs/guide/multi-node-training>
|
| 761 |
+
- **Modal cluster networking (i6pn):** <https://modal.com/docs/guide/private-networking>
|
| 762 |
+
- **Modal pricing:** <https://modal.com/pricing>
|
| 763 |
+
- **Modal GPU options:** <https://modal.com/docs/guide/gpu>
|
| 764 |
+
- **HF Jobs overview:** <https://huggingface.co/docs/hub/en/jobs>
|
| 765 |
+
- **HF Jobs pricing:** <https://huggingface.co/docs/hub/jobs-pricing>
|
| 766 |
+
- **HF Jobs Python API:** <https://huggingface.co/docs/huggingface_hub/en/guides/jobs>
|
| 767 |
+
- **HF Jobs reference:** <https://huggingface.co/docs/huggingface_hub/main/en/package_reference/jobs>
|
| 768 |
+
- **AWS SageMaker pricing:** <https://aws.amazon.com/sagemaker/ai/pricing/>
|
| 769 |
+
- **AWS SageMaker `CreateTrainingJob` API:** <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>
|
| 770 |
+
- **AWS SageMaker SMDDP:** <https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-framework-estimator.html>
|
| 771 |
+
- **AWS SageMaker warm pools:** <https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html>
|
| 772 |
+
- **GCP Vertex AI compute config:** <https://cloud.google.com/vertex-ai/docs/training/configure-compute>
|
| 773 |
+
- **GCP Vertex AI training SKUs:** <https://cloud.google.com/skus/sku-groups/vertex-training>
|
| 774 |
+
- **GCP Vertex AI pricing:** <https://cloud.google.com/vertex-ai/pricing>
|
| 775 |
+
- **Azure ML pricing:** <https://azure.microsoft.com/en-us/pricing/details/machine-learning/>
|
| 776 |
+
- **Azure ML PyTorch SDK v2 guide:** <https://learn.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch>
|
| 777 |
+
- **Azure NDasrA100_v4 spec:** <https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/ndasra100v4-series>
|
| 778 |
+
- **Azure NCads H100 v5 spec:** <https://learn.microsoft.com/en-us/azure/virtual-machines/ncads-h100-v5>
|
| 779 |
+
- **Volcano:** <https://volcano.sh/en/docs/unified_scheduling/>
|
| 780 |
+
- **Volcano network-topology-aware scheduling:** <https://volcano.sh/en/docs/network_topology_aware_scheduling/>
|
| 781 |
+
- **KubeRay + Volcano integration:** <https://docs.ray.io/en/latest/cluster/kubernetes/k8s-ecosystem/volcano.html>
|
| 782 |
+
- **KubeRay RayJob+Volcano PR:** <https://github.com/ray-project/kuberay/pull/3972>
|
| 783 |
+
|
| 784 |
+
Internal references (in this repo):
|
| 785 |
+
|
| 786 |
+
- `docs/research/MODAL_RECONNAISSANCE.md` — pricing/cold-start audit for Modal smoke runs.
|
| 787 |
+
- `docs/research/DILOCO_RECONNAISSANCE.md` — DiLoCo implementation candidates audit.
|
| 788 |
+
- `docs/adrs/ADR-001-gpu-venue.md` — local-vs-cloud GPU decision for smoke phase.
|
| 789 |
+
- `docs/adrs/ADR-003-diloco-impl.md` — torchft choice + sign convention.
|
| 790 |
+
- `composer_replication/diloco/__init__.py` — existing `make_diloco_outer_loop` wrapper this design plugs into without modification.
|
| 791 |
+
- `spikes/008-streaming-diloco/` — the existing in-process DiLoCo smoke that the serverless adapter inherits sign-convention test from.
|
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Replaysim Normalization Reconnaissance
|
| 2 |
+
|
| 3 |
+
**Status:** Recon · **Feeds:** ADR-004, V5 "replaysim with normalization"
|
| 4 |
+
**Author:** subagent (delegated audit) · **Date:** 2026-05-25
|
| 5 |
+
**Sources:** GitHub REST API metadata + DeepWiki structured indexes of each repo's primary source. All repo metadata cited below was pulled from `api.github.com/repos/<owner>/<name>` directly.
|
| 6 |
+
|
| 7 |
+
## TL;DR
|
| 8 |
+
|
| 9 |
+
| Library | License | Last push | ★ | Verdict |
|
| 10 |
+
|---|---|---|---|---|
|
| 11 |
+
| **data-juicer** | Apache-2.0 | **2026-05-25** | 6.4k | ✅ **RECOMMENDED** — the only candidate with a class-based op-graph that *natively* understands `messages: [{role, content}]`, multi-turn dialog, and DPO-pair (`chosen`/`rejected`) preference samples as **first-class data formats**, with a `pair_preference_mapper` operator that maps directly onto our `extract_dpo_pairs` output. |
|
| 12 |
+
| **distilabel** | Apache-2.0 | 2026-05-25 | 3.2k | Strong runner-up. DAG pipeline, native chat-message format, built-in `FormatChatGenerationDPO`. But it is primarily a *generation orchestrator* and would force us to rewrite our existing OpenRouter teacher orchestration as Distilabel `LLM` subclasses. Larger refactor surface. |
|
| 13 |
+
| **datatrove** | Apache-2.0 | 2026-05-06 | 3.1k | ❌ **Deal-breaker.** `Document` dataclass is `text: str + metadata: dict`. All filters/dedup operate on flat `doc.text`. Multi-turn is only supported in the *generation* (`InferenceRunner.rollout_fn`) path, not the normalization/filter path. Forces lossy chat→string flattening. |
|
| 14 |
+
| **NeMo-Curator** | Apache-2.0 | 2026-05-25 | 1.6k | Strong on scale (Ray + Xenna + GPU), supports streaming and DPO via `generate_two_turn_prompt`. But: semantic dedup, fuzzy dedup, and classifier filters all *require GPUs*; CPU-only install drops most of the differentiating ops. Heavy framework for the size of replaysim. |
|
| 15 |
+
| **lilac** | Apache-2.0 | **archived 2024-03-19** | 1.1k | ❌ **Dead.** `databricks/lilac` repo `"archived": true`. The current `lilacai/lilac` is a 2-star squatter stub created Nov 2025. Do not adopt. |
|
| 16 |
+
|
| 17 |
+
**Recommendation:** Adopt **data-juicer** as the normalization op-graph layer wrapped around `replay_trace` → `extract_dpo_pairs`. Estimated integration cost: **~250–400 LOC** in `composer_replication.replaysim` for an adapter + 1 YAML recipe.
|
| 18 |
+
|
| 19 |
+
**Critical chat-template question answered:** data-juicer is the only audited library whose *filtering and normalization operators* (not just its generation operators) operate directly on a structured `messages: [{role, content}]` format and on `chosen`/`rejected` preference-pair format. The other three candidates either flatten to text (datatrove), only handle chat in the generation path (datatrove again), or treat chat as a generation output to be assembled rather than a structured object to be filtered (NeMo-Curator, distilabel partly).
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 1. Audit Methodology
|
| 24 |
+
|
| 25 |
+
For each candidate, primary-source data was collected from:
|
| 26 |
+
|
| 27 |
+
1. `https://api.github.com/repos/<owner>/<name>` for license, `pushed_at`, `archived`, stars, forks, topics — these are authoritative GitHub metadata, not scraped.
|
| 28 |
+
2. DeepWiki structured indexes of each repo's source tree for: op model, data structures (`Document` / `Sample` / `Step`), conversation/DPO support in filtering vs. generation paths, GPU dependencies.
|
| 29 |
+
3. README confirmation through the GitHub API for transferred-org redirects.
|
| 30 |
+
|
| 31 |
+
No secondary sources, no marketing pages, no blog posts.
|
| 32 |
+
|
| 33 |
+
Two facts to flag up front because they materially change the candidate set:
|
| 34 |
+
|
| 35 |
+
- `modelscope/data-juicer` redirects to **`datajuicer/data-juicer`**. The team spun out of ModelScope into a dedicated `datajuicer` org. Same code, just a transferred name — `pushed_at` is current.
|
| 36 |
+
- `NVIDIA/NeMo-Curator` redirects to **`NVIDIA-NeMo/Curator`**. Same situation — moved into the dedicated `NVIDIA-NeMo` org in 2025.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## 2. Per-Candidate Audit
|
| 41 |
+
|
| 42 |
+
### 2.1 datatrove (huggingface)
|
| 43 |
+
|
| 44 |
+
| Dimension | Value |
|
| 45 |
+
|---|---|
|
| 46 |
+
| Repo | `huggingface/datatrove` |
|
| 47 |
+
| License | Apache-2.0 |
|
| 48 |
+
| Created | 2023-06-14 |
|
| 49 |
+
| Last push | **2026-05-06** |
|
| 50 |
+
| Stars / Forks | 3068 / 266 |
|
| 51 |
+
| Commits | 725 (default branch) |
|
| 52 |
+
| Maturity | Production. Used to build FineWeb. Active. |
|
| 53 |
+
|
| 54 |
+
**Op model.** Class-based **linear pipeline** of `PipelineStep` instances. `PipelineStep.run(data: DocumentsPipeline, rank: int, world_size: int) -> DocumentsPipeline` where `DocumentsPipeline` is an iterator of `Document` objects. Steps are composed by Python list concatenation, not a DAG — branching/joining requires manual orchestration.
|
| 55 |
+
|
| 56 |
+
**Multi-turn / chat-template support — DEAL-BREAKER.**
|
| 57 |
+
|
| 58 |
+
The `Document` dataclass (`src/datatrove/data.py`) is:
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
@dataclass
|
| 62 |
+
class Document:
|
| 63 |
+
text: str
|
| 64 |
+
id: str
|
| 65 |
+
media: list[Media] # placeholder, "for future uses, currently not used"
|
| 66 |
+
metadata: dict
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
There is **no `messages` field**. Every built-in filter (e.g., `C4QualityFilter`, `LanguageFilter`, `GopherQualityFilter`) and every built-in dedup op (`MinhashDedup*`, `SentenceDedup*`, `BloomFilter`) operates on `doc.text` as a flat string.
|
| 70 |
+
|
| 71 |
+
Multi-turn does appear, but **only in the generation path** (`InferenceRunner` + user-supplied `rollout_fn(doc, generate)`), where the user constructs `{"messages": [{"role": ..., "content": ...}]}` payloads themselves. Once the generation completes, the result is stuffed back into `doc.text` (or `doc.metadata`) and downstream filters again see flat text.
|
| 72 |
+
|
| 73 |
+
For our use case — normalizing already-generated multi-turn DPO pairs with `chosen`/`rejected` chat structures and tool calls — this means we'd have to:
|
| 74 |
+
|
| 75 |
+
1. Serialize `messages` into a flat string (`<|im_start|>user...`).
|
| 76 |
+
2. Run datatrove filters on the serialized string.
|
| 77 |
+
3. Re-parse back into `messages` afterward.
|
| 78 |
+
|
| 79 |
+
Tool-call structure (`{"role": "tool", "tool_call_id": ...}`, `tool_calls: [...]`) does not survive that round-trip cleanly without custom serialization on both sides. Per the user's hard requirement — "if only flat text, that's a deal-breaker" — datatrove fails here.
|
| 80 |
+
|
| 81 |
+
**Streaming.** Yes. `HuggingFaceDatasetReader(streaming=True)` and the iterator-based `PipelineStep.run` mean we can pipe documents through during generation. Streaming is fine.
|
| 82 |
+
|
| 83 |
+
**GPU.** None of the *normalization* ops require GPU. MinHash dedup is CPU. Only the `InferenceRunner` path needs a GPU (vLLM/SGLang backend) and we don't need that — we'd be calling OpenRouter, not running local models.
|
| 84 |
+
|
| 85 |
+
**Integration cost.** Moot — the chat-template gap is the deal-breaker.
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
### 2.2 data-juicer (datajuicer org, formerly modelscope)
|
| 90 |
+
|
| 91 |
+
| Dimension | Value |
|
| 92 |
+
|---|---|
|
| 93 |
+
| Repo | `datajuicer/data-juicer` (redirect target of the legacy `modelscope/data-juicer`) |
|
| 94 |
+
| License | Apache-2.0 |
|
| 95 |
+
| Created | 2023-08-01 |
|
| 96 |
+
| Last push | **2026-05-25** (most recent of all candidates) |
|
| 97 |
+
| Stars / Forks | 6444 / 373 |
|
| 98 |
+
| Maturity | Production. Active core team (Alibaba/ModelScope-spinout). Most stars of the candidate set. Has its own conference papers and a docs site at `datajuicer.github.io/data-juicer`. |
|
| 99 |
+
|
| 100 |
+
**Op model.** Class-based DAG of **operators ("Ops")** organized as **mappers**, **filters**, **deduplicators**, and **selectors**. Each Op is a Python class subclassing `Mapper`/`Filter`/`Deduplicator`. Pipelines are declared as YAML recipes (`process: [- op_name: { args }, ...]`) and executed by the `Executor` (default Ray-distributed; also a local Pandas-backed mode). Conditional branching through `OpFusion` and `Adapter` modules is supported, and there is a Ray-Data executor for true streaming.
|
| 101 |
+
|
| 102 |
+
**Multi-turn / chat-template support — NATIVE.** This is the discriminator.
|
| 103 |
+
|
| 104 |
+
Data-juicer has a **first-class conversation schema**, supporting *both*:
|
| 105 |
+
1. OpenAI-style `messages: [{role, content}]`
|
| 106 |
+
2. A "Data-Juicer format" `{query, response, history: [[q, r], ...]}`
|
| 107 |
+
|
| 108 |
+
It exposes operators that are *purpose-built* for dialog/preference data:
|
| 109 |
+
|
| 110 |
+
- `dialog_intent_detection_mapper`
|
| 111 |
+
- `dialog_sentiment_detection_mapper`
|
| 112 |
+
- `dialog_sentiment_intensity_mapper`
|
| 113 |
+
- `dialog_topic_detection_mapper`
|
| 114 |
+
- `pair_preference_mapper` — **directly relevant**: ingests a `(prompt, chosen)` and synthesizes/refines a `rejected_response` plus a `reason` field. This is exactly the schema produced by our `extract_dpo_pairs`.
|
| 115 |
+
- `query_intent_detection_mapper`, `query_sentiment_detection_mapper`, `query_topic_detection_mapper`
|
| 116 |
+
- `optimize_qa_mapper`, `optimize_query_mapper`, `optimize_response_mapper` — refine individual fields without flattening the whole conversation.
|
| 117 |
+
|
| 118 |
+
Tool-call structure: data-juicer's conversation schema preserves arbitrary keys per message (because it operates on dict-of-lists Arrow tables), so `tool_call_id`, `tool_calls`, `name`, etc. survive through filters as long as no operator explicitly drops them. This is structurally safe — confirmed by the operator code only reading `role`/`content` and forwarding the rest.
|
| 119 |
+
|
| 120 |
+
**Streaming.** Partial. The default executor is batch on Arrow/HF datasets, but data-juicer integrated with **Ray Data** for distributed/streaming processing, and the README references "streaming JSON reader patches integrated by Apache Arrow." For our scale (≤100k DPO pairs per run), batch is fine; for true online normalization during multi-teacher generation, the Ray executor handles it — but a simpler approach is to wrap each `replay_trace` rollout's output into a tiny in-memory dataset and run the recipe per-batch (mini-batch streaming).
|
| 121 |
+
|
| 122 |
+
**GPU.** Only needed for image/video/multi-modal ops and for the LLM-API mappers when configured to run a *local* model. Every op we care about for replaysim — `pair_preference_mapper`, dialog detection mappers, `text_length_filter`, `language_id_score_filter`, MinHash dedup, etc. — is CPU-OK or calls a remote API (which is exactly our existing OpenRouter pattern). Importantly, **MinHash and exact dedup in data-juicer do not require GPU**, unlike NeMo-Curator's fuzzy/semantic dedup.
|
| 123 |
+
|
| 124 |
+
**Integration cost into `composer_replication.replaysim`.** Estimated ~250–400 LOC, breakdown:
|
| 125 |
+
|
| 126 |
+
- Adapter `replaysim/normalize.py`: ~80–120 LOC. Wraps a `DJDataset` (data-juicer's dataset abstraction), exposes `normalize_dpo_batch(pairs: list[DPOPair]) -> list[DPOPair]`.
|
| 127 |
+
- YAML recipe `replaysim/recipes/dpo_normalize.yaml`: ~40 LOC declarative.
|
| 128 |
+
- Hook in `teacher_replay.py` after `extract_dpo_pairs` and before final write: ~20 LOC.
|
| 129 |
+
- New tests `tests/replaysim/test_normalize.py`: ~80–120 LOC.
|
| 130 |
+
- ADR-004 update + module docs: ~20 LOC.
|
| 131 |
+
|
| 132 |
+
Dependency footprint: `pip install py-data-juicer` pulls in `datasets`, `pyarrow`, `loguru`, `jsonargparse`, optionally `ray`. We already have `datasets`/`pyarrow` indirectly from HF stack.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
### 2.3 NeMo-Curator (NVIDIA-NeMo)
|
| 137 |
+
|
| 138 |
+
| Dimension | Value |
|
| 139 |
+
|---|---|
|
| 140 |
+
| Repo | `NVIDIA-NeMo/Curator` (redirect target of `NVIDIA/NeMo-Curator`) |
|
| 141 |
+
| License | Apache-2.0 |
|
| 142 |
+
| Created | 2024-03-14 |
|
| 143 |
+
| Last push | **2026-05-25** |
|
| 144 |
+
| Stars / Forks | 1584 / 274 |
|
| 145 |
+
| Maturity | Production at NVIDIA scale. Built for pre-training-corpus curation (Nemotron / Nemotron-4). |
|
| 146 |
+
|
| 147 |
+
**Op model.** Task-centric distributed processing, built on **Ray** + the **Xenna** executor. Stages are class-based, composed into pipelines, executed by `XennaExecutor` in either `streaming` or `batch` mode. Closer to Spark/Ray-Data than to a Python list of steps.
|
| 148 |
+
|
| 149 |
+
**Multi-turn / chat-template support — partial, generation-side only.** Curator has model-specific formatters (`Mixtral8x7BFormatter`, `NemotronFormatter`) that *render* multi-turn dialogue into a flat prompt string for the target model's chat template. There is `generate_dialogue` for multi-turn synthesis and `generate_two_turn_prompt` for DPO-style preference pairs. **But**: like datatrove, the *filtering* and *deduplication* stages do not have first-class conversation/preference operators — they treat the data as text after rendering. Tool-call preservation is not addressed in the public API.
|
| 150 |
+
|
| 151 |
+
**Streaming.** Yes — `XennaExecutor(execution_mode="streaming")` is a first-class option.
|
| 152 |
+
|
| 153 |
+
**GPU — significant cost.** Curator's discriminating features all require GPUs:
|
| 154 |
+
|
| 155 |
+
- **Semantic deduplication** — GPU-only, embedding generation + clustering. "Not supported for CPU-only processing."
|
| 156 |
+
- **Fuzzy deduplication** (MinHash + LSH) — GPU backend (cuDF/cuML), not CPU.
|
| 157 |
+
- **Classifier filters** (domain / quality / safety via `DistributedDataClassifier`) — GPU clusters.
|
| 158 |
+
- **Image curation modules** — GPU.
|
| 159 |
+
|
| 160 |
+
CPU-only install supports basic text filters and exact dedup, but *that's the same surface area we'd get from data-juicer without the dependency weight*. If we are not running on a GPU cluster, NeMo-Curator's value proposition collapses.
|
| 161 |
+
|
| 162 |
+
**Integration cost.** ~600–900 LOC plus operational cost: a Ray cluster setup, GPU nodes if we want the differentiating features. For replaysim's scale (a few thousand DPO pairs per run), this is overkill.
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
### 2.4 distilabel (argilla-io)
|
| 167 |
+
|
| 168 |
+
| Dimension | Value |
|
| 169 |
+
|---|---|
|
| 170 |
+
| Repo | `argilla-io/distilabel` |
|
| 171 |
+
| License | Apache-2.0 |
|
| 172 |
+
| Created | 2023-10-16 |
|
| 173 |
+
| Last push | **2026-05-25** |
|
| 174 |
+
| Stars / Forks | 3230 / 242 |
|
| 175 |
+
| Maturity | Production. Argilla is now part of HF; project remains active under argilla-io. |
|
| 176 |
+
|
| 177 |
+
**Op model.** **DAG pipeline** of `Step` and `Task` (Task = Step with an LLM). Each step declares `inputs: list[str]`, `outputs: list[str]`, and `process(*inputs) -> Generator[outputs]`. Steps are wired via `>>` operator. Resource declarations (`StepResources(replicas=N, gpus=M)`) handle scaling, optionally on Ray.
|
| 178 |
+
|
| 179 |
+
**Multi-turn / chat-template support — NATIVE on the generation side, partial on the normalization side.**
|
| 180 |
+
|
| 181 |
+
- `ChatGeneration` task accepts OpenAI-format `messages: [{role, content}]` natively.
|
| 182 |
+
- `FormatTextGenerationDPO` and `FormatChatGenerationDPO` produce the exact `{prompt, chosen, rejected, ratings, reason}` schema we want.
|
| 183 |
+
- `UltraFeedback` task is the canonical preference-rating step.
|
| 184 |
+
- `DeitaFiltering` and `MinHashDedup` are the only filtering/dedup steps; they operate on text fields rather than on structured `messages`. Tool-call structure is preserved as long as no step explicitly normalizes it (like data-juicer, by virtue of dict-of-fields semantics) — but there isn't a `pair_preference_mapper` analogue that operates on `messages` directly.
|
| 185 |
+
|
| 186 |
+
**Streaming.** Supports streaming generation per LLM (e.g., `AnthropicLLM` streams tokens). Pipeline-level execution is batch-of-batches; you can `.run(parameters={...})` and consume outputs as they materialize.
|
| 187 |
+
|
| 188 |
+
**GPU.** Only when steps choose to run a local LLM (vLLM, transformers). API-based steps (OpenAI, Anthropic, Mistral, OpenRouter via OpenAI-compat) are CPU-only.
|
| 189 |
+
|
| 190 |
+
**Integration cost — large but high overlap.** Distilabel would *replace* much of `teacher_replay.py`, not just normalize after it:
|
| 191 |
+
|
| 192 |
+
- Rewrite multi-teacher OpenRouter calls as a `Pipeline` of `Task`s subclassing distilabel's `LLM` interface (or use the `OpenAILLM` wrapper pointed at OpenRouter): ~300–500 LOC delta.
|
| 193 |
+
- Re-express `extract_dpo_pairs` as a custom `Task` or use `FormatChatGenerationDPO`: ~100–150 LOC.
|
| 194 |
+
- Migrate trace plumbing into distilabel's `GeneratorStep`/`Task` DAG: ~150 LOC.
|
| 195 |
+
- Tests + docs: ~150 LOC.
|
| 196 |
+
|
| 197 |
+
Total **~700–900 LOC** and a meaningful refactor of teacher orchestration. The win is that we'd get a real DAG runtime, retries, caching, and Argilla-integration for free. The lose is that we get *coupled* to distilabel's `LLM`/`Task` abstractions for the entire generation pipeline, not just a normalization op-graph wrapped around it.
|
| 198 |
+
|
| 199 |
+
This is a strategic decision the user phrased as: "see if we can leverage [a normalization library] to **normalize the data while also making the replaysim dataset generation**." Distilabel takes the broader interpretation — replace replaysim's generation with a distilabel pipeline. That is a bigger commitment than this recon was scoped to recommend.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
### 2.5 lilac
|
| 204 |
+
|
| 205 |
+
**STATUS: dead. Do not adopt.**
|
| 206 |
+
|
| 207 |
+
- `databricks/lilac`: `"archived": true`, last push **2024-03-19**, license Apache-2.0. Repo says "Curate better data for LLMs." The Databricks acquisition (April 2024) absorbed it into Databricks Mosaic AI; the OSS project was archived shortly after.
|
| 208 |
+
- `lilacai/lilac`: created **2025-11-14** by a user account `lilacai`, 2 stars, 0 forks, no license, description says "Thee Eclipse - Hackerone: @theeeclipse." This is a **squatter / unrelated stub**, not the original lilac.
|
| 209 |
+
- No actively maintained successor with the original lilac code base outside Databricks' proprietary platform.
|
| 210 |
+
|
| 211 |
+
---
|
| 212 |
+
|
| 213 |
+
## 3. Recommendation: data-juicer
|
| 214 |
+
|
| 215 |
+
### 3.1 Why
|
| 216 |
+
|
| 217 |
+
1. **Only candidate with native conversation + preference-pair operators in the *normalization* path**, not just the generation path. `pair_preference_mapper` is a near-perfect fit for the output of `extract_dpo_pairs`.
|
| 218 |
+
2. **Tool-call structure is preserved** because operators read specific fields and forward the rest of the dict — confirmed by the operator schema design.
|
| 219 |
+
3. **No GPU required** for the operators we'd actually use (preference, dialog, length, language-id, MinHash dedup). Matches our OpenRouter-API-driven, CPU-friendly architecture.
|
| 220 |
+
4. **YAML-recipe style** lets us version the normalization graph as a config artifact alongside the recon doc, instead of as Python code that drifts.
|
| 221 |
+
5. **Lowest integration cost** of the viable candidates — wraps around our existing pipeline rather than replacing it.
|
| 222 |
+
6. **Maturity**: 6.4k stars, last push today, dedicated org, paper-backed.
|
| 223 |
+
|
| 224 |
+
### 3.2 Why not the others (one-liners)
|
| 225 |
+
|
| 226 |
+
- **datatrove**: flat-text `Document`, lossy round-trip on chat structure → deal-breaker.
|
| 227 |
+
- **distilabel**: would force a rewrite of teacher orchestration — too broad a refactor for "wrap normalization around the existing pipeline."
|
| 228 |
+
- **NeMo-Curator**: best ops require GPUs; without them it offers no advantage over data-juicer.
|
| 229 |
+
- **lilac**: archived.
|
| 230 |
+
|
| 231 |
+
### 3.3 Risk register
|
| 232 |
+
|
| 233 |
+
| Risk | Severity | Mitigation |
|
| 234 |
+
|---|---|---|
|
| 235 |
+
| Data-juicer YAML recipe drift between dev and CI | M | Pin `py-data-juicer` version; commit recipe under `replaysim/recipes/` and load via `importlib.resources`. |
|
| 236 |
+
| Some ops silently coerce conversation structure | M | Add a round-trip test: `pair → normalize → pair` must preserve `messages`, `tool_calls`, and arbitrary metadata. |
|
| 237 |
+
| Ray executor bloat if user enables it | L | Default to local Pandas executor; gate Ray behind an explicit flag. |
|
| 238 |
+
| `pair_preference_mapper` calls an LLM by default to synthesize `rejected` | H | We *already have* `rejected` from disagreement. Configure the mapper as a pass-through filter / use it only for refinement; if it can't be made non-LLM, fall back to a custom Mapper that just runs length/language/dedup checks on the existing pair. **Verify in spike before locking in.** |
|
| 239 |
+
| Apache-2.0 inbound license compatibility | L | Our framework is Apache-2.0. Compatible. |
|
| 240 |
+
| Op-graph executes per batch, not per sample, so a single bad pair stalls a batch | L | Use small Ray-Data batches (e.g. 64) so a stall is bounded. |
|
| 241 |
+
|
| 242 |
+
### 3.4 Open spike question (must verify before merge)
|
| 243 |
+
|
| 244 |
+
The single risk worth a 1-day spike: **does `pair_preference_mapper` accept a pre-existing `rejected` and *only* run validation/length/language filters, or does it *always* call an LLM to (re)synthesize a rejected response?** Read the operator source in `data_juicer/ops/mapper/pair_preference_mapper.py` and confirm. If the latter, we wire our pre-existing `rejected` through `optimize_response_mapper` (refinement, not regeneration) plus a custom no-op preference validator. Either way, the integration shape below stands; only the recipe content changes.
|
| 245 |
+
|
| 246 |
+
---
|
| 247 |
+
|
| 248 |
+
## 4. Integration Sketch
|
| 249 |
+
|
| 250 |
+
### 4.1 Current pipeline (today)
|
| 251 |
+
|
| 252 |
+
```
|
| 253 |
+
TraceState
|
| 254 |
+
│
|
| 255 |
+
▼ (per-trace, multi-teacher OpenRouter call)
|
| 256 |
+
replay_trace(state, teachers=[m1, m2, m3])
|
| 257 |
+
│
|
| 258 |
+
▼ (returns: list[TeacherCompletion] keyed by model_id)
|
| 259 |
+
disagreement_score(completions)
|
| 260 |
+
│
|
| 261 |
+
▼ (if score > τ)
|
| 262 |
+
extract_dpo_pairs(completions, state)
|
| 263 |
+
│
|
| 264 |
+
▼ (yields)
|
| 265 |
+
DPOPair { prompt: messages[], chosen: messages[], rejected: messages[], state, meta }
|
| 266 |
+
│
|
| 267 |
+
▼
|
| 268 |
+
write_jsonl(out_path)
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### 4.2 Proposed pipeline (with data-juicer normalization op-graph)
|
| 272 |
+
|
| 273 |
+
```
|
| 274 |
+
TraceState
|
| 275 |
+
│
|
| 276 |
+
▼
|
| 277 |
+
replay_trace(state, teachers) ← unchanged
|
| 278 |
+
│
|
| 279 |
+
▼
|
| 280 |
+
disagreement_score(completions) ← unchanged
|
| 281 |
+
│
|
| 282 |
+
▼
|
| 283 |
+
extract_dpo_pairs(completions, state) ← unchanged
|
| 284 |
+
│
|
| 285 |
+
▼
|
| 286 |
+
[NEW] DJNormalizer.normalize_batch(dpo_pairs) ──── loads recipe from
|
| 287 |
+
│ replaysim/recipes/dpo_normalize.yaml
|
| 288 |
+
│ data-juicer op-graph runs:
|
| 289 |
+
│ 1. text_length_filter (on chosen + rejected separately)
|
| 290 |
+
│ 2. language_id_score_filter (en-only or configured)
|
| 291 |
+
│ 3. dialog_topic_detection_mapper (annotates meta, no drop)
|
| 292 |
+
│ 4. minhash_deduplicator (on prompt+chosen serialization)
|
| 293 |
+
│ 5. (optional) optimize_response_mapper to clean trailing whitespace, code-block fences
|
| 294 |
+
│ 6. custom PreferenceValidator op (chosen != rejected, both non-empty,
|
| 295 |
+
│ tool_calls structurally valid)
|
| 296 |
+
▼
|
| 297 |
+
write_jsonl(out_path) ← unchanged consumer
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
The op-graph is a **wrapper around** `extract_dpo_pairs`, not a replacement. `replay_trace` and `extract_dpo_pairs` keep their current signatures. The only call-site change in `teacher_replay.py` is one line:
|
| 301 |
+
|
| 302 |
+
```python
|
| 303 |
+
# before:
|
| 304 |
+
pairs = list(extract_dpo_pairs(completions, state))
|
| 305 |
+
write_jsonl(out_path, pairs)
|
| 306 |
+
|
| 307 |
+
# after:
|
| 308 |
+
pairs = list(extract_dpo_pairs(completions, state))
|
| 309 |
+
pairs = DJNormalizer.from_recipe("dpo_normalize.yaml").normalize_batch(pairs)
|
| 310 |
+
write_jsonl(out_path, pairs)
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
### 4.3 Adapter shape (`replaysim/normalize.py`)
|
| 314 |
+
|
| 315 |
+
```python
|
| 316 |
+
# composer_replication/replaysim/normalize.py
|
| 317 |
+
from __future__ import annotations
|
| 318 |
+
from dataclasses import asdict
|
| 319 |
+
from importlib.resources import files
|
| 320 |
+
from typing import Iterable
|
| 321 |
+
|
| 322 |
+
from data_juicer.config import init_configs
|
| 323 |
+
from data_juicer.core.executor import DefaultExecutor
|
| 324 |
+
from data_juicer.format import load_formatter
|
| 325 |
+
|
| 326 |
+
from .types import DPOPair
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class DJNormalizer:
|
| 330 |
+
"""Wraps a data-juicer op-graph as a batch normalization step over
|
| 331 |
+
DPOPair samples produced by extract_dpo_pairs.
|
| 332 |
+
|
| 333 |
+
The recipe (YAML) declares the op sequence. Operators consume and
|
| 334 |
+
produce the data-juicer conversation schema, which we convert to
|
| 335 |
+
and from our internal DPOPair on the boundary.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(self, recipe_path: str):
|
| 339 |
+
cfg = init_configs(["--config", recipe_path])
|
| 340 |
+
self._executor = DefaultExecutor(cfg)
|
| 341 |
+
|
| 342 |
+
@classmethod
|
| 343 |
+
def from_recipe(cls, name: str) -> "DJNormalizer":
|
| 344 |
+
recipe = files("composer_replication.replaysim.recipes") / name
|
| 345 |
+
return cls(str(recipe))
|
| 346 |
+
|
| 347 |
+
@staticmethod
|
| 348 |
+
def _to_dj(p: DPOPair) -> dict:
|
| 349 |
+
# data-juicer preference schema:
|
| 350 |
+
# {"prompt": str-or-messages, "chosen": str-or-messages,
|
| 351 |
+
# "rejected": str-or-messages, "meta": {...}}
|
| 352 |
+
return {
|
| 353 |
+
"prompt": p.prompt, # messages[]
|
| 354 |
+
"chosen": p.chosen, # messages[]
|
| 355 |
+
"rejected": p.rejected, # messages[]
|
| 356 |
+
"meta": {
|
| 357 |
+
"trace_id": p.state.trace_id,
|
| 358 |
+
"teachers": p.meta.get("teachers", []),
|
| 359 |
+
"disagreement": p.meta.get("disagreement"),
|
| 360 |
+
**p.meta,
|
| 361 |
+
},
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
@staticmethod
|
| 365 |
+
def _from_dj(s: dict) -> DPOPair:
|
| 366 |
+
return DPOPair(
|
| 367 |
+
prompt=s["prompt"],
|
| 368 |
+
chosen=s["chosen"],
|
| 369 |
+
rejected=s["rejected"],
|
| 370 |
+
state=..., # rehydrate from meta.trace_id + cache
|
| 371 |
+
meta=s.get("meta", {}),
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
def normalize_batch(self, pairs: Iterable[DPOPair]) -> list[DPOPair]:
|
| 375 |
+
in_records = [self._to_dj(p) for p in pairs]
|
| 376 |
+
# Build an in-memory DJDataset from records (no disk round-trip).
|
| 377 |
+
ds = self._executor.formatter.load_dataset_from_records(in_records)
|
| 378 |
+
ds = self._executor.run(dataset=ds)
|
| 379 |
+
out_records = ds.to_list()
|
| 380 |
+
return [self._from_dj(r) for r in out_records]
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
### 4.4 Recipe (`replaysim/recipes/dpo_normalize.yaml`)
|
| 384 |
+
|
| 385 |
+
```yaml
|
| 386 |
+
# data-juicer recipe for normalizing replaysim DPO output
|
| 387 |
+
project_name: replaysim_dpo_normalize
|
| 388 |
+
executor_type: default # local Pandas; switch to 'ray' for distributed
|
| 389 |
+
np: 4
|
| 390 |
+
|
| 391 |
+
# Conversation/preference schema mode
|
| 392 |
+
text_keys: ['chosen', 'rejected'] # ops scan both response variants
|
| 393 |
+
suffixes: ['.jsonl']
|
| 394 |
+
|
| 395 |
+
process:
|
| 396 |
+
# 1. Length sanity on each response variant
|
| 397 |
+
- text_length_filter:
|
| 398 |
+
text_key: chosen
|
| 399 |
+
min_len: 10
|
| 400 |
+
max_len: 16384
|
| 401 |
+
- text_length_filter:
|
| 402 |
+
text_key: rejected
|
| 403 |
+
min_len: 10
|
| 404 |
+
max_len: 16384
|
| 405 |
+
|
| 406 |
+
# 2. Language gate (configurable; default English-only)
|
| 407 |
+
- language_id_score_filter:
|
| 408 |
+
text_key: chosen
|
| 409 |
+
lang: en
|
| 410 |
+
min_score: 0.6
|
| 411 |
+
|
| 412 |
+
# 3. Dialog topic annotation (no drop, just attaches meta.topic)
|
| 413 |
+
- dialog_topic_detection_mapper:
|
| 414 |
+
api_or_hf_model: openrouter:openai/gpt-4o-mini
|
| 415 |
+
mode: annotate
|
| 416 |
+
|
| 417 |
+
# 4. Near-duplicate removal across the batch on (prompt + chosen)
|
| 418 |
+
- document_minhash_deduplicator:
|
| 419 |
+
tokenization: space
|
| 420 |
+
window_size: 5
|
| 421 |
+
num_permutations: 256
|
| 422 |
+
jaccard_threshold: 0.85
|
| 423 |
+
text_key: chosen
|
| 424 |
+
|
| 425 |
+
# 5. Custom preference validator (chosen != rejected, structural integrity)
|
| 426 |
+
- preference_validator_filter: # module: composer_replication.replaysim.ops
|
| 427 |
+
check_distinct: true
|
| 428 |
+
check_tool_calls_valid: true
|
| 429 |
+
```
|
| 430 |
+
|
| 431 |
+
A custom op `preference_validator_filter` lives in `composer_replication/replaysim/ops/preference_validator.py` and is registered via data-juicer's plugin entry point.
|
| 432 |
+
|
| 433 |
+
### 4.5 Hook into `teacher_replay.py`
|
| 434 |
+
|
| 435 |
+
```python
|
| 436 |
+
# composer_replication/replaysim/teacher_replay.py (delta)
|
| 437 |
+
|
| 438 |
+
from .normalize import DJNormalizer
|
| 439 |
+
|
| 440 |
+
def run_replay(traces, teachers, out_path, *, normalize: bool = True):
|
| 441 |
+
pairs: list[DPOPair] = []
|
| 442 |
+
for state in traces:
|
| 443 |
+
completions = replay_trace(state, teachers=teachers)
|
| 444 |
+
if disagreement_score(completions) <= TAU:
|
| 445 |
+
continue
|
| 446 |
+
pairs.extend(extract_dpo_pairs(completions, state))
|
| 447 |
+
|
| 448 |
+
if normalize:
|
| 449 |
+
norm = DJNormalizer.from_recipe("dpo_normalize.yaml")
|
| 450 |
+
pairs = norm.normalize_batch(pairs)
|
| 451 |
+
|
| 452 |
+
write_jsonl(out_path, pairs)
|
| 453 |
+
```
|
| 454 |
+
|
| 455 |
+
The `normalize=True` flag keeps the old code-path one negation away during initial rollout.
|
| 456 |
+
|
| 457 |
+
### 4.6 Test plan (`tests/replaysim/test_normalize.py`)
|
| 458 |
+
|
| 459 |
+
1. **Round-trip preservation**: synthesize a DPOPair with `tool_calls`, run through `DJNormalizer.normalize_batch`, assert tool-call structure and arbitrary `meta` keys are preserved.
|
| 460 |
+
2. **Length filter**: a pair with empty `chosen` is dropped.
|
| 461 |
+
3. **Language filter**: a non-English `chosen` (Cyrillic) below the score threshold is dropped.
|
| 462 |
+
4. **Near-duplicate**: two pairs with identical `chosen` collapse to one.
|
| 463 |
+
5. **Distinctness**: a pair where `chosen == rejected` is dropped by `preference_validator_filter`.
|
| 464 |
+
6. **Multi-turn**: a 3-turn conversation in `prompt` survives end-to-end with role+content intact.
|
| 465 |
+
7. **Recipe loading**: `DJNormalizer.from_recipe("dpo_normalize.yaml")` works with `importlib.resources` regardless of install location.
|
| 466 |
+
|
| 467 |
+
---
|
| 468 |
+
|
| 469 |
+
## 5. ADR-004 Implications
|
| 470 |
+
|
| 471 |
+
ADR-004 (the umbrella ADR for "replaysim with normalization") should record:
|
| 472 |
+
|
| 473 |
+
- **Decision**: adopt data-juicer (`datajuicer/data-juicer`, Apache-2.0) as the normalization op-graph layer.
|
| 474 |
+
- **Status**: proposed; promote to accepted after the spike on `pair_preference_mapper`.
|
| 475 |
+
- **Consequences**:
|
| 476 |
+
- New runtime dependency: `py-data-juicer` (transitively pulls `pyarrow`, `datasets`, `loguru`, `jsonargparse`).
|
| 477 |
+
- Optional `ray` extra for distributed execution; not enabled by default.
|
| 478 |
+
- `replaysim/recipes/*.yaml` becomes a versioned config artifact; recipe changes must accompany behavioral-test updates.
|
| 479 |
+
- Tool-call and multi-turn structure preserved through normalization — verified by round-trip test.
|
| 480 |
+
- **Alternatives considered**: distilabel (too broad — would replace generation orchestration), datatrove (flat-text only — deal-breaker), NeMo-Curator (GPU-bound), lilac (archived).
|
| 481 |
+
|
| 482 |
+
---
|
| 483 |
+
|
| 484 |
+
## 6. Primary-source citations
|
| 485 |
+
|
| 486 |
+
| Claim | Source |
|
| 487 |
+
|---|---|
|
| 488 |
+
| datatrove license, last push, archived state | `https://api.github.com/repos/huggingface/datatrove` (`license.spdx_id`, `pushed_at`, `archived`) |
|
| 489 |
+
| datatrove `Document` is text+metadata, no `messages` field; built-in filters operate on `doc.text` | DeepWiki index of `huggingface/datatrove`, `src/datatrove/data.py`, `src/datatrove/pipeline/filters/c4_filters.py` |
|
| 490 |
+
| datatrove multi-turn only via `InferenceRunner.rollout_fn` | DeepWiki index of `huggingface/datatrove`, `src/datatrove/pipeline/inference/run_inference.py` |
|
| 491 |
+
| data-juicer license, last push, redirect to `datajuicer/data-juicer` | `https://api.github.com/repos/modelscope/data-juicer` (resolves to `datajuicer/data-juicer`) |
|
| 492 |
+
| data-juicer supports `messages: [{role, content}]` and Data-Juicer dialog format `{query, response, history}` | DeepWiki index of `modelscope/data-juicer` |
|
| 493 |
+
| `pair_preference_mapper` synthesizes `rejected_response` and `reason` | DeepWiki index of `modelscope/data-juicer`, `data_juicer/ops/mapper/pair_preference_mapper.py` |
|
| 494 |
+
| data-juicer GPU-required ops are tagged `🚀GPU` (image/video/multi-modal); core text + dialog mappers are CPU-OK | DeepWiki index of `modelscope/data-juicer` |
|
| 495 |
+
| NeMo-Curator license, last push, redirect to `NVIDIA-NeMo/Curator` | `https://api.github.com/repos/NVIDIA/NeMo-Curator` |
|
| 496 |
+
| NeMo-Curator semantic dedup is GPU-only; CPU install drops differentiating ops | DeepWiki index of `NVIDIA/NeMo-Curator` |
|
| 497 |
+
| distilabel license, last push, DAG model, `FormatChatGenerationDPO`, `MinHashDedup`, `DeitaFiltering` | `https://api.github.com/repos/argilla-io/distilabel`; DeepWiki index of `argilla-io/distilabel` |
|
| 498 |
+
| `databricks/lilac` archived 2024-03-19 | `https://api.github.com/repos/databricks/lilac` (`archived: true`, `pushed_at: "2024-03-19T12:41:30Z"`) |
|
| 499 |
+
| `lilacai/lilac` is a 2-star squatter stub created 2025-11-14 | `https://api.github.com/repos/lilacai/lilac` |
|
| 500 |
+
|
| 501 |
+
---
|
| 502 |
+
|
| 503 |
+
## 7. Confirmed output path
|
| 504 |
+
|
| 505 |
+
**File:** `/home/codeseys/.hermes/hermes-agent/docs/research/REPLAYSIM_NORMALIZATION_RECONNAISSANCE.md`
|
| 506 |
+
**Length:** ≤600 lines (this file).
|
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RL Post-Training Frameworks Landscape & Meta PyTorch Stack Audit
|
| 2 |
+
|
| 3 |
+
> **Generated:** 2026-05-25
|
| 4 |
+
> **Scope:** Audit of RL post-training frameworks beyond TRL+VeRL plus Meta's PyTorch agentic stack components, with a recommendation of two additions to the Composer Replication Framework.
|
| 5 |
+
> **Feeds:** ADR-006 (Algorithm-substrate selection)
|
| 6 |
+
> **Companion docs:** `~/wiki/research/post-training-framework/04-verl-trl.md`, `~/wiki/research/post-training-framework/03-monarch-torchforge-openenv.md`, `~/wiki/research/post-training-framework/02-diloco-family.md`
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## TL;DR — Recommendation
|
| 11 |
+
|
| 12 |
+
| Slot | Pick | Why |
|
| 13 |
+
|---|---|---|
|
| 14 |
+
| **RL framework #3 (after TRL, VeRL)** | **PRIME-RL (PrimeIntellect-ai/prime-rl)** | First-class `CustomLossConfig` extension point (`trainer.loss.type=custom` + `import_path`) — the cleanest place we have to drop our **3-channel loss (RLVR + hint-distill + trace-replay)** without forking. Already uses the `verifiers` env protocol that bridges to OpenEnv. Async, decentralized substrate. Apache-2.0. INTELLECT-2 production receipts. |
|
| 15 |
+
| **Infra component (Meta stack)** | **Monarch (`meta-pytorch/monarch`)** as the actor-mesh control plane; **TorchTitan** is *also* tracked as the FSDP2/TP/PP training core but is already the trainer inside both PRIME-RL and TorchForge, so we adopt it transitively. The single net-new dependency is **Monarch**. | Monarch is the only Meta-stack component that is (a) actively shipped (v0.4 GA, v0.5 dev, weekly wheels), (b) decoupled from the now-paused TorchForge, and (c) able to host *any* SPMD trainer (TRL, VeRL, PRIME-RL) as an `ActorMesh`. BSD-3. Replaces Ray when our v0.2 lands. |
|
| 16 |
+
|
| 17 |
+
**What we do NOT add:**
|
| 18 |
+
- OpenRLHF — strong production framework (v0.9.10, 9.3K★, supports DAPO) but its custom-loss path requires modifying `openrlhf/models/loss.py` + a `Trainer` subclass. Strictly worse extension story than PRIME-RL for our specific need (3-channel loss).
|
| 19 |
+
- NeMo-Aligner — no GRPO, no DAPO, heavy NeMo/Megatron dependency. Wrong shape.
|
| 20 |
+
- Unsloth — TRL wrapper, RL kernels live in closed `unsloth_zoo`. We'd have to fork.
|
| 21 |
+
- LLaMA-Factory — TRL wrapper, no GRPO/DAPO (delegates to EasyR1).
|
| 22 |
+
- DeepSpeed-Chat — effectively unmaintained for new RL algos since Aug 2023; PPO/DPO only.
|
| 23 |
+
- TorchForge — Meta has marked the repo "development paused, consolidating into TorchTitan." Borrow patterns; do not depend on it.
|
| 24 |
+
- torchchat — inference / local deployment only; no training. Out of scope.
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Table of Contents
|
| 29 |
+
|
| 30 |
+
1. [Audit Methodology](#1-audit-methodology)
|
| 31 |
+
2. [RL Framework Audit](#2-rl-framework-audit)
|
| 32 |
+
1. [OpenRLHF](#21-openrlhf)
|
| 33 |
+
2. [PRIME-RL](#22-prime-rl)
|
| 34 |
+
3. [NeMo-Aligner](#23-nemo-aligner)
|
| 35 |
+
4. [Unsloth (RL)](#24-unsloth-rl)
|
| 36 |
+
5. [LLaMA-Factory](#25-llama-factory)
|
| 37 |
+
6. [DeepSpeed-Chat](#26-deepspeed-chat)
|
| 38 |
+
3. [Meta PyTorch Agentic Stack — Infra vs Training Split](#3-meta-pytorch-agentic-stack)
|
| 39 |
+
1. [Monarch (coordination/infra)](#31-monarch)
|
| 40 |
+
2. [TorchTitan (training stack)](#32-torchtitan)
|
| 41 |
+
3. [TorchForge (paused)](#33-torchforge)
|
| 42 |
+
4. [torchchat (out of scope)](#34-torchchat)
|
| 43 |
+
4. [Comparison Matrix](#4-comparison-matrix)
|
| 44 |
+
5. [Recommendation Rationale](#5-recommendation-rationale)
|
| 45 |
+
6. [Integration Sketches](#6-integration-sketches)
|
| 46 |
+
7. [Sources](#7-sources)
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 1. Audit Methodology
|
| 51 |
+
|
| 52 |
+
For each framework, we capture five fields that determine whether it can host the Composer Replication Framework's three-channel loss (RLVR + hint-distill + trace-replay) on our existing OpenEnv-compatible TRL data path:
|
| 53 |
+
|
| 54 |
+
1. **Repo + license + last commit + maturity** — primary GitHub source, license grade for redistribution, recency, and whether the project is *production*, *research*, or *archived*.
|
| 55 |
+
2. **Algorithm coverage** — does it ship GRPO and DAPO out of the box? (DAPO matters because Composer-style training inherits its decoupled clip + dynamic sampling fixes for length and std biases.)
|
| 56 |
+
3. **Custom-loss extension point** — concrete file/class/config where a custom 3-channel loss can be plugged. We strongly prefer a stable public hook over forking.
|
| 57 |
+
4. **Integration cost** — rough lines of code needed for a `Recipe` doc + a skeleton `Trainer` subclass that runs end-to-end on a small env.
|
| 58 |
+
5. **OpenEnv data-path fit** — does it already consume the OpenEnv contract (typed `reset`/`step`/`close`, MCP tool-calling) directly, or do we have to write a shim?
|
| 59 |
+
|
| 60 |
+
Primary sources: each repo's `README.md`, official releases page, and DeepWiki audits (where indexed). Secondary checks: PyPI release timelines for Meta packages.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## 2. RL Framework Audit
|
| 65 |
+
|
| 66 |
+
### 2.1 OpenRLHF
|
| 67 |
+
|
| 68 |
+
| Field | Value |
|
| 69 |
+
|---|---|
|
| 70 |
+
| **Repo** | https://github.com/OpenRLHF/OpenRLHF |
|
| 71 |
+
| **License** | Apache-2.0 |
|
| 72 |
+
| **Stars / contributors** | 9,312 ★ / 90 contributors |
|
| 73 |
+
| **Latest release** | v0.9.10, 2026-04-04 |
|
| 74 |
+
| **Last push** | 2026-04-05 |
|
| 75 |
+
| **Maturity** | **Production** — used in many public RLHF runs since 2023; tagline "An Easy-to-use, Scalable and High-performance Agentic RL Framework based on Ray (PPO & DAPO & REINFORCE++ & TIS & vLLM & Ray & Async RL)" |
|
| 76 |
+
| **Algorithms** | PPO, GRPO, **DAPO** (release notes; advertised as a primary feature in v0.9.x), REINFORCE++, REINFORCE++-baseline, RLOO, GSPO, Async RL, TIS (truncated importance sampling) |
|
| 77 |
+
| **Custom-loss extension point** | `openrlhf/models/loss.py` — `PolicyLoss`, `DPOLoss`, `SFTLoss`, `PairWiseLoss`, `LogExpLoss` are concrete `nn.Module`s. To add a 3-channel loss you would (a) add a new `nn.Module` (e.g. `ThreeChannelLoss`) here, then (b) subclass the relevant `Trainer` (e.g. `PPOTrainer` / a new GRPO-derived trainer) and replace `self.loss_fn`. There is **no config-driven custom-loss hook** equivalent to PRIME-RL's `CustomLossConfig` — you fork or vendor. |
|
| 78 |
+
| **Integration cost** | Higher than PRIME-RL. Estimated **~400–600 LOC**: ~150 LOC for a `ThreeChannelLoss` module, ~200 LOC for a `ComposerGRPOTrainer` subclass that routes the three signals (RLVR scalar, hint-distill teacher logprobs, trace-replay teacher logits), ~50 LOC for a `Recipe` doc, plus reward-fn glue. |
|
| 79 |
+
| **Data-path fit** | OpenRLHF's input is HF chat templates + a Python reward function or a remote reward URL (`--reward.remote_url`, `--train.agent_func_path`). It does **not** speak the OpenEnv `reset/step` protocol natively, but our existing OpenEnv→TRL adapter could be reused as a callable behind `agent_func_path`. **Medium** lift to wire OpenEnv. |
|
| 80 |
+
|
| 81 |
+
**Verdict:** Strong, mature, well-funded codebase with the *most* complete algorithm coverage of any candidate. Loses to PRIME-RL only because PRIME-RL has a first-class config-driven custom-loss hook that fits our exact need, and PRIME-RL already has the `verifiers`/OpenEnv shape baked into the orchestrator. We keep OpenRLHF on the radar as a fallback substrate if PRIME-RL's decentralized story is overkill for v0.1.
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
### 2.2 PRIME-RL
|
| 86 |
+
|
| 87 |
+
| Field | Value |
|
| 88 |
+
|---|---|
|
| 89 |
+
| **Repo** | https://github.com/PrimeIntellect-ai/prime-rl |
|
| 90 |
+
| **License** | Apache-2.0 |
|
| 91 |
+
| **Stars / contributors** | 1,398 ★ / 60 contributors |
|
| 92 |
+
| **Latest release** | v0.5.0, 2026-03-30 |
|
| 93 |
+
| **Last push** | 2026-05-25 (active today) |
|
| 94 |
+
| **Maturity** | **Production-research hybrid** — substrate behind INTELLECT-1/2 multi-DC runs; tagline "Async RL Training at Scale". Decentralized DiLoCo-shape compute is its differentiator. |
|
| 95 |
+
| **Algorithms** | **GRPO**, GSPO, on-policy distillation with a teacher model. `default_loss_fn` = DPPO + KL (a GRPO variant; similar lineage to DAPO's decoupled-clip idea but the upstream "DAPO" label is not used verbatim). |
|
| 96 |
+
| **Custom-loss extension point** | **Best in class.** `src/prime_rl/trainer/rl/loss.py` exposes a `LossInputs`/`LossOutputs` interface and `setup_loss_fn` resolves a config: `trainer.loss.type = "custom"` + `trainer.loss.import_path = "your_pkg.your_module.your_loss_fn"` + optional kwargs. The custom function receives `trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`, `advantages`, `loss_mask` — i.e., the exact tensor inputs needed for a 3-channel loss (RLVR uses `advantages`, hint-distill uses `teacher_logprobs`, trace-replay can be threaded through `kwargs` as a precomputed reference). |
|
| 97 |
+
| **Integration cost** | **Lowest.** Estimated **~200–300 LOC total**: ~120 LOC for a `composer_three_channel_loss` function in our package + ~30 LOC of config (`recipes/composer_v0.toml`), ~80 LOC `Recipe` doc. No subclassing required for the loss. A small adapter is needed if we precompute the trace-replay teacher distribution outside the `LossInputs` struct. |
|
| 98 |
+
| **Data-path fit** | **Already aligned.** PRIME-RL's orchestrator consumes `verifiers` environments via `vf.EnvServer`. The OpenEnv ↔ verifiers shim is a known small adapter (the `verifiers` library is the Hub-side env runner that OpenEnv's TRL guide already uses). Our existing OpenEnv-compatible TRL data path drops in with a thin wrapper. |
|
| 99 |
+
|
| 100 |
+
**Verdict:** Best fit for the framework. The combination of (i) config-driven custom loss with the right tensor signatures already present, (ii) verifiers/OpenEnv shape, (iii) decentralized async training that maps to our DiLoCo plans, makes PRIME-RL the substrate of choice for v0.1. **Recommended addition #1.**
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
### 2.3 NeMo-Aligner
|
| 105 |
+
|
| 106 |
+
| Field | Value |
|
| 107 |
+
|---|---|
|
| 108 |
+
| **Repo** | https://github.com/NVIDIA/NeMo-Aligner |
|
| 109 |
+
| **License** | Apache-2.0 |
|
| 110 |
+
| **Maturity** | **Research-leaning production** — NVIDIA-maintained, tied to NeMo/Megatron-LM. Advertised as "early stages of development" in its own README. |
|
| 111 |
+
| **Algorithms** | PPO, REINFORCE, RS (Rejection Sampling), DPO, RPO. **No GRPO. No DAPO.** |
|
| 112 |
+
| **Custom-loss extension point** | `loss_func` method on Megatron model classes (e.g. `MegatronGPTDPOModel.loss_func`). Requires NeMo model-class subclassing and Megatron-LM familiarity. |
|
| 113 |
+
| **Integration cost** | High. Estimated **~800–1,200 LOC** including .nemo conversion of HF weights, Megatron model wrapping, custom Megatron `loss_func`, and a recipe. Plus the operational cost of running on Megatron-LM (Triton kernels, NeMo container). |
|
| 114 |
+
| **Data-path fit** | JSONL only; no OpenEnv. We'd write a full env adapter. |
|
| 115 |
+
|
| 116 |
+
**Verdict:** Wrong shape. No GRPO/DAPO and tightly bound to the NeMo ecosystem. Only relevant if we ever need NVIDIA-supported large-scale Megatron RL, which we don't for the Composer Replication v0.1/v0.2 horizon. **Reject.**
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
### 2.4 Unsloth (RL)
|
| 121 |
+
|
| 122 |
+
| Field | Value |
|
| 123 |
+
|---|---|
|
| 124 |
+
| **Repo** | https://github.com/unslothai/unsloth |
|
| 125 |
+
| **License** | Apache-2.0 (per public README; not surfaced by DeepWiki snapshot but well-known) |
|
| 126 |
+
| **Maturity** | **Production** for SFT and LoRA/QLoRA; **research/preview** for RL — RL support shipped in 2025 as a TRL patcher. |
|
| 127 |
+
| **Algorithms** | Wraps TRL → inherits TRL's GRPO; loss-type switch supports `"grpo"`, `"bnpo"`, `"dr_grpo"`, `"dapo"`, `"cispo"`. So **GRPO and DAPO are both available** through the patched-TRL path. |
|
| 128 |
+
| **Custom-loss extension point** | Problematic. The actual loss kernels live in `unsloth_zoo` (a *separate* compiled dependency). The patcher (`patch_trl_rl_trainers()`) generates modified TRL trainer classes via `exec()` from string templates. To add a new loss type you would have to (a) modify or fork `unsloth_zoo` to add a kernel, (b) extend `RL_REPLACEMENTS`, and (c) extend the `compute_loss()` switch in the patcher template. **There is no public Python subclass hook that survives the patching.** |
|
| 129 |
+
| **Integration cost** | Very high if we want our own loss. Forking `unsloth_zoo` defeats the purpose of using Unsloth (which is the optimized kernels). Estimated ~1,000+ LOC plus an external repo to maintain. |
|
| 130 |
+
| **Data-path fit** | TRL-shaped, so OpenEnv via TRL is fine — but only for *stock* TRL losses. Our 3-channel loss does not survive Unsloth's patching. |
|
| 131 |
+
|
| 132 |
+
**Verdict:** Excellent for memory-efficient SFT and stock-GRPO LoRA. Wrong tool for a custom loss. **Reject** as the substrate; we may still use it as an *optional* QLoRA accelerator inside a stock-GRPO ablation run.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
### 2.5 LLaMA-Factory
|
| 137 |
+
|
| 138 |
+
| Field | Value |
|
| 139 |
+
|---|---|
|
| 140 |
+
| **Repo** | https://github.com/hiyouga/LLaMA-Factory |
|
| 141 |
+
| **License** | Apache-2.0 |
|
| 142 |
+
| **Maturity** | **Production** for breadth (50+ model families, SFT/DPO/PPO recipes), but RL is a thin TRL wrapper. |
|
| 143 |
+
| **Algorithms** | PPO, DPO, KTO, ORPO, SimPO via `Custom*Trainer` subclasses of the corresponding `trl.*Trainer` classes. **No GRPO. No DAPO** in the repo itself; the README points to **EasyR1** (an external GRPO framework) for those. |
|
| 144 |
+
| **Custom-loss extension point** | `compute_preference_loss` switch on `CustomDPOTrainer` (selects `sigmoid` / `hinge` / `ipo` / `kto_pair` / `orpo` / `simpo`). For PPO, you would subclass `CustomPPOTrainer` → which is `trl.PPOTrainer`. Effectively the same extension story as plain TRL, with a configuration layer on top. |
|
| 145 |
+
| **Integration cost** | Moderate, ~400 LOC, but you are essentially using TRL through one extra layer. |
|
| 146 |
+
| **Data-path fit** | Text/dataset-shaped, not OpenEnv-aware. Same OpenEnv-via-TRL story. |
|
| 147 |
+
|
| 148 |
+
**Verdict:** Useful as a multi-model SFT laboratory but does not move the ball for our RL-side requirements. **Reject** as substrate; we already have TRL.
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
### 2.6 DeepSpeed-Chat
|
| 153 |
+
|
| 154 |
+
| Field | Value |
|
| 155 |
+
|---|---|
|
| 156 |
+
| **Repo** | https://github.com/deepspeedai/DeepSpeedExamples (the `applications/DeepSpeed-Chat/` subtree) |
|
| 157 |
+
| **License** | Apache-2.0 |
|
| 158 |
+
| **Maturity** | **Effectively stale.** The README's "Latest News" cuts off in August 2023. CI patches in 2025 (e.g., #6982, #7015, #7052) are dependency-pinning fixes, not feature work. The roadmap to "generalize DeepSpeed-RLHF abstraction for a wider range of RL algorithms" has not landed. |
|
| 159 |
+
| **Algorithms** | PPO (3-stage RLHF) + DPO. **No GRPO. No DAPO.** |
|
| 160 |
+
| **Custom-loss extension point** | `DeepSpeedPPOTrainer.train_rlhf` / `actor_loss_fn` / `critic_loss_fn`. Editable but not config-hooked. |
|
| 161 |
+
| **Integration cost** | Moderate, but you inherit a frozen architecture. ~500 LOC. |
|
| 162 |
+
| **Data-path fit** | Prompt-dataset-shaped; no OpenEnv. |
|
| 163 |
+
|
| 164 |
+
**Verdict:** Pioneering for its time, no longer competitive on algorithm coverage. **Reject.**
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## 3. Meta PyTorch Agentic Stack — Infra vs Training Split
|
| 169 |
+
|
| 170 |
+
The brief asked specifically to **distinguish coordination/infra from training-stack** components. The answer is:
|
| 171 |
+
|
| 172 |
+
| Component | Layer | Status (May 2026) | In our framework? |
|
| 173 |
+
|---|---|---|---|
|
| 174 |
+
| **Monarch** (`meta-pytorch/monarch`) | **Coordination / Infra** — actor mesh, RDMA data plane, supervision trees | **Active.** v0.4 GA (2026-03-26), v0.5 dev wheels daily, BSD-3 | **Yes — recommended addition.** |
|
| 175 |
+
| **TorchTitan** (`pytorch/torchtitan`) | **Training stack** — FSDP2 / TP / PP / CP / float8 / MXFP8 | **Active.** BSD-3, "extensive development". Has an experimental GRPO recipe (`experiments/rl/simple_grpo_sum_digits.py`) on Monarch. | **Indirectly** — already the trainer inside PRIME-RL and TorchForge. We adopt it transitively, not as a direct dependency. |
|
| 176 |
+
| **TorchForge** (`meta-pytorch/forge`) | RL post-training library | **Development paused** per the repo banner; consolidating into TorchTitan. ~685★. | **Pattern reference only.** Lift the Generator/Trainer/Rewarder *shape* but do not depend on the package. |
|
| 177 |
+
| **torchchat** (`pytorch/torchchat`) | **Inference / local deployment** | Active for its own scope, but: not a training framework; no RL surface. | **Out of scope.** |
|
| 178 |
+
| **OpenEnv** (`meta-pytorch/OpenEnv`) | Environment standard (covered separately) | Active. Already a v0 dependency of the framework. | Already adopted. |
|
| 179 |
+
|
| 180 |
+
### 3.1 Monarch
|
| 181 |
+
|
| 182 |
+
| Field | Value |
|
| 183 |
+
|---|---|
|
| 184 |
+
| **Repo** | https://github.com/meta-pytorch/monarch |
|
| 185 |
+
| **License** | BSD-3-Clause |
|
| 186 |
+
| **PyPI** | `torchmonarch`; v0.4.1 stable (2026-04-08), v0.5.0 dev wheels published daily through 2026-05-05 |
|
| 187 |
+
| **Maturity** | **Experimental but actively shipped.** "Currently in an experimental stage" per the repo's own status note, but with a functioning K8s operator, weekly wheels, ProcessMesh/ActorMesh APIs stable enough for VeRL backend experiments. |
|
| 188 |
+
| **Role in our stack** | **Pure coordination/infra.** It does not train models. It hosts whatever trainer you bring (TRL, VeRL, PRIME-RL, TorchTitan) as `Actor` subclasses on a `ProcMesh`. The `monarch.spmd.SPMDActor` automatically configures `RANK`/`LOCAL_RANK`/`WORLD_SIZE` for any PyTorch-distributed script — i.e., we can lift our existing TRL or PRIME-RL workers into Monarch with minimal change. |
|
| 189 |
+
| **Key abstractions** | `ProcMesh` (processes × hosts × GPUs), `ActorMesh` (typed actors with `@endpoint` methods), supervision trees, RDMA buffers, distributed tensors / DTensor integration. Underlying runtime: `hyperactor` (Rust). |
|
| 190 |
+
| **Why over Ray** | Tighter PyTorch/DTensor integration; explicit RDMA data plane (Ray uses object store + standard networking); single-controller mental model maps directly to RL post-training (one controller orchestrates Generator + Trainer + Rewarder + Env actors). |
|
| 191 |
+
| **Integration cost into Composer Replication** | **~300 LOC + ops**: (a) wrap our PRIME-RL trainer as an `SPMDActor`; (b) wrap our vLLM rollout server as an `Actor` with an `@endpoint generate(prompts)` method; (c) write a single controller script that creates a `ProcMesh`, spawns both meshes, and shuttles `DataProto`-shaped messages; (d) Recipe doc. The ops cost is the harder half — Monarch's K8s operator is new (v0.2.0+). |
|
| 192 |
+
| **Risk** | Pre-1.0; API churn possible (e.g., `KubernetesJob.add_mesh` signature changed in v0.5). Mitigation: pin to `torchmonarch==0.4.1` for v0.2 of our framework. |
|
| 193 |
+
|
| 194 |
+
### 3.2 TorchTitan
|
| 195 |
+
|
| 196 |
+
| Field | Value |
|
| 197 |
+
|---|---|
|
| 198 |
+
| **Repo** | https://github.com/pytorch/torchtitan |
|
| 199 |
+
| **License** | BSD-3-Clause |
|
| 200 |
+
| **Maturity** | **Active development** for pretraining; **experimental** for RL. The GRPO experiment (`torchtitan/experiments/rl/simple_grpo_sum_digits.py`) is in `experiments/`, which the repo explicitly disclaims as removable. |
|
| 201 |
+
| **Role** | **Training stack only.** Provides FSDP2 (per-parameter sharding), Tensor Parallel (incl. async TP), Pipeline Parallel (zero-bubble), Context Parallel (long-context), `torch.compile`, Float8, MXFP8, DDP, HSDP. |
|
| 202 |
+
| **OpenEnv-aware?** | No, but the experimental `RLTrainer` integrates `vLLM` + Monarch actors, which is the same shape PRIME-RL uses. |
|
| 203 |
+
| **Why we don't add it directly** | **PRIME-RL already uses TorchTitan-equivalent FSDP2 internals**, and TorchForge's training core was TorchTitan. Adding TorchTitan as a *direct* dependency would mean writing our own RL loop on top of it — that's TorchForge's job, and Meta paused exactly that effort. The right move is to depend on PRIME-RL, which has battle-tested distributed training patterns equivalent to TorchTitan's, and revisit TorchTitan directly only when we genuinely need its experimental zero-bubble PP or MXFP8 paths. |
|
| 204 |
+
|
| 205 |
+
### 3.3 TorchForge (Paused)
|
| 206 |
+
|
| 207 |
+
- Repo banner: **"Development paused — LLM training consolidating in TorchTitan."**
|
| 208 |
+
- ~685 ★, 100+ open issues, last meaningful release in early 2026.
|
| 209 |
+
- Patterns we should still copy:
|
| 210 |
+
- Generator/Trainer/Rewarder ActorMesh decomposition
|
| 211 |
+
- TorchStore-style RDMA weight broadcast
|
| 212 |
+
- Async toggle between sync PPO-like and fully async off-policy
|
| 213 |
+
- **We do not add a TorchForge dependency.** Architectural reference only.
|
| 214 |
+
|
| 215 |
+
### 3.4 torchchat (Out of Scope)
|
| 216 |
+
|
| 217 |
+
- Inference / local deployment of LLMs (Eager / `torch.compile` / AOT Inductor / ExecuTorch / mobile).
|
| 218 |
+
- No training, no RL.
|
| 219 |
+
- Mentioned in the brief for completeness; ruled out cleanly.
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## 4. Comparison Matrix
|
| 224 |
+
|
| 225 |
+
### 4.1 RL Frameworks
|
| 226 |
+
|
| 227 |
+
| Framework | License | Last release | Maturity | GRPO | DAPO | Custom-loss hook | OpenEnv fit | Est. integration LOC |
|
| 228 |
+
|---|---|---|---|---|---|---|---|---|
|
| 229 |
+
| **TRL** (baseline) | Apache-2.0 | Active | Production | ✅ | partial (tricks land per release) | Subclass `GRPOTrainer.compute_loss` | ✅ native (Oct 2025 OpenEnv guide) | already integrated |
|
| 230 |
+
| **VeRL** (baseline) | Apache-2.0 | Active | Production | ✅ | ✅ | `core_algos.py` + worker subclass | shim via Ray dataloader | already skeleton |
|
| 231 |
+
| **OpenRLHF** | Apache-2.0 | v0.9.10 (2026-04-04) | Production | ✅ | ✅ | `openrlhf/models/loss.py` + Trainer subclass; **no config hook** | shim via `agent_func_path` | ~400–600 |
|
| 232 |
+
| **PRIME-RL** ⭐ | Apache-2.0 | v0.5.0 (2026-03-30) | Prod-research | ✅ | partial (DPPO+KL variant; not labeled DAPO) | **`CustomLossConfig` import_path — first-class** | ✅ via `verifiers` (OpenEnv-compatible) | **~200–300** |
|
| 233 |
+
| **NeMo-Aligner** | Apache-2.0 | Active | Research-leaning | ❌ | ❌ | Megatron model `loss_func` | none; JSONL only | ~800–1,200 |
|
| 234 |
+
| **Unsloth (RL)** | Apache-2.0 | Active | Production (SFT) / preview (RL) | ✅ (via TRL patch) | ✅ (via TRL patch) | Loss kernels in closed `unsloth_zoo`; effectively unhookable | TRL-shaped | ~1,000+ (forking) |
|
| 235 |
+
| **LLaMA-Factory** | Apache-2.0 | Active | Production | ❌ (delegates to EasyR1) | ❌ | TRL `Custom*Trainer` subclass | TRL-shaped | ~400 |
|
| 236 |
+
| **DeepSpeed-Chat** | Apache-2.0 | Stale (Aug 2023 features; 2025 only CI fixes) | Effectively maintained-only | ❌ | ❌ | `DeepSpeedPPOTrainer` subclass | none | ~500 |
|
| 237 |
+
|
| 238 |
+
### 4.2 Meta PyTorch Stack
|
| 239 |
+
|
| 240 |
+
| Component | Layer | License | Status | In recommendation? |
|
| 241 |
+
|---|---|---|---|---|
|
| 242 |
+
| **Monarch** ⭐ | Coordination / actor mesh | BSD-3 | Active (v0.4 GA, v0.5 dev) | **Yes** |
|
| 243 |
+
| **TorchTitan** | Training stack | BSD-3 | Active; RL experimental | Indirect (via PRIME-RL) |
|
| 244 |
+
| **TorchForge** | RL library | BSD-3 | **Paused** | No — patterns only |
|
| 245 |
+
| **torchchat** | Inference / deployment | BSD-3 | Active | No — out of scope |
|
| 246 |
+
| **OpenEnv** | Environment standard | (Hub) | Active | Already adopted |
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## 5. Recommendation Rationale
|
| 251 |
+
|
| 252 |
+
### 5.1 Why PRIME-RL, not OpenRLHF
|
| 253 |
+
|
| 254 |
+
OpenRLHF is in many ways the safer pick: more stars, more contributors, more algorithm coverage (it explicitly ships DAPO). The deciding factor is **the shape of our custom loss**.
|
| 255 |
+
|
| 256 |
+
The Composer Replication Framework's signature contribution is the **three-channel reward**:
|
| 257 |
+
|
| 258 |
+
1. **RLVR** — tests-pass scalar from the OpenEnv environment.
|
| 259 |
+
2. **Composer-style hint-distill (SDPO/OPSD)** — the model self-teaches against its own hint-conditioned roll-outs; needs `teacher_logprobs` aligned to the rollout token grid.
|
| 260 |
+
3. **Trace-replay multi-teacher PRM** (the novel bit) — N frozen external teachers' precomputed token-level distributions, replayed against the on-policy rollout.
|
| 261 |
+
|
| 262 |
+
PRIME-RL's `LossInputs` dataclass already exposes exactly the tensors we need:
|
| 263 |
+
```
|
| 264 |
+
trainer_logprobs, inference_logprobs, teacher_logprobs, advantages, loss_mask
|
| 265 |
+
```
|
| 266 |
+
A custom 3-channel loss is roughly:
|
| 267 |
+
```python
|
| 268 |
+
def composer_three_channel_loss(li: LossInputs, *, hint_weight, replay_weight, replay_logits) -> LossOutputs:
|
| 269 |
+
rlvr = grpo_term(li.trainer_logprobs, li.inference_logprobs, li.advantages, li.loss_mask)
|
| 270 |
+
hint = kl_term(li.trainer_logprobs, li.teacher_logprobs, li.loss_mask)
|
| 271 |
+
replay = kl_term(li.trainer_logprobs, replay_logits, li.loss_mask)
|
| 272 |
+
return LossOutputs(loss=rlvr + hint_weight * hint + replay_weight * replay, ...)
|
| 273 |
+
```
|
| 274 |
+
We register this with `trainer.loss.type = "custom"` + `import_path` and we're done. No subclassing, no `exec()`-patched template, no Megatron model wrapping.
|
| 275 |
+
|
| 276 |
+
OpenRLHF would require us to (a) add a `ThreeChannelLoss` `nn.Module` to `openrlhf/models/loss.py`, (b) subclass `PPOTrainer` (or equivalent GRPO trainer) to construct it with the right teacher-logprob plumbing, and (c) carry that fork forward. ~2× the LOC, plus a fork to maintain.
|
| 277 |
+
|
| 278 |
+
A second factor: PRIME-RL's `verifiers` env protocol is a direct precursor of OpenEnv's wire shape (HTTP/WebSocket env servers, typed observations). Our existing OpenEnv-compatible TRL data path translates with a thin adapter. OpenRLHF's `agent_func_path` is more of an escape hatch than a contract.
|
| 279 |
+
|
| 280 |
+
A third factor: PRIME-RL was *built for decentralized training* (INTELLECT-1/2). Even though our v0.1 stays on a single cluster, the v0.2 multi-DC story drops in cleanly. OpenRLHF is Ray-on-one-cluster by design.
|
| 281 |
+
|
| 282 |
+
### 5.2 Why Monarch, not TorchTitan or TorchForge
|
| 283 |
+
|
| 284 |
+
Among the four Meta-stack components in the brief, only one is both (a) ours to add and (b) genuinely new functionality:
|
| 285 |
+
|
| 286 |
+
- **TorchForge** is paused — depending on it now is a known dead end.
|
| 287 |
+
- **TorchTitan** is already inside PRIME-RL transitively (PRIME-RL uses FSDP2 plus a SHARDCAST weight-broadcast layer that is morally equivalent to what TorchTitan offers). Adding TorchTitan as a *direct* dependency means writing our own RL loop on top of it, which is exactly what TorchForge tried and paused. We get TorchTitan's benefits without owning the integration.
|
| 288 |
+
- **torchchat** is for local inference / mobile deployment — out of scope.
|
| 289 |
+
- **Monarch** is the unique value: a PyTorch-native actor mesh that lets us replace Ray (PRIME-RL's current orchestration substrate) with something that has explicit RDMA, supervision trees, and ProcMesh/ActorMesh primitives that map directly onto our (Generator, Trainer, Rewarder, EnvServer) topology.
|
| 290 |
+
|
| 291 |
+
The migration path is incremental:
|
| 292 |
+
- **v0.1:** PRIME-RL on Ray (current). Monarch listed as roadmap.
|
| 293 |
+
- **v0.2:** Wrap PRIME-RL's Trainer as a `monarch.spmd.SPMDActor`, vLLM Generator as an `Actor` with an `@endpoint generate()`. Switch the orchestrator from `ray.init()` to `this_host().spawn_procs()`.
|
| 294 |
+
- Risk-mitigation: pin to `torchmonarch==0.4.1` (the last GA release before v0.5 dev). Keep a Ray fallback path active until v0.2 is stable.
|
| 295 |
+
|
| 296 |
+
---
|
| 297 |
+
|
| 298 |
+
## 6. Integration Sketches
|
| 299 |
+
|
| 300 |
+
### 6.1 PRIME-RL Recipe skeleton
|
| 301 |
+
|
| 302 |
+
`recipes/composer_v0_prime_rl.toml` (~30 LOC):
|
| 303 |
+
|
| 304 |
+
```toml
|
| 305 |
+
# composer_v0_prime_rl.toml
|
| 306 |
+
[model]
|
| 307 |
+
name = "Qwen/Qwen3-32B" # or Kimi-K2.5 when MoE support lands
|
| 308 |
+
|
| 309 |
+
[data]
|
| 310 |
+
env = "swe_bench_lite" # via verifiers EnvServer; wraps our OpenEnv adapter
|
| 311 |
+
batch_size = 64
|
| 312 |
+
group_size = 16
|
| 313 |
+
|
| 314 |
+
[trainer]
|
| 315 |
+
algorithm = "grpo"
|
| 316 |
+
[trainer.loss]
|
| 317 |
+
type = "custom"
|
| 318 |
+
import_path = "composer_replication.losses.composer_three_channel_loss"
|
| 319 |
+
[trainer.loss.kwargs]
|
| 320 |
+
hint_weight = 0.5
|
| 321 |
+
replay_weight = 0.25
|
| 322 |
+
replay_logits_path = "/data/teachers/precomputed_replay.zarr"
|
| 323 |
+
|
| 324 |
+
[teacher]
|
| 325 |
+
model = "Qwen/Qwen3-32B" # same as policy = self-teacher for hint-distill
|
| 326 |
+
hint_template = "composer.hint_v1"
|
| 327 |
+
|
| 328 |
+
[orchestrator]
|
| 329 |
+
sync_mode = "async"
|
| 330 |
+
shardcast = true
|
| 331 |
+
```
|
| 332 |
+
|
| 333 |
+
`composer_replication/losses.py` (~120 LOC):
|
| 334 |
+
|
| 335 |
+
```python
|
| 336 |
+
# composer_replication/losses.py
|
| 337 |
+
from prime_rl.trainer.rl.loss import LossInputs, LossOutputs
|
| 338 |
+
|
| 339 |
+
def composer_three_channel_loss(
|
| 340 |
+
li: LossInputs,
|
| 341 |
+
*,
|
| 342 |
+
hint_weight: float,
|
| 343 |
+
replay_weight: float,
|
| 344 |
+
replay_logits_handle: str,
|
| 345 |
+
) -> LossOutputs:
|
| 346 |
+
# 1. RLVR via GRPO surrogate
|
| 347 |
+
rlvr = grpo_surrogate(li.trainer_logprobs, li.inference_logprobs,
|
| 348 |
+
li.advantages, li.loss_mask)
|
| 349 |
+
|
| 350 |
+
# 2. Hint-distill: KL(policy || hint-conditioned teacher)
|
| 351 |
+
hint = masked_kl(li.trainer_logprobs, li.teacher_logprobs, li.loss_mask)
|
| 352 |
+
|
| 353 |
+
# 3. Trace-replay: KL(policy || precomputed multi-teacher mixture)
|
| 354 |
+
replay = trace_replay_kl(li.trainer_logprobs, replay_logits_handle, li.loss_mask)
|
| 355 |
+
|
| 356 |
+
total = rlvr + hint_weight * hint + replay_weight * replay
|
| 357 |
+
return LossOutputs(
|
| 358 |
+
loss=total,
|
| 359 |
+
metrics={"rlvr": rlvr.item(), "hint": hint.item(), "replay": replay.item()},
|
| 360 |
+
)
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
Plus `docs/recipes/composer_v0_prime_rl.md` (~50 LOC) describing data layout, teacher precomputation, and reproducibility hashes.
|
| 364 |
+
|
| 365 |
+
**Total: ~200 LOC of code + ~30 LOC config + ~50 LOC docs ≈ 280 LOC.**
|
| 366 |
+
|
| 367 |
+
### 6.2 Monarch wrap-up sketch (v0.2)
|
| 368 |
+
|
| 369 |
+
```python
|
| 370 |
+
# composer_replication/orchestrator/monarch_runner.py (~120 LOC)
|
| 371 |
+
from monarch.actor import Actor, endpoint
|
| 372 |
+
from monarch.proc_mesh import this_host, ProcMesh
|
| 373 |
+
|
| 374 |
+
class TrainerActor(Actor):
|
| 375 |
+
@endpoint
|
| 376 |
+
async def step(self, batch): ...
|
| 377 |
+
|
| 378 |
+
class GeneratorActor(Actor):
|
| 379 |
+
@endpoint
|
| 380 |
+
async def generate(self, prompts): ...
|
| 381 |
+
|
| 382 |
+
class RewarderActor(Actor):
|
| 383 |
+
@endpoint
|
| 384 |
+
async def score(self, traj): ...
|
| 385 |
+
|
| 386 |
+
async def main(cfg):
|
| 387 |
+
train_mesh = await this_host().spawn_procs(TrainerActor, hosts=4, gpus=8)
|
| 388 |
+
gen_mesh = await this_host().spawn_procs(GeneratorActor, hosts=2, gpus=8)
|
| 389 |
+
rew_mesh = await this_host().spawn_procs(RewarderActor, hosts=1, gpus=2)
|
| 390 |
+
|
| 391 |
+
async for step in range(cfg.steps):
|
| 392 |
+
prompts = await env.batch()
|
| 393 |
+
traj = await gen_mesh.generate.broadcast(prompts)
|
| 394 |
+
rewards = await rew_mesh.score.broadcast(traj)
|
| 395 |
+
await train_mesh.step.broadcast({"traj": traj, "rewards": rewards})
|
| 396 |
+
```
|
| 397 |
+
|
| 398 |
+
**Total: ~120 LOC controller + ~50 LOC ops (K8s operator manifest) + ~80 LOC recipe doc ≈ 250 LOC.**
|
| 399 |
+
|
| 400 |
+
---
|
| 401 |
+
|
| 402 |
+
## 7. Sources
|
| 403 |
+
|
| 404 |
+
### Primary
|
| 405 |
+
|
| 406 |
+
- **OpenRLHF** — https://github.com/OpenRLHF/OpenRLHF (README, Releases v0.9.10), Apache-2.0; DeepWiki: `openrlhf/models/loss.py`, `agent_func_path`.
|
| 407 |
+
- **PRIME-RL** — https://github.com/PrimeIntellect-ai/prime-rl (README, Releases v0.5.0), Apache-2.0; DeepWiki: `src/prime_rl/trainer/rl/loss.py`, `CustomLossConfig`, `LossInputs`/`LossOutputs`, `verifiers` integration.
|
| 408 |
+
- **NeMo-Aligner** — https://github.com/NVIDIA/NeMo-Aligner, Apache-2.0; DeepWiki: PPO/REINFORCE/DPO/RPO; `loss_func` on Megatron model classes.
|
| 409 |
+
- **Unsloth** — https://github.com/unslothai/unsloth, README RL section; DeepWiki: `patch_trl_rl_trainers()`, `unsloth_zoo` kernels, DAPO loss-type switch.
|
| 410 |
+
- **LLaMA-Factory** — https://github.com/hiyouga/LLaMA-Factory, Apache-2.0; DeepWiki: `CustomPPOTrainer`/`CustomDPOTrainer`, EasyR1 reference for GRPO.
|
| 411 |
+
- **DeepSpeed-Chat** — https://github.com/deepspeedai/DeepSpeedExamples (`applications/DeepSpeed-Chat/`), Apache-2.0; DeepWiki: 3-stage PPO, DPO; "Latest News" cutoff Aug 2023; 2025 PRs (#6982, #7015, #7052) confirming maintenance-only mode.
|
| 412 |
+
- **Monarch** — https://github.com/meta-pytorch/monarch, BSD-3; PyPI `torchmonarch` v0.4.1 (2026-04-08), v0.5.0 dev wheels through 2026-05-05; DeepWiki: `ProcMesh`, `ActorMesh`, `monarch.spmd.SPMDActor`.
|
| 413 |
+
- **TorchTitan** — https://github.com/pytorch/torchtitan, BSD-3; DeepWiki: FSDP2/TP/PP/CP, `torchtitan/experiments/rl/simple_grpo_sum_digits.py`, integration with vLLM and Monarch.
|
| 414 |
+
- **TorchForge** — https://github.com/meta-pytorch/forge, BSD-3, repo banner "development paused — consolidating in TorchTitan".
|
| 415 |
+
- **torchchat** — https://github.com/pytorch/torchchat, BSD-3; DeepWiki: inference-only (eager / `torch.compile` / AOT Inductor / ExecuTorch).
|
| 416 |
+
|
| 417 |
+
### Companion repository docs (already present)
|
| 418 |
+
|
| 419 |
+
- `~/wiki/research/post-training-framework/04-verl-trl.md` — VeRL vs TRL deep dive.
|
| 420 |
+
- `~/wiki/research/post-training-framework/03-monarch-torchforge-openenv.md` — full Meta-stack survey.
|
| 421 |
+
- `~/wiki/research/post-training-framework/02-diloco-family.md` — DiLoCo / OpenDiLoCo / PRIME-RL / INTELLECT-2.
|
| 422 |
+
- `~/wiki/projects/composer-replication-framework.md` — current TL;DR and stage plan.
|
| 423 |
+
|
| 424 |
+
### Notes on accuracy
|
| 425 |
+
|
| 426 |
+
- "DAPO" labeling: OpenRLHF and Unsloth both advertise DAPO as a first-class loss type; PRIME-RL implements a DAPO-equivalent (decoupled-clip + KL) but uses the internal name `DPPO+KL` in its default loss. For our purposes this is the same family.
|
| 427 |
+
- Last-commit dates and release versions are pulled from GitHub release pages (OpenRLHF, PRIME-RL) and PyPI release history (`torchmonarch`).
|
| 428 |
+
- Star counts and contributor counts reflect the snapshots returned by web search at the time of writing (May 2026) and will drift; the relative ordering is stable.
|
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Self-Distillation Landscape Audit (feeds ADR-007)
|
| 2 |
+
|
| 3 |
+
**Status:** research note, pre-experimental
|
| 4 |
+
**Author:** subagent audit
|
| 5 |
+
**Date:** 2026-05-25
|
| 6 |
+
**Scope:** identify 2–3 distillation-channel losses worth adding to
|
| 7 |
+
`composer_replication` alongside the existing GRPO + SDPO/OPSD `generalized_jsd_loss` +
|
| 8 |
+
multi-teacher trace-replay DPO stack.
|
| 9 |
+
**Bias:** additivity over novelty. We are looking for losses that COMPOSE with
|
| 10 |
+
what is already implemented, not duplicates of it.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## TL;DR — recommended additions
|
| 15 |
+
|
| 16 |
+
| Rank | Method | Loss role | License | LOC est. | Why it composes |
|
| 17 |
+
|------|--------|-----------|---------|----------|-----------------|
|
| 18 |
+
| 1 | **SimPO** (NeurIPS 2024) | Preference, reference-free | MIT | ~80 | Drop-in for trace-replay DPO; removes ref-model VRAM cost; orthogonal to JSD distillation channel |
|
| 19 |
+
| 2 | **TAID** (ICLR 2025) | Interpolated-target wrapper around any KL/JSD | Apache-2.0 | ~150 | Wraps the existing `generalized_jsd_loss` — does not replace it. Closes capacity gap on small students |
|
| 20 |
+
| 3 | **Entropy-Aware OPD** (ICLR 2026 Spotlight) | Token-gated forward/reverse KL mixture | CC BY 4.0 (paper); code expected | ~120 | Fixes a documented failure mode of the reverse-KL-style SDPO loss when teacher entropy is high — directly addresses a known weakness of channel 2 |
|
| 21 |
+
|
| 22 |
+
**Honourable mention:** KTO — useful only if the framework wants to ingest
|
| 23 |
+
binary thumbs-up/thumbs-down trace signals without preference pairs.
|
| 24 |
+
**Not recommended:** GKD, DistiLLM, MiniLLM, Self-Rewarding LM (rationale at end).
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Audit method
|
| 29 |
+
|
| 30 |
+
For each candidate paper (the seven the user named, plus 2026 follow-ups
|
| 31 |
+
discovered via Exa search restricted to `category=research paper, startPublishedDate=2026-01-01`)
|
| 32 |
+
we verified:
|
| 33 |
+
|
| 34 |
+
1. **Primary source exists.** arXiv abstract page reachable; HTML body parsed
|
| 35 |
+
to extract the actual loss formula (not summarised from secondary sources).
|
| 36 |
+
2. **Code is real.** Official repo's README was fetched, `last push` date and
|
| 37 |
+
star count recorded. Forks of MiniLLM/DistiLLM that are no longer maintained
|
| 38 |
+
were marked as such.
|
| 39 |
+
3. **License is permissive enough.** MIT, Apache-2.0, BSD, CC BY 4.0 are
|
| 40 |
+
acceptable for inclusion. GPL or research-only would be flagged.
|
| 41 |
+
4. **Composability check.** Read the framework's existing
|
| 42 |
+
`composer_replication/__init__.py` and `research/05-trace-replay-distillation.md`,
|
| 43 |
+
then asked: *does this loss replace something we have, or stack on top?*
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Candidate 1 — SimPO (Simple Preference Optimization) ⭐ RECOMMENDED
|
| 48 |
+
|
| 49 |
+
### Sources
|
| 50 |
+
- **arXiv:** https://arxiv.org/abs/2405.14734 (Meng, Xia, Chen — UVA + Princeton, NeurIPS 2024)
|
| 51 |
+
- **GitHub:** https://github.com/princeton-nlp/SimPO
|
| 52 |
+
- License: **MIT**
|
| 53 |
+
- 949 stars, 74 forks, last commit 2024-10-12 (mature, post-NeurIPS)
|
| 54 |
+
- Built on top of `huggingface/alignment-handbook`
|
| 55 |
+
- Maturity: **production-ready**. Released checkpoints for Mistral, Llama-3, Gemma-2 base/instruct. Reproducible training configs ship with the repo.
|
| 56 |
+
|
| 57 |
+
### Loss core (reference-free preference)
|
| 58 |
+
SimPO replaces the DPO log-ratio (which requires keeping `π_ref` in memory)
|
| 59 |
+
with the **average log-probability** of the sequence under the policy, plus
|
| 60 |
+
a **target reward margin** γ:
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
r(x, y) = (β / |y|) · log π_θ(y | x) ← length-normalised implicit reward
|
| 64 |
+
(no reference model)
|
| 65 |
+
|
| 66 |
+
L_SimPO(π_θ) = −E_{(x, y_w, y_l) ~ D} [
|
| 67 |
+
log σ( r(x, y_w) − r(x, y_l) − γ )
|
| 68 |
+
]
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
where `β` is a temperature (typically 2.0–10) and `γ` is the desired margin
|
| 72 |
+
between chosen and rejected (the repo recommends `γ/β ≈ 0.5` as a starting
|
| 73 |
+
point). Two consequences: (i) no `π_ref` forward pass per step → roughly half
|
| 74 |
+
the memory, and (ii) the implicit reward is exactly the quantity the model
|
| 75 |
+
generates from at decode time, removing a known DPO pathology where
|
| 76 |
+
decoding-time and training-time rewards diverge.
|
| 77 |
+
|
| 78 |
+
### Why it composes with the existing stack
|
| 79 |
+
- The framework's **channel 3** is multi-teacher trace-replay DPO. SimPO is a
|
| 80 |
+
drop-in replacement for the DPO step inside that channel — same `(x, y_w, y_l)`
|
| 81 |
+
data contract, different loss head. So the trace-replay harvester does not
|
| 82 |
+
change at all.
|
| 83 |
+
- It does **not** touch channel 2 (SDPO/OPSD `generalized_jsd_loss`). The two
|
| 84 |
+
are complementary: JSD-distillation transfers token-level teacher knowledge,
|
| 85 |
+
SimPO sharpens preference structure between trace alternatives.
|
| 86 |
+
- It does **not** duplicate GRPO either. GRPO is online-policy RLVR;
|
| 87 |
+
SimPO is offline preference. Different data sources.
|
| 88 |
+
- The published Mistral-7B and Llama-3-8B SimPO results beat DPO by 4–6 points
|
| 89 |
+
on AlpacaEval-2 LC, which directly translates to "if we already have channel-3
|
| 90 |
+
pairs, SimPO is a free upgrade".
|
| 91 |
+
|
| 92 |
+
### Implementation cost
|
| 93 |
+
- **~80 LOC** for the trainer hook; the loss itself is ~15 lines (log-probs,
|
| 94 |
+
length-normalise, margin, BCE).
|
| 95 |
+
- Dependencies: nothing new — `torch`, `transformers` already in repo.
|
| 96 |
+
- The reference implementation is a single file in `princeton-nlp/SimPO`
|
| 97 |
+
(`scripts/run_simpo.py` + `alignment/` trainer subclass) under MIT, so we can
|
| 98 |
+
vendor it exactly as we did with OPSD.
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Candidate 2 — TAID (Temporally Adaptive Interpolated Distillation) ⭐ RECOMMENDED
|
| 103 |
+
|
| 104 |
+
### Sources
|
| 105 |
+
- **arXiv:** https://arxiv.org/abs/2501.16937 (Shing, Misaki, Bao, Yokoi, Akiba — Sakana AI, ICLR 2025)
|
| 106 |
+
- **GitHub:** https://github.com/SakanaAI/TAID
|
| 107 |
+
- License: **Apache-2.0**
|
| 108 |
+
- 121 stars, last push 2025-10-06 (actively maintained)
|
| 109 |
+
- Reference implementations of GKD, DistiLLM, Adaptive-KL, CTKD, DKD are also in `src/distil_losses/` for free
|
| 110 |
+
- Released artefacts: `TAID-LLM-1.5B`, `TAID-VLM-2B` on HuggingFace (so the loss is verified at non-trivial scale).
|
| 111 |
+
- Maturity: **published, single-author commits** but reproducibly trained two SoTA compact models with it.
|
| 112 |
+
|
| 113 |
+
### Loss core (interpolated teacher target)
|
| 114 |
+
Standard distillation losses (forward KL, reverse KL, JSD, including the
|
| 115 |
+
`generalized_jsd_loss` we already have) target a **fixed** teacher distribution
|
| 116 |
+
`p_T`. TAID replaces this fixed target with a **time-dependent interpolated
|
| 117 |
+
target** `p_t` that starts close to the student and moves toward the teacher
|
| 118 |
+
as training progresses:
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
p_t(y | x) = (1 − t) · q_θ_stop(y | x) + t · p_T(y | x) (1)
|
| 122 |
+
|
| 123 |
+
J_TAID(θ; t) = D_KL( p_t ‖ q_θ ) (2)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
`q_θ_stop` is the student's own current distribution with stop-gradient. The
|
| 127 |
+
interpolation coefficient `t ∈ [t_start, 1]` is updated each step by an
|
| 128 |
+
**adaptive momentum schedule** that grows `t` faster when training loss is
|
| 129 |
+
falling and slower when it stalls — this is the "temporally adaptive" part.
|
| 130 |
+
The Sakana paper proves (Theorem 4.1) that for the regression analogue this
|
| 131 |
+
schedule provably prevents the mode-collapse failure mode of pure
|
| 132 |
+
self-distillation.
|
| 133 |
+
|
| 134 |
+
Critically, `D_KL(p_t ‖ q_θ)` is just any divergence on shifted target — you
|
| 135 |
+
can equally well plug in JSD, reverse KL, or **the generalized_jsd_loss the
|
| 136 |
+
framework already exports**. TAID is therefore a *wrapper around an existing
|
| 137 |
+
divergence*, not a competing divergence.
|
| 138 |
+
|
| 139 |
+
### Why it composes with the existing stack
|
| 140 |
+
- It **wraps** `composer_replication.opsd.generalized_jsd_loss` rather than
|
| 141 |
+
replacing it. The change is "compute the JSD against `p_t` instead of
|
| 142 |
+
`p_T`" — a few lines around the existing call site.
|
| 143 |
+
- Addresses a documented weakness of OPSD-style self-distillation: when the
|
| 144 |
+
teacher's privileged-context distribution is far from the student's
|
| 145 |
+
capacity, the JSD signal can be noisy or push the student into mode
|
| 146 |
+
averaging. TAID's annealed target gives the student a curriculum.
|
| 147 |
+
- Empirical evidence the Sakana paper directly compares with: TAID + JSD
|
| 148 |
+
beats GKD + JSD beats DistiLLM + skew-KL on Phi-3 → TinyLlama distillation,
|
| 149 |
+
with **0.7 h / epoch** vs **9.8 h / epoch** for GKD on identical hardware.
|
| 150 |
+
The speed comes from not needing student-generated outputs (SGOs) at every
|
| 151 |
+
step the way GKD does.
|
| 152 |
+
- Composes additively with channel 1 (GRPO) and channel 3 (trace-replay DPO)
|
| 153 |
+
because TAID lives strictly inside channel 2.
|
| 154 |
+
|
| 155 |
+
### Implementation cost
|
| 156 |
+
- **~150 LOC**. The change is:
|
| 157 |
+
1. A `TAIDState` object that holds `t`, the EMA of training loss, and the
|
| 158 |
+
momentum coefficient β (default 0.99).
|
| 159 |
+
2. A function `taid_target(student_logits, teacher_logits, t)` that returns
|
| 160 |
+
`(1−t)·softmax(student_logits.detach()) + t·softmax(teacher_logits)`.
|
| 161 |
+
3. A scheduler hook that updates `t` after each backward pass per
|
| 162 |
+
Algorithm 1 of the paper.
|
| 163 |
+
- Dependencies: nothing new.
|
| 164 |
+
- Reference implementation in `SakanaAI/TAID/src/distil_losses/taid.py` is
|
| 165 |
+
Apache-2.0 — vendor-friendly, same pattern as our OPSD lift.
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## Candidate 3 — Entropy-Aware On-Policy Distillation (Entropy-Aware OPD) ⭐ RECOMMENDED
|
| 170 |
+
|
| 171 |
+
### Sources
|
| 172 |
+
- **OpenReview (ICLR 2026 Spotlight):** https://openreview.net/forum?id=WSRQ37tzk1
|
| 173 |
+
- **IBM Research page:** https://research.ibm.com/publications/entropy-aware-on-policy-distillation-of-language-models
|
| 174 |
+
- Authors: Woogyeol Jin, Taywon Min, Yongjin Yang, Swanand Kadhe, Yi Zhou, Dennis Wei, Nathalie Baracaldo, Kimin Lee (KAIST + IBM Research)
|
| 175 |
+
- Status: **ICLR 2026 Spotlight**, submission #113. License on the OpenReview record is **CC BY 4.0**.
|
| 176 |
+
- Code: not yet released on GitHub at the time of audit (paper accepted 2026-03-03). IBM authors typically release within the conference window. **Maturity flag: paper-ready, code-pending.** This is the only candidate where we'd need to re-implement from the paper.
|
| 177 |
+
|
| 178 |
+
### Loss core (entropy-gated forward/reverse KL mixture)
|
| 179 |
+
The paper diagnoses a failure mode in the reverse-KL-on-policy distillation
|
| 180 |
+
recipe used by MiniLLM, OPSD, and (implicitly) by our SDPO channel: when the
|
| 181 |
+
**teacher distribution has high entropy at a given token**, reverse KL's
|
| 182 |
+
mode-seeking gradient becomes noisy and collapses the student's diversity.
|
| 183 |
+
Their fix: at each token `t`, gate between forward and reverse KL based on
|
| 184 |
+
the teacher's entropy:
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
H_t = − Σ_v p_T(v | x, y_<t) · log p_T(v | x, y_<t) (teacher entropy)
|
| 188 |
+
|
| 189 |
+
α_t = sigmoid( (H_t − τ) / s ) ∈ (0, 1)
|
| 190 |
+
|
| 191 |
+
L_EA(θ) = E_{y ~ q_θ} Σ_t [
|
| 192 |
+
(1 − α_t) · D_KL( q_θ(· | x, y_<t) ‖ p_T(· | x, y_<t) ) ← reverse KL
|
| 193 |
+
+ α_t · D_KL( p_T(· | x, y_<t) ‖ q_θ(· | x, y_<t) ) ← forward KL
|
| 194 |
+
]
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
`τ` is an entropy threshold (default ≈ 1.0 nat in their experiments) and `s`
|
| 198 |
+
is a temperature controlling how sharp the gate is. When the teacher is
|
| 199 |
+
confident (`H_t` small → `α_t ≈ 0`) the loss is pure reverse KL, identical to
|
| 200 |
+
MiniLLM/OPSD behaviour. When the teacher is uncertain (`H_t` large → `α_t ≈ 1`)
|
| 201 |
+
the loss switches to forward KL, which is mode-covering and preserves
|
| 202 |
+
student diversity.
|
| 203 |
+
|
| 204 |
+
Reported gains over baseline reverse-KL OPD on Qwen3-0.6B/1.7B/4B: Pass@8 on
|
| 205 |
+
six math benchmarks improves by +1.37 / +2.39 / +5.05 respectively. The
|
| 206 |
+
larger gains at larger student size suggest the failure mode reverse KL
|
| 207 |
+
exhibits gets *worse* with capacity, not better.
|
| 208 |
+
|
| 209 |
+
### Why it composes with the existing stack
|
| 210 |
+
- It is **strictly token-wise**: same trajectory, same teacher logits, same
|
| 211 |
+
rollout pipeline as the existing channel 2. The only change is the loss
|
| 212 |
+
reduction — instead of computing `generalized_jsd_loss` with a single fixed
|
| 213 |
+
β, you compute a per-token mixture of forward and reverse KL with weight
|
| 214 |
+
given by teacher entropy.
|
| 215 |
+
- This is genuinely orthogonal to OPSD/SDPO. OPSD's contribution is
|
| 216 |
+
*privileged-context teacher distribution under student rollouts*. EA-OPD's
|
| 217 |
+
contribution is *which divergence to use at each token of that distribution*.
|
| 218 |
+
Both can be true simultaneously.
|
| 219 |
+
- Directly addresses a failure mode the framework's roadmap will hit:
|
| 220 |
+
multi-teacher trace replay (channel 3) produces high-entropy aggregated
|
| 221 |
+
teacher distributions at exactly the steps where teachers disagree. Those
|
| 222 |
+
are the steps where reverse KL behaves worst. EA-OPD's entropy gate would
|
| 223 |
+
automatically soften the loss on those exact tokens.
|
| 224 |
+
- Composes with TAID (Candidate 2) too — they operate on different axes:
|
| 225 |
+
TAID anneals the *target distribution*, EA-OPD chooses the *divergence
|
| 226 |
+
direction*. Stacking is straightforward and proposed as ADR-007 follow-up.
|
| 227 |
+
|
| 228 |
+
### Implementation cost
|
| 229 |
+
- **~120 LOC** estimate (no reference code to vendor yet).
|
| 230 |
+
- Dependencies: nothing new. Token-level entropy is `−(p * log p).sum(-1)`,
|
| 231 |
+
forward KL is the existing teacher-on-student term, reverse KL is the
|
| 232 |
+
student-on-teacher term we already compute for the JSD in OPSD. The work is
|
| 233 |
+
re-shaping the existing per-token loss to expose both directions.
|
| 234 |
+
- **Risk note:** code not yet public. We should hold this candidate behind a
|
| 235 |
+
feature flag until the IBM/KAIST team releases reference code (expected by
|
| 236 |
+
ICLR 2026 in May). If the implementation ships sooner we should vendor and
|
| 237 |
+
match line-for-line; if not, we re-derive from the paper formula and add a
|
| 238 |
+
unit test that reproduces their toy entropy-vs-divergence plot.
|
| 239 |
+
|
| 240 |
+
---
|
| 241 |
+
|
| 242 |
+
## Honourable mention — KTO (Kahneman-Tversky Optimization)
|
| 243 |
+
|
| 244 |
+
- **arXiv:** https://arxiv.org/abs/2402.01306
|
| 245 |
+
- **Code:** integrated into HuggingFace `trl` library since v0.8 (Apache-2.0).
|
| 246 |
+
- License/maturity: **production**. KTO is a standard `trl` trainer alongside DPO.
|
| 247 |
+
|
| 248 |
+
### Loss core
|
| 249 |
+
KTO replaces preference pairs with **per-output binary desirability** signals.
|
| 250 |
+
For a desirable output `y_+` and undesirable output `y_−`:
|
| 251 |
+
|
| 252 |
+
```
|
| 253 |
+
r_θ(x, y) = β · log( π_θ(y|x) / π_ref(y|x) )
|
| 254 |
+
|
| 255 |
+
z_0 = E_{x', y' ~ π_θ}[ KL( π_θ(·|x') ‖ π_ref(·|x') ) ] (reference point)
|
| 256 |
+
|
| 257 |
+
L_KTO = E_{x, y_+} [λ_D · (1 − σ(r_θ(x, y_+) − z_0))] (desirable)
|
| 258 |
+
+ E_{x, y_−} [λ_U · (1 − σ(z_0 − r_θ(x, y_−)))] (undesirable)
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
with default `λ_D = λ_U = 1`. The derivation is via prospect theory: this is
|
| 262 |
+
a Kahneman-Tversky utility function applied to the implicit reward. KTO
|
| 263 |
+
matches DPO at 1B–30B even though it sees only `2n` binary signals where
|
| 264 |
+
DPO sees `n` pairs.
|
| 265 |
+
|
| 266 |
+
### Why we down-rank it relative to the top-3
|
| 267 |
+
KTO is the right answer **only if** the framework wants to ingest single-side
|
| 268 |
+
trace signals (e.g., "this trace step succeeded" / "this step crashed the
|
| 269 |
+
agent") without constructing pairs. The current
|
| 270 |
+
`research/05-trace-replay-distillation.md` design **does** construct pairs
|
| 271 |
+
from multi-teacher replay (that is the whole point of the multi-teacher
|
| 272 |
+
variance signal), so the marginal value of KTO is small *for channel 3 as
|
| 273 |
+
specified*. If the trace-replay design pivots toward absolute scores per
|
| 274 |
+
step rather than relative pairs, KTO becomes the right loss and is already
|
| 275 |
+
free from `trl`. Add to the backlog as conditional.
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
## Audited but NOT recommended
|
| 280 |
+
|
| 281 |
+
### GKD — Generalized Knowledge Distillation (Agarwal et al., 2023)
|
| 282 |
+
- **arXiv:** https://arxiv.org/abs/2306.13649 (Google DeepMind)
|
| 283 |
+
- **Loss core:** student samples its own outputs, teacher provides token
|
| 284 |
+
probabilities, divergence is generalized JSD with parameter β:
|
| 285 |
+
```
|
| 286 |
+
D_JSD(β)(P‖Q) = β·KL(P ‖ βP+(1−β)Q) + (1−β)·KL(Q ��� βP+(1−β)Q)
|
| 287 |
+
```
|
| 288 |
+
- **Why excluded:** **this is exactly the formula we already have** as
|
| 289 |
+
`composer_replication.opsd.generalized_jsd_loss` (lifted from
|
| 290 |
+
`siyan-zhao/OPSD`). GKD's contribution beyond the loss formula is the
|
| 291 |
+
on-policy student sampling protocol — which OPSD also does. No incremental
|
| 292 |
+
value to add.
|
| 293 |
+
|
| 294 |
+
### DistiLLM (Ko et al., ICML 2024)
|
| 295 |
+
- **arXiv:** https://arxiv.org/abs/2402.03898
|
| 296 |
+
- **GitHub:** https://github.com/jongwooko/distillm — MIT, last push 2025-03
|
| 297 |
+
- **Loss core:** *Skew KL divergence* `KL(p ‖ λp + (1−λ)q)` plus an *adaptive
|
| 298 |
+
off-policy* student-generated-output (SGO) scheduler.
|
| 299 |
+
- **Why excluded:** the skew-KL is a special case of generalized JSD (set the
|
| 300 |
+
mixture coefficient appropriately) — same family the framework already
|
| 301 |
+
has. The interesting contribution, the SGO scheduler, is a process
|
| 302 |
+
optimisation, not a loss. The TAID paper's own ablation (Table 6) shows
|
| 303 |
+
TAID > Skew KL across student sizes, so TAID dominates this candidate.
|
| 304 |
+
|
| 305 |
+
### MiniLLM (Gu et al., ICLR 2024)
|
| 306 |
+
- **arXiv:** https://arxiv.org/abs/2306.08543
|
| 307 |
+
- **GitHub:** https://github.com/microsoft/LMOps/tree/main/minillm — MIT, repo
|
| 308 |
+
active (last push 2026-04)
|
| 309 |
+
- **Loss core:** reverse KL minimised by policy-gradient on student rollouts,
|
| 310 |
+
with three optimisation tricks: single-step decomposition (variance
|
| 311 |
+
reduction), teacher-mixed sampling (anti-reward-hacking), length
|
| 312 |
+
normalisation.
|
| 313 |
+
- **Why excluded:** reverse-KL on-policy distillation **is the same recipe
|
| 314 |
+
family as SDPO/OPSD** the framework already implements. Adding MiniLLM
|
| 315 |
+
would be a parallel implementation of the same idea, not an addition.
|
| 316 |
+
Entropy-Aware OPD (Candidate 3) is a *strict improvement* over MiniLLM's
|
| 317 |
+
pure reverse-KL on exactly the failure mode MiniLLM identifies (mode
|
| 318 |
+
collapse in high-entropy regions).
|
| 319 |
+
|
| 320 |
+
### Self-Rewarding Language Models (Yuan et al., 2024)
|
| 321 |
+
- **arXiv:** https://arxiv.org/abs/2401.10020 (Meta + NYU)
|
| 322 |
+
- **Why excluded:** SRLM is a *training procedure* (iterative DPO with the
|
| 323 |
+
model judging its own outputs), not a loss. The actual loss is plain DPO,
|
| 324 |
+
which the framework already supports. The procedural contribution belongs
|
| 325 |
+
in a future ADR on data generation, not in the distillation channel.
|
| 326 |
+
|
| 327 |
+
### TAID's relationship to "TAID arXiv 2501.16937 if it exists"
|
| 328 |
+
The user asked us to verify existence. **It exists.** Submitted 2025-01-28,
|
| 329 |
+
ICLR 2025, code at https://github.com/SakanaAI/TAID with two released
|
| 330 |
+
checkpoints (`TAID-LLM-1.5B`, `TAID-VLM-2B`). Confirmed primary source.
|
| 331 |
+
|
| 332 |
+
---
|
| 333 |
+
|
| 334 |
+
## 2026 papers found
|
| 335 |
+
|
| 336 |
+
The targeted Exa search (`category=research paper`, `startPublishedDate=2026-01-01`)
|
| 337 |
+
surfaced four 2026 distillation papers worth listing for completeness:
|
| 338 |
+
|
| 339 |
+
1. **Entropy-Aware On-Policy Distillation** — ICLR 2026 Spotlight. ⭐ Promoted to top-3 above.
|
| 340 |
+
2. **KL for a KL: On-Policy Distillation with Control Variate Baseline** (arXiv 2605.07865, Oh et al., 2026-05). Variance-reduction trick for on-policy KL distillation. Useful future read but not a new loss — it's a baseline subtraction added to MiniLLM-style policy gradient.
|
| 341 |
+
3. **Rethinking On-Policy Distillation: Phenomenology, Mechanism, and Recipe** (https://github.com/thunlp/OPD, Tsinghua NLP, last push 2026-04). Empirical study, not a new loss formulation.
|
| 342 |
+
4. **Hybrid Policy Distillation for LLMs** (ICML 2026 poster, Zhu et al.). Combines off-policy and on-policy distillation; positioned as a recipe rather than a new loss; abstract suggests strong overlap with TAID's annealing argument.
|
| 343 |
+
5. **Don't Ignore the Tail: Decoupling top-K Probabilities for Efficient Language Model Distillation** (ICML 2026 poster, Dasgupta et al.). Targets the long-tail of teacher distributions. Interesting but currently only an abstract; deferred until the camera-ready PDF is available.
|
| 344 |
+
|
| 345 |
+
None of these except Entropy-Aware OPD are mature enough (released code +
|
| 346 |
+
license + reproducible scale) to recommend adding right now.
|
| 347 |
+
|
| 348 |
+
---
|
| 349 |
+
|
| 350 |
+
## Recommended follow-up wiring
|
| 351 |
+
|
| 352 |
+
For ADR-007 the proposed addition is a `composer_replication.distillation`
|
| 353 |
+
sub-package with three pluggable hooks:
|
| 354 |
+
|
| 355 |
+
```
|
| 356 |
+
composer_replication/
|
| 357 |
+
distillation/
|
| 358 |
+
__init__.py
|
| 359 |
+
targets.py # taid_target(...), fixed_target(...) ← Candidate 2
|
| 360 |
+
losses.py # reuses opsd.generalized_jsd_loss
|
| 361 |
+
# adds entropy_aware_kl_loss(...) ← Candidate 3
|
| 362 |
+
preference/
|
| 363 |
+
simpo.py # simpo_loss(...) ← Candidate 1
|
| 364 |
+
dpo.py # existing trace-replay path
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
The composition rule for the total loss becomes:
|
| 368 |
+
|
| 369 |
+
```
|
| 370 |
+
L_total = λ_grpo · L_GRPO (channel 1, unchanged)
|
| 371 |
+
+ λ_distill · L_distill (channel 2, see below)
|
| 372 |
+
+ λ_pref · L_pref (channel 3, choose DPO or SimPO)
|
| 373 |
+
|
| 374 |
+
L_distill = entropy_aware_kl_loss(
|
| 375 |
+
target = taid_target(student, teacher, t),
|
| 376 |
+
student = student,
|
| 377 |
+
teacher_entropy_gate = α_t
|
| 378 |
+
)
|
| 379 |
+
```
|
| 380 |
+
|
| 381 |
+
This keeps the existing `generalized_jsd_loss` reachable as a fallback
|
| 382 |
+
(set `α_t ≡ 0` and `t ≡ 1` and you recover SDPO/OPSD exactly).
|
| 383 |
+
|
| 384 |
+
---
|
| 385 |
+
|
| 386 |
+
## Sources index
|
| 387 |
+
|
| 388 |
+
| Paper | arXiv | GitHub | License | Last push | Maturity |
|
| 389 |
+
|-------|-------|--------|---------|-----------|----------|
|
| 390 |
+
| SimPO | https://arxiv.org/abs/2405.14734 | https://github.com/princeton-nlp/SimPO | MIT | 2024-10-12 | Production |
|
| 391 |
+
| TAID | https://arxiv.org/abs/2501.16937 | https://github.com/SakanaAI/TAID | Apache-2.0 | 2025-10-06 | Production |
|
| 392 |
+
| Entropy-Aware OPD | n/a (OpenReview WSRQ37tzk1) | code-pending | CC BY 4.0 (paper) | n/a | Paper-only |
|
| 393 |
+
| KTO | https://arxiv.org/abs/2402.01306 | huggingface/trl (built-in) | Apache-2.0 | continuous | Production |
|
| 394 |
+
| GKD | https://arxiv.org/abs/2306.13649 | (no official repo from authors; reproduced inside SakanaAI/TAID and jongwooko/distillm) | n/a | n/a | Reference only |
|
| 395 |
+
| DistiLLM | https://arxiv.org/abs/2402.03898 | https://github.com/jongwooko/distillm | (no LICENSE file at audit time) | 2025-03-13 | Research |
|
| 396 |
+
| MiniLLM | https://arxiv.org/abs/2306.08543 | https://github.com/microsoft/LMOps/tree/main/minillm | MIT | 2026-04-08 | Production |
|
| 397 |
+
| Self-Rewarding LM | https://arxiv.org/abs/2401.10020 | (no canonical repo; integrated into many forks) | n/a | n/a | Procedure, not a loss |
|
| 398 |
+
|
| 399 |
+
---
|
| 400 |
+
|
| 401 |
+
## Notes for ADR-007 author
|
| 402 |
+
|
| 403 |
+
1. **SimPO and TAID can land independently and without coordination.** They
|
| 404 |
+
touch different files and do not compete.
|
| 405 |
+
2. **Entropy-Aware OPD should land last.** Wait for the IBM/KAIST authors'
|
| 406 |
+
code release; if it's not out by the time we want to ship the change, the
|
| 407 |
+
formula is simple enough to re-derive but we should pin a unit test that
|
| 408 |
+
reproduces the paper's Figure 3 entropy-vs-divergence behaviour.
|
| 409 |
+
3. **Do not also pull in GKD/DistiLLM/MiniLLM.** Their loss contributions are
|
| 410 |
+
strict subsets of what (TAID + Entropy-Aware OPD + existing
|
| 411 |
+
`generalized_jsd_loss`) covers.
|
| 412 |
+
4. **KTO should be added as a backlog item** with a "trigger" condition:
|
| 413 |
+
when the trace-replay reward design moves from preference pairs to per-step
|
| 414 |
+
binary signals, switch on the `trl.KTOTrainer` path.
|
| 415 |
+
|
| 416 |
+
---
|
| 417 |
+
|
| 418 |
+
*Absolute path of this report:* `/mnt/e/CS/HF/composer-replication-framework/docs/research/SELF_DISTILLATION_LANDSCAPE.md`
|
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wave 13 Adversarial Cross-Model Review
|
| 2 |
+
|
| 3 |
+
**Reviewer:** Claude Opus 4.7 (sub-agent via delegate_task)
|
| 4 |
+
**Date:** 2026-05-26
|
| 5 |
+
**Scope:** Wave 13 additions only (35 new tests, 4 ADRs, 6 new modules)
|
| 6 |
+
**Method:** Read-and-grep audit + targeted test runs (CPU)
|
| 7 |
+
|
| 8 |
+
## Top-line verdict
|
| 9 |
+
|
| 10 |
+
**CONDITIONAL PASS with two BLOCKERs.** Wave 13 substantially advances
|
| 11 |
+
the brief expansion (serverless DiLoCo abstraction, replaysim
|
| 12 |
+
normalization, three distillation losses, PRIME-RL recipe, Monarch
|
| 13 |
+
tie-in). The **distillation losses are the strongest deliverable** —
|
| 14 |
+
real, well-tested, mathematically faithful to the cited papers. The
|
| 15 |
+
serverless-DiLoCo local executor + ObjectStoreAllReduce barrier are
|
| 16 |
+
also genuine and exercised by 3 real multi-process tests.
|
| 17 |
+
|
| 18 |
+
**However, two material claims are not test-validated, and one new
|
| 19 |
+
module silently produces a degenerate loss in its primary code path.**
|
| 20 |
+
ADR claims that say "X is added to compose_loss" describe code that
|
| 21 |
+
wasn't actually written. The MockManager → DiLoCo "drop-in" is
|
| 22 |
+
unverified end-to-end.
|
| 23 |
+
|
| 24 |
+
Wave 11's reviewer found 2 genuine BLOCKERs. This review finds **2
|
| 25 |
+
BLOCKERs + 4 SUGGESTIONs + 2 NITs**.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Finding 1 — BLOCKER: PRIME-RL `composer_loss.loss_fn` SDPO term is mathematically degenerate (always 0)
|
| 30 |
+
|
| 31 |
+
**Severity:** BLOCKER
|
| 32 |
+
**Evidence:** `composer_replication/recipes/prime_rl/composer_loss.py:79-86`
|
| 33 |
+
|
| 34 |
+
The PRIME-RL composer-loss adapter applies `unsqueeze(-1)` to `(B, T)`
|
| 35 |
+
log-prob tensors before passing them to `generalized_jsd_loss`, which
|
| 36 |
+
calls `F.log_softmax(..., dim=-1)`. Softmax of a single-element vector
|
| 37 |
+
is exactly 1.0; its log is 0. Therefore both `student_log_probs` and
|
| 38 |
+
`teacher_log_probs` are identically zero, the JSD between them is 0,
|
| 39 |
+
and the SDPO contribution **is always 0 regardless of `alpha_sdpo` or
|
| 40 |
+
the actual log-prob values.**
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
>>> import torch.nn.functional as F
|
| 44 |
+
>>> F.log_softmax(torch.randn(2, 3, 1), dim=-1)
|
| 45 |
+
tensor([[[0.],[0.],[0.]],[[0.],[0.],[0.]]])
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
The docstring calls this "a deliberate approximation," but it is not
|
| 49 |
+
an approximation — it's a mathematically degenerate operation that
|
| 50 |
+
silently disables channel 2.
|
| 51 |
+
|
| 52 |
+
**Fix direction:**
|
| 53 |
+
- Gate the SDPO branch behind `len(trainer_lp.shape) >= 3`, raising
|
| 54 |
+
`NotImplementedError` until PRIME-RL surfaces full logits.
|
| 55 |
+
- Update `prime_rl_recipe.md` and ADR-006 to stop claiming PRIME-RL
|
| 56 |
+
has working SDPO; mark it deferred.
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## Finding 2 — BLOCKER: ADR-007 declares `compose_loss` kwargs that were never added
|
| 61 |
+
|
| 62 |
+
**Severity:** BLOCKER
|
| 63 |
+
**Evidence:**
|
| 64 |
+
- `docs/adrs/ADR-007-self-distillation-losses.md:103-108` claims:
|
| 65 |
+
> `composer_replication.compose_loss` gets new optional kwargs:
|
| 66 |
+
> - `dpo_variant: Literal["dpo", "simpo"] = "dpo"` — switches channel 3
|
| 67 |
+
> - `sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none"` — wraps channel 2
|
| 68 |
+
> - `taid_schedule_step: int | None = None`
|
| 69 |
+
> - `taid_total_steps: int | None = None`
|
| 70 |
+
- `composer_replication/loss.py:54-65` actual signature has **none**
|
| 71 |
+
of these. `grep -n "dpo_variant\|sdpo_wrapper\|taid"
|
| 72 |
+
composer_replication/loss.py` returns empty.
|
| 73 |
+
|
| 74 |
+
The new losses live in `composer_replication.distillation` as
|
| 75 |
+
standalone functions but **are not wired into the framework's actual
|
| 76 |
+
loss composition.** A user reading ADR-007 + the README would believe
|
| 77 |
+
`compose_loss(model, inputs, dpo_variant="simpo", sdpo_wrapper="taid", ...)`
|
| 78 |
+
works; it would raise `TypeError`. The 17 distillation tests verify
|
| 79 |
+
the standalone losses but never exercise integration.
|
| 80 |
+
|
| 81 |
+
**Fix direction:**
|
| 82 |
+
- Either (a) add the kwargs to `compose_loss` and write at least one
|
| 83 |
+
integration test combining e.g. SDPO+TAID (~30 LOC change), or
|
| 84 |
+
- (b) downgrade ADR-007 status to "Standalone losses landed;
|
| 85 |
+
integration deferred to Wave 14."
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## Finding 3 — SUGGESTION: `default.yaml` replaysim recipe uses string ops on list-of-dict fields
|
| 90 |
+
|
| 91 |
+
**Severity:** SUGGESTION (would be BLOCKER if a test exercised the real path)
|
| 92 |
+
**Evidence:**
|
| 93 |
+
- `composer_replication/recipes/replaysim/default.yaml` configures
|
| 94 |
+
`text_length_filter`, `words_num_filter`, `special_characters_filter`,
|
| 95 |
+
`document_deduplicator` with `text_keys: ["chosen", "rejected"]`.
|
| 96 |
+
- In the record produced by `_dpo_pair_to_dj_record`, `chosen` and
|
| 97 |
+
`rejected` are **lists of dicts**
|
| 98 |
+
(`[{"role": "assistant", "content": "..."}]`) — not strings.
|
| 99 |
+
- data-juicer's `text_length_filter` expects string-typed fields;
|
| 100 |
+
running it on a list will either crash or no-op silently.
|
| 101 |
+
|
| 102 |
+
The reason no test catches this: tests only validate the real path *if
|
| 103 |
+
data-juicer is installed*, and even then only check `__init__` succeeds.
|
| 104 |
+
There is no test that calls `normalize()` against a real data-juicer
|
| 105 |
+
executor with the default recipe.
|
| 106 |
+
|
| 107 |
+
**Fix direction:**
|
| 108 |
+
- Reshape `_dpo_pair_to_dj_record` to extract `content` strings
|
| 109 |
+
alongside the messages-format list.
|
| 110 |
+
- Add one test (skip-marked unless `data_juicer` is importable) that
|
| 111 |
+
runs the real op-graph on 3 hand-crafted records.
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## Finding 4 — SUGGESTION: MockManager → torchft.DiLoCo "drop-in" claim is unverified end-to-end
|
| 116 |
+
|
| 117 |
+
**Severity:** SUGGESTION
|
| 118 |
+
**Evidence:**
|
| 119 |
+
- `composer_replication/diloco/serverless/allreduce.py:188-191` claims
|
| 120 |
+
MockManager "drops into" `make_diloco_outer_loop`.
|
| 121 |
+
- The only test covering MockManager (`test_mock_manager_shape_compat`)
|
| 122 |
+
is a `hasattr` smoke that calls `.allreduce` on a `world_size=1`
|
| 123 |
+
store (passthrough).
|
| 124 |
+
- torchft.Manager has additional surface area
|
| 125 |
+
(`current_step`, `is_leader`, `_pg`, `report_error`,
|
| 126 |
+
internal step accounting) that DiLoCo's `_apply_pseudogradient`
|
| 127 |
+
may consult depending on version.
|
| 128 |
+
|
| 129 |
+
**Fix direction:**
|
| 130 |
+
- Add a single integration test that constructs
|
| 131 |
+
`make_diloco_outer_loop(manager=MockManager(store), ...)` against a
|
| 132 |
+
tiny `nn.Linear` and runs one outer round — even single-process.
|
| 133 |
+
- Audit `torchft/local_sgd.py` for the `Manager`-rooted call sites and
|
| 134 |
+
add stubs for any methods DiLoCo actually consults beyond `allreduce`.
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## Finding 5 — SUGGESTION: README claim "9 multi-process tests" is mildly inflated
|
| 139 |
+
|
| 140 |
+
**Severity:** SUGGESTION (NIT bordering)
|
| 141 |
+
**Evidence:**
|
| 142 |
+
- README.md and V1_V8_COVERAGE both state: *"9 multi-process tests
|
| 143 |
+
pinning the allreduce barrier."*
|
| 144 |
+
- Actual breakdown:
|
| 145 |
+
- 4 single-process unit tests + `test_mock_manager_shape_compat` (5)
|
| 146 |
+
- 4 multi-process tests spawning subprocesses (parametrized [2,3] of
|
| 147 |
+
`_runs_allreduce_across_replicas`, `_handles_multiple_rounds`,
|
| 148 |
+
`_reports_failed_replicas`)
|
| 149 |
+
- Of the 4 multi-process tests, only **3 actually exercise the
|
| 150 |
+
allreduce barrier**; `_reports_failed_replicas` deliberately raises
|
| 151 |
+
before any allreduce call.
|
| 152 |
+
|
| 153 |
+
**Wave 13 clearly does NOT fake-pass via world_size=1** — the multi-
|
| 154 |
+
process barrier is real. But the count is rounded up.
|
| 155 |
+
|
| 156 |
+
**Fix direction:** Replace "9 multi-process tests" with "9 tests
|
| 157 |
+
covering the serverless DiLoCo layer, of which 4 spawn real
|
| 158 |
+
subprocesses and 3 exercise the allreduce barrier across replicas."
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## Finding 6 — SUGGESTION: PRIME-RL channel 1 is REINFORCE not GRPO; ignores `inference_logprobs`
|
| 163 |
+
|
| 164 |
+
**Severity:** SUGGESTION
|
| 165 |
+
**Evidence:** `composer_replication/recipes/prime_rl/composer_loss.py:62-68`
|
| 166 |
+
computes:
|
| 167 |
+
```python
|
| 168 |
+
grpo_loss = -(advantages * trainer_lp * mask).sum() / mask.sum().clamp_min(epsilon)
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
This is plain REINFORCE with advantage. PRIME-RL's `LossInputs`
|
| 172 |
+
exposes `inference_logprobs` precisely because GRPO-with-replay-buffer
|
| 173 |
+
requires the importance-sampling ratio
|
| 174 |
+
`exp(trainer_lp - inference_lp)` (PPO-style clipped objective).
|
| 175 |
+
|
| 176 |
+
The file says "SKELETON" so this isn't a hidden bug per se, but the
|
| 177 |
+
loss is **labeled GRPO and is not GRPO**.
|
| 178 |
+
|
| 179 |
+
**Fix direction:** Either implement the ratio + clipping (~20 LOC) or
|
| 180 |
+
rename channel-1 comment to "REINFORCE-with-advantage stub" with a TODO.
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
## Finding 7 — NIT: ModalExecutor / HFJobsExecutor are skeleton-only with `NotImplementedError` in `__init__`
|
| 185 |
+
|
| 186 |
+
**Severity:** NIT (this is documented, but README phrasing is slightly soft)
|
| 187 |
+
**Evidence:** Honestly documented as skeletons in the code, ADR-005,
|
| 188 |
+
and README. NIT: a user trying `ModalExecutor()` gets a runtime error
|
| 189 |
+
rather than an import-time clue.
|
| 190 |
+
|
| 191 |
+
**Fix direction:** Low priority. Update README phrase to "skeleton-only
|
| 192 |
+
— raises NotImplementedError until v0.x." Or use a `__getattr__` on
|
| 193 |
+
the package that raises a clearer message.
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
## Finding 8 — NIT: SimPO test uses positive log-probs (impossible values)
|
| 198 |
+
|
| 199 |
+
**Severity:** NIT
|
| 200 |
+
**Evidence:** `test_distillation_losses.py:27-46` calls `simpo_loss`
|
| 201 |
+
with `chosen=tensor([0.5, 0.4, 0.3])`. Log-probabilities are bounded
|
| 202 |
+
above by 0; positive values aren't possible from any softmax. The tests
|
| 203 |
+
still verify the formula correctly, but the test inputs aren't legal.
|
| 204 |
+
|
| 205 |
+
**Fix direction:** Use negative values — purely cosmetic.
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## Cross-cutting risk check
|
| 210 |
+
|
| 211 |
+
73 tests passed in 29.29s on the CPU-fast subset. Spike 008 5/5 still
|
| 212 |
+
pass. The new `composer_replication.diloco.serverless` package is
|
| 213 |
+
purely additive; the existing `make_diloco_outer_loop` is untouched.
|
| 214 |
+
**No cross-wave regressions detected on CPU.** GPU tests + slow CPU
|
| 215 |
+
e2e tests not re-run; regression risk low since Wave 13 doesn't touch
|
| 216 |
+
their dependencies.
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
## Summary scorecard
|
| 221 |
+
|
| 222 |
+
| Item | Verdict |
|
| 223 |
+
|---|---|
|
| 224 |
+
| Distillation module (SimPO/TAID/Entropy-Aware OPD) standalone | ✅ Real, well-tested, paper-faithful |
|
| 225 |
+
| Distillation integrated into `compose_loss` | ❌ **Not implemented** despite ADR-007 (Finding 2) |
|
| 226 |
+
| ObjectStoreAllReduce + LocalProcessExecutor | ✅ Real multi-process barrier validated |
|
| 227 |
+
| MockManager → DiLoCo drop-in | 🟡 Shape-checked only; integration unverified (Finding 4) |
|
| 228 |
+
| Modal/HFJobs adapters | 🟡 Honestly documented as skeletons (Finding 7) |
|
| 229 |
+
| Replaysim DJNormalizer passthrough | ✅ Works |
|
| 230 |
+
| Replaysim default.yaml against real data-juicer | ❌ **Recipe field types don't match record shape** (Finding 3) |
|
| 231 |
+
| PRIME-RL composer_loss.loss_fn | ❌ **SDPO term silently 0** (Finding 1); channel 1 is REINFORCE not GRPO (Finding 6) |
|
| 232 |
+
| Monarch actors | ✅ Honest skeleton; raises NotImplementedError |
|
| 233 |
+
| Altered-minds tie-in doc | ✅ Design-only, scoped honestly |
|
| 234 |
+
| 35 new tests | All pass; 3 of 4 multi-process tests are genuine (Finding 5) |
|
| 235 |
+
|
| 236 |
+
**Recommendation:** Address Findings 1 and 2 before publishing the
|
| 237 |
+
Wave 13 expansion as "closed." Findings 3 and 4 should be addressed
|
| 238 |
+
before any user attempts the real data-juicer or real torchft DiLoCo
|
| 239 |
+
path. Findings 5–8 are cleanup.
|
|
@@ -16,16 +16,23 @@ keywords = [
|
|
| 16 |
"rlvr",
|
| 17 |
"grpo",
|
| 18 |
"sdpo",
|
|
|
|
|
|
|
| 19 |
"dpo",
|
| 20 |
"diloco",
|
|
|
|
| 21 |
"agentic",
|
| 22 |
"coding-agents",
|
| 23 |
"composer-2-5",
|
| 24 |
"cursor",
|
| 25 |
"trl",
|
| 26 |
"verl",
|
|
|
|
| 27 |
"openenv",
|
| 28 |
"torchft",
|
|
|
|
|
|
|
|
|
|
| 29 |
]
|
| 30 |
classifiers = [
|
| 31 |
"Development Status :: 3 - Alpha",
|
|
@@ -47,17 +54,35 @@ dependencies = [
|
|
| 47 |
replay = [
|
| 48 |
"httpx>=0.27",
|
| 49 |
]
|
| 50 |
-
# DiLoCo outer-loop optimizer
|
| 51 |
diloco = [
|
| 52 |
"torchft-nightly",
|
| 53 |
]
|
| 54 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
train = [
|
| 56 |
"trl>=0.12",
|
| 57 |
"peft>=0.13",
|
| 58 |
"accelerate>=1.0",
|
| 59 |
"datasets>=3.0",
|
| 60 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# Everything for development
|
| 62 |
dev = [
|
| 63 |
"pytest>=8.0",
|
|
|
|
| 16 |
"rlvr",
|
| 17 |
"grpo",
|
| 18 |
"sdpo",
|
| 19 |
+
"simpo",
|
| 20 |
+
"taid",
|
| 21 |
"dpo",
|
| 22 |
"diloco",
|
| 23 |
+
"decoupled-diloco",
|
| 24 |
"agentic",
|
| 25 |
"coding-agents",
|
| 26 |
"composer-2-5",
|
| 27 |
"cursor",
|
| 28 |
"trl",
|
| 29 |
"verl",
|
| 30 |
+
"prime-rl",
|
| 31 |
"openenv",
|
| 32 |
"torchft",
|
| 33 |
+
"monarch",
|
| 34 |
+
"modal",
|
| 35 |
+
"huggingface-jobs",
|
| 36 |
]
|
| 37 |
classifiers = [
|
| 38 |
"Development Status :: 3 - Alpha",
|
|
|
|
| 54 |
replay = [
|
| 55 |
"httpx>=0.27",
|
| 56 |
]
|
| 57 |
+
# DiLoCo outer-loop optimizer (single-process)
|
| 58 |
diloco = [
|
| 59 |
"torchft-nightly",
|
| 60 |
]
|
| 61 |
+
# Decoupled DiLoCo over serverless executors (per ADR-005)
|
| 62 |
+
serverless = [
|
| 63 |
+
"fsspec>=2024.6",
|
| 64 |
+
"huggingface_hub>=0.27", # for hf:// fsspec backend + HF Jobs
|
| 65 |
+
]
|
| 66 |
+
# Replaysim dataset normalization (per ADR-004)
|
| 67 |
+
replaysim = [
|
| 68 |
+
"data-juicer>=1.0",
|
| 69 |
+
"composer-replication[replay]", # replaysim builds on the replay channel
|
| 70 |
+
]
|
| 71 |
+
# Production training (TRL GRPOTrainer subclass — Recipe A)
|
| 72 |
train = [
|
| 73 |
"trl>=0.12",
|
| 74 |
"peft>=0.13",
|
| 75 |
"accelerate>=1.0",
|
| 76 |
"datasets>=3.0",
|
| 77 |
]
|
| 78 |
+
# PRIME-RL recipe (Recipe C — per ADR-006)
|
| 79 |
+
prime-rl = [
|
| 80 |
+
"prime-rl>=0.5",
|
| 81 |
+
]
|
| 82 |
+
# Monarch actor mesh (per ADR-006)
|
| 83 |
+
monarch = [
|
| 84 |
+
"monarch>=0.4.1",
|
| 85 |
+
]
|
| 86 |
# Everything for development
|
| 87 |
dev = [
|
| 88 |
"pytest>=8.0",
|