Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Wave 14: close every Wave 13 review finding + 4 documentation files; Wave 14b: real PRIME-RL parity + multi-process DiLoCo convergence
Browse filesPHASE A: 4 parallel impl subagents closed every Wave 13 cross-model review item.
T1 \u2014 compose_loss integration (closes W13 BLOCKER 2):
- Added `dpo_variant: 'dpo'|'simpo'`, `sdpo_wrapper: 'none'|'taid'|'entropy_opd'`,
`taid_schedule_step`, `taid_total_steps` plus 5 tuning kwargs
- 11 new integration tests in composer_replication/tests/test_compose_loss_integration.py
- Bit-exact reproduction of legacy compose_loss output when defaults preserved
- All 38 spike-005 tests still pass
T2 \u2014 replaysim DJ adapter reshape (closes W13 Suggestion 3):
- _dpo_pair_to_dj_record now emits BOTH flat strings AND chat-messages
for dual-shape compatibility with data-juicer's text ops
- _dj_record_to_normalized round-trips both shapes
- default.yaml fixed: text_keys plural \u2192 text_key singular (caught a
related bug in the same op set)
- Real data-juicer e2e test runs DefaultExecutor on a 3-record fixture
- 15 tests (was 9)
T3 \u2014 MockManager DiLoCo integration (closes W13 Suggestion 4):
- Audited torchft Manager surface; added 6 missing methods:
current_step, disallow_state_dict_read, allow_state_dict_read,
register_state_dict_fn, _use_async_quorum, is_leader
- _ImmediateWork wraps allreduce return so DiLoCo can call .wait()
- New integration test runs make_diloco_outer_loop(MockManager(store), nn.Linear)
end-to-end and verifies model parameters change
T4 \u2014 PRIME-RL real GRPO (initial, BUGGY \u2014 see Wave 14b for fix):
- Wave 14 first attempt got the formula wrong; documented in
docs/research/WAVE_14_FINAL_REVIEW.md and re-fixed in Wave 14b.
PHASE B: 4 parallel doc subagents produced 3,930 lines:
- docs/USER_GUIDE.md (670 lines) \u2014 8-section end-to-end narrative
- docs/API_REFERENCE.md (1471 lines) \u2014 every public symbol with
signature + params + return + raises + example, marked tested vs
\u26a0\ufe0f untested vs \ud83d\udfe1 skeleton
- docs/TROUBLESHOOTING.md (797 lines) \u2014 11 failure modes with
SYMPTOM/DIAGNOSIS/FIX/VERIFICATION
- docs/INTEGRATION_RECIPES.md (993 lines) \u2014 5 recipes (TRL/VeRL/
PRIME-RL/serverless/Monarch) with 7-part template + comparison matrix
PHASE C: cross-model adversarial review (Opus 4.7 sub-agent) cloned
PRIME-RL upstream and verified T4's implementation. Found 1 BLOCKER:
T4 thought it matched PRIME-RL but didn't:
- Mask gate was on log_ratio, should be on probs_diff (probability-space)
- Missing importance_ratio multiplication (was REINFORCE)
- Missing advantage-sign-conditioned mask
- Missing KL term
- Wrong defaults (4.0/-4.0 vs PRIME-RL's actual 0.2/0.2)
- Plus 4 SUGGESTIONs.
PHASE C2 (Wave 14b): 2 parallel subagents closed everything.
Subagent 1 re-implemented PRIME-RL composer_loss against upstream:
- Verified formula in /tmp/prime-rl-clone/src/prime_rl/trainer/rl/loss.py
default_loss_fn (lines 116-165) and DefaultLossConfig (412-425)
- Now byte-for-byte matches PRIME-RL: probs_diff masking,
importance_ratio multiplication, advantage-sign-conditioned mask, KL term
- Defaults corrected: dppo_mask_high=0.2, dppo_mask_low=0.2,
adv_tau=1.0, kl_tau=1e-3
- 16 tests including parity test that imports PRIME-RL's default_loss_fn
(skip-marked when prime-rl not installed)
- Updated docstring + USER_GUIDE / API_REFERENCE / TROUBLESHOOTING /
prime_rl_recipe.md / prime_rl_config.yaml all repaired
- FLAGGED for Wave 15: PRIME-RL's setup_loss_fns expects LossOutputs(loss,
metrics) return shape, not bare scalar \u2014 separate adapter-level issue
Subagent 2 closed 3 doc/test SUGGESTIONs:
- ADR-007 updated to reflect Wave 14 closure of compose_loss integration
- INTEGRATION_RECIPES.md: 4 ModalExecutor/HFJobsExecutor dead-code
constructor calls fixed (substituted LocalProcessExecutor)
- New multi-process MockManager+DiLoCo convergence test:
spawns 2 replicas, runs 1 outer round, asserts both replicas converge
to identical weights end-to-end. Test design corrected after initial
attempt: shared-init-seed + rank-specific-data is canonical DiLoCo
setup (not rank-specific-init which is divergent by design).
DOCUMENTATION INDEX (post-Wave-14b):
- README.md (Wave 13 expansion section + roadmap)
- docs/USER_GUIDE.md (start-here narrative)
- docs/API_REFERENCE.md (every public symbol)
- docs/INTEGRATION_RECIPES.md (5 recipes)
- docs/TROUBLESHOOTING.md (11 failure modes + bug-report template)
- docs/V1_V8_COVERAGE.md (brief coverage matrix)
- docs/V3_SUBSTRATE_COVERAGE.md (substrate matrix, 8/8 covered)
- docs/VISION_VALIDATION.md (10-point scorecard)
- docs/ALTERED_MINDS_TIE_IN.md (workstream bridge)
- docs/adrs/ADR-001..007 (7 architectural decisions)
- docs/research/WAVE_7_10_FINAL_REVIEW.md (Wave 11 cross-model review)
- docs/research/WAVE_13_FINAL_REVIEW.md (Wave 13 cross-model review)
- docs/research/WAVE_14_FINAL_REVIEW.md (Wave 14 + 14b cross-model review)
- docs/research/{DILOCO_RECONNAISSANCE, DILOCO_SERVERLESS_RECONNAISSANCE,
MODAL_RECONNAISSANCE, REPLAYSIM_NORMALIZATION_RECONNAISSANCE,
RL_FRAMEWORKS_LANDSCAPE, SELF_DISTILLATION_LANDSCAPE,
TRACE_SOURCE_RECONNAISSANCE}.md (primary-source recons, ~3,300 lines)
TESTS: 130 passing + 1 skip-marked (PRIME-RL parity test runs when
prime-rl installed). Was 93 at end of Wave 13.
NO REGRESSIONS: every prior wave's tests still pass. New code is
either purely additive (distillation, replaysim, serverless DiLoCo) or
backward-compatible (compose_loss kwargs default to legacy behavior).
- composer_replication/diloco/serverless/allreduce.py +122 -10
- composer_replication/diloco/serverless/tests/test_serverless_diloco_integration.py +352 -0
- composer_replication/diloco/serverless/tests/test_serverless_local.py +15 -4
- composer_replication/loss.py +209 -25
- composer_replication/recipes/prime_rl/composer_loss.py +213 -73
- composer_replication/recipes/prime_rl/prime_rl_config.yaml +17 -3
- composer_replication/recipes/prime_rl/prime_rl_recipe.md +90 -22
- composer_replication/recipes/prime_rl/tests/test_composer_loss.py +484 -0
- composer_replication/recipes/replaysim/default.yaml +49 -15
- composer_replication/replaysim/normalize.py +53 -7
- composer_replication/replaysim/tests/test_replaysim.py +191 -2
- composer_replication/tests/test_compose_loss_integration.py +416 -0
- docs/API_REFERENCE.md +1484 -0
- docs/INTEGRATION_RECIPES.md +998 -0
- docs/TROUBLESHOOTING.md +823 -0
- docs/USER_GUIDE.md +683 -0
- docs/adrs/ADR-007-self-distillation-losses.md +46 -22
- docs/research/WAVE_14_FINAL_REVIEW.md +264 -0
|
@@ -176,6 +176,42 @@ class ObjectStoreAllReduce:
|
|
| 176 |
# ---------------------------------------------------------------------
|
| 177 |
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
class MockManager:
|
| 180 |
"""Drop-in replacement for `torchft.Manager` that delegates allreduce
|
| 181 |
to `ObjectStoreAllReduce`.
|
|
@@ -188,27 +224,103 @@ class MockManager:
|
|
| 188 |
Reference: `make_diloco_outer_loop` in
|
| 189 |
`composer_replication/diloco/__init__.py` accepts an optional
|
| 190 |
`manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"""
|
|
|
|
| 192 |
def __init__(self, store: ObjectStoreAllReduce) -> None:
|
| 193 |
self._store = store
|
| 194 |
-
# torchft Manager attributes that DiLoCo consults
|
|
|
|
| 195 |
self.num_participants = store.world_size
|
| 196 |
self.rank = store.rank
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
def should_commit(self) -> bool:
|
|
|
|
|
|
|
|
|
|
| 205 |
return True
|
| 206 |
|
| 207 |
def start_quorum(self) -> None:
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
def wait_quorum(self) -> int:
|
| 211 |
return self.num_participants
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
__all__ = ["MockManager", "ObjectStoreAllReduce"]
|
|
|
|
| 176 |
# ---------------------------------------------------------------------
|
| 177 |
|
| 178 |
|
| 179 |
+
class _ImmediateWork:
|
| 180 |
+
"""Work-shaped wrapper for an already-completed allreduce.
|
| 181 |
+
|
| 182 |
+
`torchft.Manager.allreduce` returns a `torch.distributed.Work` (or
|
| 183 |
+
`torchft.work._DummyWork`) which DiLoCo calls `.wait()` on inside
|
| 184 |
+
`_StreamingDiLoCoFragment.perform_sync`. Our `ObjectStoreAllReduce`
|
| 185 |
+
is synchronous — by the time it returns, the average is already in
|
| 186 |
+
the tensor — so `.wait()` is a no-op.
|
| 187 |
+
|
| 188 |
+
We deliberately don't subclass `torch.distributed._Work` to keep this
|
| 189 |
+
module importable in environments without a full torch distributed
|
| 190 |
+
build; DiLoCo only does `work.wait()`, nothing more.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
__slots__ = ("_tensor",)
|
| 194 |
+
|
| 195 |
+
def __init__(self, tensor: torch.Tensor) -> None:
|
| 196 |
+
self._tensor = tensor
|
| 197 |
+
|
| 198 |
+
def wait(self, *_args: Any, **_kwargs: Any) -> bool:
|
| 199 |
+
return True
|
| 200 |
+
|
| 201 |
+
def get_future(self) -> Any:
|
| 202 |
+
# Torch >=2.x sometimes calls Work.get_future(); provide a satisfied
|
| 203 |
+
# future so callers don't crash. We only need to be defensive here;
|
| 204 |
+
# DiLoCo itself doesn't call this.
|
| 205 |
+
try:
|
| 206 |
+
import torch.futures as _f
|
| 207 |
+
|
| 208 |
+
fut = _f.Future()
|
| 209 |
+
fut.set_result(self._tensor)
|
| 210 |
+
return fut
|
| 211 |
+
except Exception: # pragma: no cover — defensive only
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
|
| 215 |
class MockManager:
|
| 216 |
"""Drop-in replacement for `torchft.Manager` that delegates allreduce
|
| 217 |
to `ObjectStoreAllReduce`.
|
|
|
|
| 224 |
Reference: `make_diloco_outer_loop` in
|
| 225 |
`composer_replication/diloco/__init__.py` accepts an optional
|
| 226 |
`manager=` kwarg; pass a `MockManager` to enable serverless DiLoCo.
|
| 227 |
+
|
| 228 |
+
torchft.Manager surface audited from
|
| 229 |
+
``torchft/local_sgd.py`` (DiLoCo + _StreamingDiLoCoFragment paths) and
|
| 230 |
+
``torchft/manager.py``. Methods/attributes DiLoCo touches:
|
| 231 |
+
|
| 232 |
+
* ``allreduce(tensor, should_quantize=...) -> Work`` — must return an
|
| 233 |
+
object with ``.wait()`` (DiLoCo calls ``work.wait()`` in
|
| 234 |
+
``perform_sync``).
|
| 235 |
+
* ``should_commit() -> bool`` — gates the outer-optimizer step.
|
| 236 |
+
* ``start_quorum()`` — called once per outer round, before
|
| 237 |
+
``prepare_sync``.
|
| 238 |
+
* ``current_step() -> int`` — used to pick the streaming-DiLoCo
|
| 239 |
+
fragment for this round (``step % len(fragments)``).
|
| 240 |
+
* ``disallow_state_dict_read()`` / ``allow_state_dict_read()`` —
|
| 241 |
+
called every inner step from the optimizer pre/post hooks.
|
| 242 |
+
* ``register_state_dict_fn(key, load_fn, save_fn)`` — called once
|
| 243 |
+
per fragment from ``DiLoCo.__init__``.
|
| 244 |
+
* ``_use_async_quorum`` (attribute) — DiLoCo's constructor refuses
|
| 245 |
+
to start if this is truthy. Must exist and be False.
|
| 246 |
+
* ``num_participants`` / ``rank`` — read by upstream callers.
|
| 247 |
"""
|
| 248 |
+
|
| 249 |
def __init__(self, store: ObjectStoreAllReduce) -> None:
|
| 250 |
self._store = store
|
| 251 |
+
# torchft Manager attributes that DiLoCo consults at construction time
|
| 252 |
+
# or in user code paths.
|
| 253 |
self.num_participants = store.world_size
|
| 254 |
self.rank = store.rank
|
| 255 |
+
# DiLoCo.__init__ raises if this is truthy (line 622 of
|
| 256 |
+
# torchft/local_sgd.py). Object-store sync is synchronous → False.
|
| 257 |
+
self._use_async_quorum: bool = False
|
| 258 |
+
# Mirror the upstream Manager's monotonic step counter. DiLoCo reads
|
| 259 |
+
# this via current_step() to decide which fragment to sync each round.
|
| 260 |
+
# Bumped from start_quorum() so it advances exactly once per outer round.
|
| 261 |
+
self._step: int = 0
|
| 262 |
+
# State-dict-fn registry: torchft uses this for fault-tolerant
|
| 263 |
+
# checkpoint restore. We're single-shot serverless — record but never
|
| 264 |
+
# invoke. Tests can introspect this dict to confirm registration.
|
| 265 |
+
self._state_dict_fns: dict[str, tuple[Any, Any]] = {}
|
| 266 |
+
|
| 267 |
+
# ---- Core collective ------------------------------------------------
|
| 268 |
+
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> _ImmediateWork:
|
| 269 |
+
# DiLoCo expects a Work-like return value (it stores it in a list
|
| 270 |
+
# then calls .wait() later). Object-store all-reduce is synchronous,
|
| 271 |
+
# so the tensor is already averaged when we hand back the wrapper.
|
| 272 |
+
averaged = self._store.allreduce(tensor)
|
| 273 |
+
return _ImmediateWork(averaged)
|
| 274 |
+
|
| 275 |
+
# ---- Quorum / commit lifecycle -------------------------------------
|
| 276 |
def should_commit(self) -> bool:
|
| 277 |
+
# No fault-tolerance failover in serverless mode: every quorum
|
| 278 |
+
# always commits. Replica failure is handled by the orchestration
|
| 279 |
+
# layer (HF Jobs / Modal restart), not by DiLoCo skipping a round.
|
| 280 |
return True
|
| 281 |
|
| 282 |
def start_quorum(self) -> None:
|
| 283 |
+
# The upstream Manager bumps its step counter inside the quorum
|
| 284 |
+
# bookkeeping. Do the same so current_step() advances per round
|
| 285 |
+
# and DiLoCo's fragment-rotation math matches across replicas.
|
| 286 |
+
self._step += 1
|
| 287 |
|
| 288 |
def wait_quorum(self) -> int:
|
| 289 |
return self.num_participants
|
| 290 |
|
| 291 |
+
# ---- Step counter ---------------------------------------------------
|
| 292 |
+
def current_step(self) -> int:
|
| 293 |
+
return self._step
|
| 294 |
+
|
| 295 |
+
# ---- State-dict read gating ----------------------------------------
|
| 296 |
+
# torchft uses these to make checkpoint restore thread-safe. In a
|
| 297 |
+
# single-process serverless mock there's no concurrent reader, so they
|
| 298 |
+
# are no-ops — but they MUST exist (DiLoCo's pre/post optimizer hooks
|
| 299 |
+
# call them on every inner step).
|
| 300 |
+
def allow_state_dict_read(self) -> None:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
+
def disallow_state_dict_read(self) -> None:
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
# ---- Checkpoint hook registry --------------------------------------
|
| 307 |
+
def register_state_dict_fn(
|
| 308 |
+
self,
|
| 309 |
+
key: str,
|
| 310 |
+
load_fn: Any,
|
| 311 |
+
save_fn: Any,
|
| 312 |
+
) -> None:
|
| 313 |
+
# DiLoCo registers one (load, save) pair per fragment so torchft can
|
| 314 |
+
# checkpoint the outer-optimizer state and original-parameter backup.
|
| 315 |
+
# In serverless mode we capture the registration so tests can verify
|
| 316 |
+
# it happened, but never invoke it — there's no HA failover.
|
| 317 |
+
self._state_dict_fns[key] = (load_fn, save_fn)
|
| 318 |
+
|
| 319 |
+
# ---- Convenience ----------------------------------------------------
|
| 320 |
+
def is_leader(self) -> bool:
|
| 321 |
+
# Not strictly required by DiLoCo but referenced in some torchft
|
| 322 |
+
# integrations / our own code that may swap MockManager in.
|
| 323 |
+
return self.rank == 0
|
| 324 |
+
|
| 325 |
|
| 326 |
+
__all__ = ["MockManager", "ObjectStoreAllReduce", "_ImmediateWork"]
|
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end MockManager × torchft.DiLoCo integration test.
|
| 2 |
+
|
| 3 |
+
Closes the Wave 13 cross-model adversarial-review gap (Suggestion 4):
|
| 4 |
+
the original MockManager was advertised as a drop-in for torchft.Manager
|
| 5 |
+
but only stubbed `.allreduce / .should_commit / .start_quorum`. DiLoCo's
|
| 6 |
+
real call surface (audited from `torchft/local_sgd.py` v2026-spring) also
|
| 7 |
+
includes `current_step()`, `disallow_state_dict_read()`,
|
| 8 |
+
`allow_state_dict_read()`, `register_state_dict_fn()`, and the
|
| 9 |
+
`_use_async_quorum` attribute — plus `allreduce()` must return a Work-like
|
| 10 |
+
object with `.wait()`, not a raw tensor.
|
| 11 |
+
|
| 12 |
+
This test runs ONE full DiLoCo outer round (sync_every inner steps + the
|
| 13 |
+
sync) against a tiny `nn.Linear(4, 4)` with `world_size=1` so the
|
| 14 |
+
object-store rendezvous is trivial. It verifies:
|
| 15 |
+
|
| 16 |
+
1. Construction does not raise.
|
| 17 |
+
2. Running through one full outer round does not raise AttributeError
|
| 18 |
+
(which is what the old MockManager would have hit at `current_step()`).
|
| 19 |
+
3. The model parameters change after the outer step fires (proving the
|
| 20 |
+
outer SGD path actually executed end-to-end, not just that the
|
| 21 |
+
inner-step hooks ran).
|
| 22 |
+
4. The MockManager's step counter advanced exactly once (one outer round
|
| 23 |
+
⇒ one start_quorum bump).
|
| 24 |
+
5. DiLoCo registered a state-dict fn per fragment.
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import pytest
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
torchft = pytest.importorskip(
|
| 32 |
+
"torchft.local_sgd",
|
| 33 |
+
reason="torchft must be installed to run the DiLoCo integration test",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 37 |
+
from composer_replication.diloco.serverless.allreduce import (
|
| 38 |
+
MockManager,
|
| 39 |
+
ObjectStoreAllReduce,
|
| 40 |
+
_ImmediateWork,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _make_store(tmp_path) -> ObjectStoreAllReduce:
|
| 45 |
+
return ObjectStoreAllReduce(
|
| 46 |
+
uri=str(tmp_path),
|
| 47 |
+
rank=0,
|
| 48 |
+
world_size=1,
|
| 49 |
+
timeout_s=10.0,
|
| 50 |
+
poll_interval_s=0.05,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_mockmanager_has_full_diloco_call_surface(tmp_path):
|
| 55 |
+
"""Audited methods/attrs from torchft/local_sgd.py DiLoCo path must exist."""
|
| 56 |
+
mgr = MockManager(_make_store(tmp_path))
|
| 57 |
+
# Methods DiLoCo invokes
|
| 58 |
+
for attr in (
|
| 59 |
+
"allreduce",
|
| 60 |
+
"should_commit",
|
| 61 |
+
"start_quorum",
|
| 62 |
+
"current_step",
|
| 63 |
+
"disallow_state_dict_read",
|
| 64 |
+
"allow_state_dict_read",
|
| 65 |
+
"register_state_dict_fn",
|
| 66 |
+
"wait_quorum",
|
| 67 |
+
"is_leader",
|
| 68 |
+
):
|
| 69 |
+
assert callable(getattr(mgr, attr)), f"MockManager missing method: {attr}"
|
| 70 |
+
# Attributes DiLoCo reads at construction / runtime
|
| 71 |
+
assert hasattr(mgr, "_use_async_quorum")
|
| 72 |
+
assert mgr._use_async_quorum is False # DiLoCo.__init__ rejects True
|
| 73 |
+
assert hasattr(mgr, "num_participants")
|
| 74 |
+
assert hasattr(mgr, "rank")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_mockmanager_allreduce_returns_workshaped(tmp_path):
|
| 78 |
+
"""DiLoCo stores the allreduce return in a list and calls `.wait()` later."""
|
| 79 |
+
mgr = MockManager(_make_store(tmp_path))
|
| 80 |
+
work = mgr.allreduce(torch.zeros(2, 2))
|
| 81 |
+
# It must look like torch.distributed.Work / torchft._DummyWork
|
| 82 |
+
assert hasattr(work, "wait"), "allreduce return must have .wait() (DiLoCo calls it)"
|
| 83 |
+
assert callable(work.wait)
|
| 84 |
+
# No-op .wait() must not raise on a synchronous mock.
|
| 85 |
+
assert work.wait() is True
|
| 86 |
+
# Defensive: get_future() should also work (some torch paths probe it).
|
| 87 |
+
fut = work.get_future()
|
| 88 |
+
assert fut is None or hasattr(fut, "wait")
|
| 89 |
+
# Concrete type
|
| 90 |
+
assert isinstance(work, _ImmediateWork)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_mockmanager_diloco_outer_round_completes(tmp_path):
|
| 94 |
+
"""Run one full inner+outer DiLoCo round and verify params change.
|
| 95 |
+
|
| 96 |
+
With world_size=1 + MockManager → ObjectStoreAllReduce(file://), the
|
| 97 |
+
rendezvous is single-process, so this test runs synchronously. We
|
| 98 |
+
use `sync_every=4` and run exactly 4 inner-optimizer steps; at the
|
| 99 |
+
4th step DiLoCo's post-hook fires `prepare_sync` then `perform_sync`,
|
| 100 |
+
exercising the entire MockManager surface.
|
| 101 |
+
"""
|
| 102 |
+
torch.manual_seed(0)
|
| 103 |
+
model = torch.nn.Linear(4, 4, bias=False)
|
| 104 |
+
initial_params = model.weight.detach().clone()
|
| 105 |
+
|
| 106 |
+
inner_optim = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 107 |
+
|
| 108 |
+
store = _make_store(tmp_path)
|
| 109 |
+
manager = MockManager(store)
|
| 110 |
+
|
| 111 |
+
diloco = make_diloco_outer_loop(
|
| 112 |
+
manager=manager,
|
| 113 |
+
model_fragments=[model],
|
| 114 |
+
inner_optimizer=inner_optim,
|
| 115 |
+
outer_lr=0.7,
|
| 116 |
+
outer_momentum=0.9,
|
| 117 |
+
nesterov=True,
|
| 118 |
+
sync_every=4,
|
| 119 |
+
fragment_sync_delay=0,
|
| 120 |
+
fragment_update_alpha=0.0,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Sanity: DiLoCo registered a state-dict fn for our single fragment.
|
| 124 |
+
assert len(manager._state_dict_fns) == 1, (
|
| 125 |
+
f"expected 1 fragment registration, got {list(manager._state_dict_fns)}"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
x = torch.randn(2, 4)
|
| 129 |
+
target = torch.randn(2, 4)
|
| 130 |
+
|
| 131 |
+
with diloco:
|
| 132 |
+
for _ in range(4): # exactly sync_every inner steps → one outer round
|
| 133 |
+
inner_optim.zero_grad()
|
| 134 |
+
loss = ((model(x) - target) ** 2).mean()
|
| 135 |
+
loss.backward()
|
| 136 |
+
# Must NOT raise AttributeError on current_step / state_dict_read /
|
| 137 |
+
# register_state_dict_fn / etc. The original MockManager would have
|
| 138 |
+
# crashed here on the very first step's _step_pre_hook calling
|
| 139 |
+
# disallow_state_dict_read.
|
| 140 |
+
inner_optim.step()
|
| 141 |
+
|
| 142 |
+
# After exactly one outer round, the MockManager's step counter
|
| 143 |
+
# should have advanced exactly once (start_quorum is called once).
|
| 144 |
+
assert manager.current_step() == 1, (
|
| 145 |
+
f"expected current_step()==1 after one outer round, got {manager.current_step()}"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# The outer SGD step actually fired ⇒ params differ from initial.
|
| 149 |
+
final_params = model.weight.detach().clone()
|
| 150 |
+
assert not torch.allclose(initial_params, final_params), (
|
| 151 |
+
"model params unchanged after outer round — outer optimizer never ran"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _diloco_replica_one_outer_round(
|
| 156 |
+
rendezvous_uri: str,
|
| 157 |
+
world_size: int,
|
| 158 |
+
sync_every: int,
|
| 159 |
+
) -> dict:
|
| 160 |
+
"""Top-level entry — must be importable for multiprocessing 'spawn'.
|
| 161 |
+
|
| 162 |
+
Each replica:
|
| 163 |
+
1. seeds torch with a SHARED seed for model init (DiLoCo's standard
|
| 164 |
+
assumption: all replicas start with identical weights — DiLoCo
|
| 165 |
+
only averages pseudo-gradients, not absolute weights, so divergent
|
| 166 |
+
inits would never reconcile).
|
| 167 |
+
2. builds nn.Linear(4, 4, bias=False) + SGD inner optimizer.
|
| 168 |
+
3. trains on RANK-SPECIFIC data so each replica's inner-trained
|
| 169 |
+
weights diverge during the inner loop (this is what gives the
|
| 170 |
+
pseudo-gradient real cross-rank variance — without it, the
|
| 171 |
+
averaging is observationally a no-op).
|
| 172 |
+
4. runs `sync_every` inner steps inside `make_diloco_outer_loop` —
|
| 173 |
+
this fires exactly one outer round.
|
| 174 |
+
5. returns the final flattened weight vector and the pre-outer
|
| 175 |
+
(purely-inner) weights.
|
| 176 |
+
|
| 177 |
+
The test then asserts both ranks' final weights are identical
|
| 178 |
+
(allclose), which proves the cross-replica allreduce of the
|
| 179 |
+
pseudo-gradient ran end-to-end. The pre-outer weights MUST differ
|
| 180 |
+
across ranks (proving rank-specific data drove divergence in the
|
| 181 |
+
inner loop) — otherwise the convergence assertion is vacuous.
|
| 182 |
+
"""
|
| 183 |
+
import os as _os
|
| 184 |
+
import torch as _torch
|
| 185 |
+
import torch.nn as _nn
|
| 186 |
+
|
| 187 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 188 |
+
from composer_replication.diloco.serverless.allreduce import (
|
| 189 |
+
MockManager,
|
| 190 |
+
ObjectStoreAllReduce,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
rank = int(_os.environ["REPLICA_RANK"])
|
| 194 |
+
|
| 195 |
+
# SHARED init seed — both replicas start with identical weights, as
|
| 196 |
+
# DiLoCo assumes. (DiLoCo averages pseudo-gradients, not weights, so
|
| 197 |
+
# divergent inits would never reconcile and the convergence claim
|
| 198 |
+
# would be incorrect.)
|
| 199 |
+
_torch.manual_seed(0)
|
| 200 |
+
model = _nn.Linear(4, 4, bias=False)
|
| 201 |
+
initial = model.weight.detach().clone()
|
| 202 |
+
|
| 203 |
+
inner_optim = _torch.optim.SGD(model.parameters(), lr=0.1)
|
| 204 |
+
|
| 205 |
+
store = ObjectStoreAllReduce(
|
| 206 |
+
rendezvous_uri,
|
| 207 |
+
rank=rank,
|
| 208 |
+
world_size=world_size,
|
| 209 |
+
timeout_s=120.0,
|
| 210 |
+
poll_interval_s=0.05,
|
| 211 |
+
)
|
| 212 |
+
manager = MockManager(store)
|
| 213 |
+
|
| 214 |
+
diloco = make_diloco_outer_loop(
|
| 215 |
+
manager=manager,
|
| 216 |
+
model_fragments=[model],
|
| 217 |
+
inner_optimizer=inner_optim,
|
| 218 |
+
sync_every=sync_every,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# RANK-SPECIFIC data so the inner-trained weights diverge before the
|
| 222 |
+
# outer sync — this is what makes "post-sync convergence" a real
|
| 223 |
+
# property to verify rather than a tautology.
|
| 224 |
+
_torch.manual_seed(100 + rank)
|
| 225 |
+
x = _torch.randn(2, 4)
|
| 226 |
+
target = _torch.randn(2, 4)
|
| 227 |
+
|
| 228 |
+
with diloco:
|
| 229 |
+
for _ in range(sync_every):
|
| 230 |
+
inner_optim.zero_grad()
|
| 231 |
+
loss = ((model(x) - target) ** 2).mean()
|
| 232 |
+
loss.backward()
|
| 233 |
+
inner_optim.step()
|
| 234 |
+
|
| 235 |
+
final = model.weight.detach().clone()
|
| 236 |
+
return {
|
| 237 |
+
"rank": rank,
|
| 238 |
+
"initial": initial.flatten().tolist(),
|
| 239 |
+
"final": final.flatten().tolist(),
|
| 240 |
+
"current_step": manager.current_step(),
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def test_mockmanager_diloco_multi_process_weights_converge(tmp_path):
|
| 245 |
+
"""Wave 14 (Suggestion 4): cross-replica weight convergence after one outer round.
|
| 246 |
+
|
| 247 |
+
Spawns n_replicas=2 subprocesses with IDENTICAL initial weights
|
| 248 |
+
(DiLoCo's standard assumption — it averages pseudo-gradients, not
|
| 249 |
+
absolute weights) but RANK-SPECIFIC training data. After exactly
|
| 250 |
+
one DiLoCo outer round, both replicas must end with IDENTICAL
|
| 251 |
+
weights, because:
|
| 252 |
+
|
| 253 |
+
pseudo_grad_i = init - inner_trained_i # per-rank, differ
|
| 254 |
+
avg_pseudo = mean_i(pseudo_grad_i) # same on all ranks
|
| 255 |
+
final = init - outer_lr * avg_pseudo # same on all ranks
|
| 256 |
+
|
| 257 |
+
This catches averaging-direction bugs that the world_size=1
|
| 258 |
+
single-process test silently misses (a single-rank allreduce is a
|
| 259 |
+
no-op and can hide bugs in the multi-rank averaging arithmetic, the
|
| 260 |
+
file-staging round-id increment, or the weight redistribution after
|
| 261 |
+
the outer SGD step).
|
| 262 |
+
"""
|
| 263 |
+
import os as _os
|
| 264 |
+
import tempfile as _tempfile
|
| 265 |
+
|
| 266 |
+
from composer_replication.diloco.serverless import LocalProcessExecutor
|
| 267 |
+
|
| 268 |
+
n_replicas = 2
|
| 269 |
+
sync_every = 2
|
| 270 |
+
with _tempfile.TemporaryDirectory() as td:
|
| 271 |
+
rendezvous = _os.path.join(td, "diloco-multiproc-run")
|
| 272 |
+
executor = LocalProcessExecutor()
|
| 273 |
+
handles = executor.launch_replicas(
|
| 274 |
+
n_replicas=n_replicas,
|
| 275 |
+
entrypoint=f"{__name__}._diloco_replica_one_outer_round",
|
| 276 |
+
entrypoint_args={
|
| 277 |
+
"rendezvous_uri": rendezvous,
|
| 278 |
+
"world_size": n_replicas,
|
| 279 |
+
"sync_every": sync_every,
|
| 280 |
+
"rank_env": "REPLICA_RANK",
|
| 281 |
+
},
|
| 282 |
+
timeout=180,
|
| 283 |
+
)
|
| 284 |
+
results = executor.collect(handles, timeout=180)
|
| 285 |
+
|
| 286 |
+
# Diagnostic-friendly failure: surface per-rank error if any replica died.
|
| 287 |
+
statuses = {r["rank"]: r["status"] for r in results}
|
| 288 |
+
for rank in range(n_replicas):
|
| 289 |
+
assert statuses[rank] == "succeeded", (
|
| 290 |
+
f"rank {rank} failed: "
|
| 291 |
+
f"{next(r for r in results if r['rank'] == rank).get('error')}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
payloads = sorted([r["result"] for r in results], key=lambda d: d["rank"])
|
| 295 |
+
rank0, rank1 = payloads[0], payloads[1]
|
| 296 |
+
|
| 297 |
+
# Sanity: each replica really did fire exactly one outer round.
|
| 298 |
+
assert rank0["current_step"] == 1, rank0
|
| 299 |
+
assert rank1["current_step"] == 1, rank1
|
| 300 |
+
|
| 301 |
+
# Sanity: replicas STARTED with identical weights (DiLoCo assumption).
|
| 302 |
+
assert rank0["initial"] == rank1["initial"], (
|
| 303 |
+
"replicas started with different initial weights — DiLoCo only "
|
| 304 |
+
"averages pseudo-gradients, not weights, so this would prevent "
|
| 305 |
+
"convergence even with a perfectly correct allreduce"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# The actual property: after one full outer round both replicas must
|
| 309 |
+
# have the SAME final weights. Tight tolerance because the only
|
| 310 |
+
# arithmetic between them is SGD + a single allreduce-mean.
|
| 311 |
+
final0 = torch.tensor(rank0["final"])
|
| 312 |
+
final1 = torch.tensor(rank1["final"])
|
| 313 |
+
if not torch.allclose(final0, final1, atol=1e-5, rtol=1e-5):
|
| 314 |
+
max_abs_diff = (final0 - final1).abs().max().item()
|
| 315 |
+
pytest.fail(
|
| 316 |
+
"Multi-process DiLoCo did NOT converge to identical weights "
|
| 317 |
+
"after one outer round.\n"
|
| 318 |
+
f" rank0 final = {final0.tolist()}\n"
|
| 319 |
+
f" rank1 final = {final1.tolist()}\n"
|
| 320 |
+
f" max|diff| = {max_abs_diff}\n"
|
| 321 |
+
"This indicates a real cross-replica-averaging bug "
|
| 322 |
+
"(averaging direction, round-id desync, or weight redistribution)."
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def test_mockmanager_diloco_two_outer_rounds_step_counter(tmp_path):
|
| 327 |
+
"""Two outer rounds must bump current_step() to 2 (fragment rotation safety)."""
|
| 328 |
+
torch.manual_seed(1)
|
| 329 |
+
model = torch.nn.Linear(4, 4, bias=False)
|
| 330 |
+
inner_optim = torch.optim.SGD(model.parameters(), lr=0.05)
|
| 331 |
+
|
| 332 |
+
manager = MockManager(_make_store(tmp_path))
|
| 333 |
+
|
| 334 |
+
diloco = make_diloco_outer_loop(
|
| 335 |
+
manager=manager,
|
| 336 |
+
model_fragments=[model],
|
| 337 |
+
inner_optimizer=inner_optim,
|
| 338 |
+
sync_every=2,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
x = torch.randn(2, 4)
|
| 342 |
+
target = torch.randn(2, 4)
|
| 343 |
+
|
| 344 |
+
with diloco:
|
| 345 |
+
for _ in range(4): # 2 outer rounds at sync_every=2
|
| 346 |
+
inner_optim.zero_grad()
|
| 347 |
+
(((model(x) - target) ** 2).mean()).backward()
|
| 348 |
+
inner_optim.step()
|
| 349 |
+
|
| 350 |
+
assert manager.current_step() == 2, (
|
| 351 |
+
f"expected current_step()==2 after two outer rounds, got {manager.current_step()}"
|
| 352 |
+
)
|
|
@@ -225,15 +225,26 @@ def test_mock_manager_shape_compat():
|
|
| 225 |
with tempfile.TemporaryDirectory() as td:
|
| 226 |
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
|
| 227 |
mgr = MockManager(store)
|
| 228 |
-
# torchft.Manager surface
|
| 229 |
assert hasattr(mgr, "allreduce")
|
| 230 |
assert hasattr(mgr, "should_commit")
|
| 231 |
assert hasattr(mgr, "start_quorum")
|
| 232 |
assert hasattr(mgr, "wait_quorum")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
assert mgr.num_participants == 1
|
| 234 |
assert mgr.rank == 0
|
| 235 |
assert mgr.should_commit() is True
|
| 236 |
-
# Single-replica allreduce is a passthrough
|
|
|
|
|
|
|
| 237 |
t = torch.tensor([1.0, 2.0])
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
with tempfile.TemporaryDirectory() as td:
|
| 226 |
store = ObjectStoreAllReduce(td, rank=0, world_size=1, timeout_s=10.0)
|
| 227 |
mgr = MockManager(store)
|
| 228 |
+
# torchft.Manager surface (audited from torchft/local_sgd.py DiLoCo path)
|
| 229 |
assert hasattr(mgr, "allreduce")
|
| 230 |
assert hasattr(mgr, "should_commit")
|
| 231 |
assert hasattr(mgr, "start_quorum")
|
| 232 |
assert hasattr(mgr, "wait_quorum")
|
| 233 |
+
assert hasattr(mgr, "current_step")
|
| 234 |
+
assert hasattr(mgr, "disallow_state_dict_read")
|
| 235 |
+
assert hasattr(mgr, "allow_state_dict_read")
|
| 236 |
+
assert hasattr(mgr, "register_state_dict_fn")
|
| 237 |
+
assert hasattr(mgr, "_use_async_quorum")
|
| 238 |
+
assert mgr._use_async_quorum is False
|
| 239 |
assert mgr.num_participants == 1
|
| 240 |
assert mgr.rank == 0
|
| 241 |
assert mgr.should_commit() is True
|
| 242 |
+
# Single-replica allreduce: averaging is a passthrough, but the return
|
| 243 |
+
# must be a Work-shaped object (DiLoCo calls .wait() on it). The
|
| 244 |
+
# tensor itself is mutated in place by ObjectStoreAllReduce.
|
| 245 |
t = torch.tensor([1.0, 2.0])
|
| 246 |
+
buf = t.clone()
|
| 247 |
+
work = mgr.allreduce(buf)
|
| 248 |
+
assert hasattr(work, "wait") and callable(work.wait)
|
| 249 |
+
assert work.wait() is True
|
| 250 |
+
torch.testing.assert_close(buf, t, atol=1e-6, rtol=1e-6)
|
|
@@ -21,12 +21,27 @@ Channels:
|
|
| 21 |
- lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub)
|
| 22 |
- sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits
|
| 23 |
- trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
from __future__ import annotations
|
| 26 |
|
| 27 |
-
import sys
|
| 28 |
from dataclasses import dataclass
|
| 29 |
-
from
|
| 30 |
|
| 31 |
import torch
|
| 32 |
import torch.nn.functional as F
|
|
@@ -62,6 +77,20 @@ def compose_loss(
|
|
| 62 |
sdpo_token_clip: float | None = None,
|
| 63 |
replay_dpo_beta: float = 0.1,
|
| 64 |
lm_ce_label_smoothing: float = 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
) -> LossComponents:
|
| 66 |
"""Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo.
|
| 67 |
|
|
@@ -73,11 +102,40 @@ def compose_loss(
|
|
| 73 |
SDPO:
|
| 74 |
- ctx_teacher_input_ids: (B, T_t) hint-conditioned context
|
| 75 |
- sdpo_loss_mask: (B, T_t) 1 at error-turn tokens
|
| 76 |
-
DPO:
|
| 77 |
- dpo_chosen_input_ids, dpo_chosen_response_mask
|
| 78 |
- dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 79 |
- dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
device = _device_of(model)
|
| 82 |
|
| 83 |
# ------------------------------------------------------------------
|
|
@@ -92,6 +150,7 @@ def compose_loss(
|
|
| 92 |
|
| 93 |
# ------------------------------------------------------------------
|
| 94 |
# Channel 2 (SDPO): generalized JSD on hint-conditioned forward
|
|
|
|
| 95 |
# ------------------------------------------------------------------
|
| 96 |
sdpo_jsd = _zero(device)
|
| 97 |
if (
|
|
@@ -104,22 +163,58 @@ def compose_loss(
|
|
| 104 |
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
|
| 105 |
|
| 106 |
if student_logits.shape == teacher_logits.shape:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# else: silently zero — the data collator is responsible for shape
|
| 117 |
# alignment in production. For the smoke we accept misalignment and
|
| 118 |
# exercise the fallback path.
|
| 119 |
|
| 120 |
# ------------------------------------------------------------------
|
| 121 |
# Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement
|
| 122 |
-
# pairs.
|
| 123 |
# ------------------------------------------------------------------
|
| 124 |
trace_replay_dpo = _zero(device)
|
| 125 |
if (
|
|
@@ -127,18 +222,42 @@ def compose_loss(
|
|
| 127 |
and "dpo_chosen_input_ids" in inputs
|
| 128 |
and inputs["dpo_chosen_input_ids"].numel() > 0
|
| 129 |
):
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo
|
| 144 |
|
|
@@ -208,4 +327,69 @@ def _sequence_logprobs(
|
|
| 208 |
return masked.sum(dim=-1)
|
| 209 |
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
__all__ = ["compose_loss", "LossComponents"]
|
|
|
|
| 21 |
- lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub)
|
| 22 |
- sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits
|
| 23 |
- trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs
|
| 24 |
+
|
| 25 |
+
ADR-007 extensions
|
| 26 |
+
------------------
|
| 27 |
+
Three pluggable distillation losses can swap the default DPO/SDPO channels:
|
| 28 |
+
|
| 29 |
+
- ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with
|
| 30 |
+
margin) instead of standard DPO. Reference logprobs are no longer required.
|
| 31 |
+
- ``sdpo_wrapper="taid"`` — channel 2 wraps SDPO with TAID (Temporally
|
| 32 |
+
Adaptive Interpolated Distillation). Requires ``taid_schedule_step`` and
|
| 33 |
+
``taid_total_steps`` plus either ``inputs["student_init_logits"]`` or
|
| 34 |
+
``inputs["student_init_input_ids"]`` for the frozen-init forward pass.
|
| 35 |
+
- ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a
|
| 36 |
+
per-token gated forward/reverse KL.
|
| 37 |
+
|
| 38 |
+
All three default to off; passing the new kwargs at their defaults is
|
| 39 |
+
bit-exact equivalent to the legacy 3-channel composition.
|
| 40 |
"""
|
| 41 |
from __future__ import annotations
|
| 42 |
|
|
|
|
| 43 |
from dataclasses import dataclass
|
| 44 |
+
from typing import Literal
|
| 45 |
|
| 46 |
import torch
|
| 47 |
import torch.nn.functional as F
|
|
|
|
| 77 |
sdpo_token_clip: float | None = None,
|
| 78 |
replay_dpo_beta: float = 0.1,
|
| 79 |
lm_ce_label_smoothing: float = 0.0,
|
| 80 |
+
# ADR-007 extensions ------------------------------------------------
|
| 81 |
+
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 82 |
+
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 83 |
+
taid_schedule_step: int | None = None,
|
| 84 |
+
taid_total_steps: int | None = None,
|
| 85 |
+
# SimPO knobs (only used when dpo_variant="simpo") ------------------
|
| 86 |
+
simpo_beta: float = 2.0,
|
| 87 |
+
simpo_gamma: float = 1.0,
|
| 88 |
+
# TAID knobs (only used when sdpo_wrapper="taid") -------------------
|
| 89 |
+
taid_schedule: str = "linear",
|
| 90 |
+
taid_alpha_min: float = 0.0,
|
| 91 |
+
taid_alpha_max: float = 1.0,
|
| 92 |
+
# Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd")
|
| 93 |
+
entropy_opd_h_max: float | None = None,
|
| 94 |
) -> LossComponents:
|
| 95 |
"""Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo.
|
| 96 |
|
|
|
|
| 102 |
SDPO:
|
| 103 |
- ctx_teacher_input_ids: (B, T_t) hint-conditioned context
|
| 104 |
- sdpo_loss_mask: (B, T_t) 1 at error-turn tokens
|
| 105 |
+
DPO (dpo_variant="dpo"):
|
| 106 |
- dpo_chosen_input_ids, dpo_chosen_response_mask
|
| 107 |
- dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 108 |
- dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed)
|
| 109 |
+
SimPO (dpo_variant="simpo"):
|
| 110 |
+
- dpo_chosen_input_ids, dpo_chosen_response_mask
|
| 111 |
+
- dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 112 |
+
(reference logprobs not required and silently ignored)
|
| 113 |
+
TAID (sdpo_wrapper="taid"):
|
| 114 |
+
- student_init_logits: (B, T_t, V) precomputed frozen init logits, OR
|
| 115 |
+
- student_init_input_ids: (B, T_t) frozen student snapshot — a frozen
|
| 116 |
+
forward pass through `model` produces the init logits (this assumes
|
| 117 |
+
`model` has not yet drifted from init; production callers should
|
| 118 |
+
prefer the precomputed path with a saved init snapshot).
|
| 119 |
"""
|
| 120 |
+
if dpo_variant not in ("dpo", "simpo"):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"dpo_variant must be 'dpo' or 'simpo', got {dpo_variant!r}"
|
| 123 |
+
)
|
| 124 |
+
if sdpo_wrapper not in ("none", "taid", "entropy_opd"):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"sdpo_wrapper must be 'none', 'taid', or 'entropy_opd', "
|
| 127 |
+
f"got {sdpo_wrapper!r}"
|
| 128 |
+
)
|
| 129 |
+
if sdpo_wrapper == "taid":
|
| 130 |
+
if taid_schedule_step is None:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"sdpo_wrapper='taid' requires taid_schedule_step (int)"
|
| 133 |
+
)
|
| 134 |
+
if taid_total_steps is None:
|
| 135 |
+
raise ValueError(
|
| 136 |
+
"sdpo_wrapper='taid' requires taid_total_steps (int)"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
device = _device_of(model)
|
| 140 |
|
| 141 |
# ------------------------------------------------------------------
|
|
|
|
| 150 |
|
| 151 |
# ------------------------------------------------------------------
|
| 152 |
# Channel 2 (SDPO): generalized JSD on hint-conditioned forward
|
| 153 |
+
# Optionally wrapped by TAID or replaced by Entropy-Aware OPD.
|
| 154 |
# ------------------------------------------------------------------
|
| 155 |
sdpo_jsd = _zero(device)
|
| 156 |
if (
|
|
|
|
| 163 |
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
|
| 164 |
|
| 165 |
if student_logits.shape == teacher_logits.shape:
|
| 166 |
+
if sdpo_wrapper == "none":
|
| 167 |
+
sdpo_jsd = generalized_jsd_loss(
|
| 168 |
+
student_logits=student_logits,
|
| 169 |
+
teacher_logits=teacher_logits,
|
| 170 |
+
labels=inputs.get("sdpo_loss_mask"),
|
| 171 |
+
beta=sdpo_jsd_beta,
|
| 172 |
+
temperature=sdpo_temperature,
|
| 173 |
+
token_clip=sdpo_token_clip,
|
| 174 |
+
reduction="batchmean",
|
| 175 |
+
)
|
| 176 |
+
elif sdpo_wrapper == "taid":
|
| 177 |
+
from composer_replication.distillation import taid_loss
|
| 178 |
+
|
| 179 |
+
student_init_logits = _resolve_student_init_logits(
|
| 180 |
+
model, inputs, expected_shape=teacher_logits.shape
|
| 181 |
+
)
|
| 182 |
+
# taid_schedule_step / taid_total_steps validated non-None above.
|
| 183 |
+
assert taid_schedule_step is not None
|
| 184 |
+
assert taid_total_steps is not None
|
| 185 |
+
sdpo_jsd = taid_loss(
|
| 186 |
+
student_logits=student_logits,
|
| 187 |
+
teacher_logits=teacher_logits,
|
| 188 |
+
student_init_logits=student_init_logits,
|
| 189 |
+
schedule_step=int(taid_schedule_step),
|
| 190 |
+
total_steps=int(taid_total_steps),
|
| 191 |
+
schedule=taid_schedule,
|
| 192 |
+
alpha_min=taid_alpha_min,
|
| 193 |
+
alpha_max=taid_alpha_max,
|
| 194 |
+
jsd_beta=sdpo_jsd_beta,
|
| 195 |
+
temperature=sdpo_temperature,
|
| 196 |
+
reduction="batchmean",
|
| 197 |
+
)
|
| 198 |
+
elif sdpo_wrapper == "entropy_opd":
|
| 199 |
+
from composer_replication.distillation import (
|
| 200 |
+
entropy_aware_opd_loss,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
sdpo_jsd = entropy_aware_opd_loss(
|
| 204 |
+
student_logits=student_logits,
|
| 205 |
+
teacher_logits=teacher_logits,
|
| 206 |
+
labels=inputs.get("sdpo_loss_mask"),
|
| 207 |
+
h_max=entropy_opd_h_max,
|
| 208 |
+
temperature=sdpo_temperature,
|
| 209 |
+
reduction="batchmean",
|
| 210 |
+
)
|
| 211 |
# else: silently zero — the data collator is responsible for shape
|
| 212 |
# alignment in production. For the smoke we accept misalignment and
|
| 213 |
# exercise the fallback path.
|
| 214 |
|
| 215 |
# ------------------------------------------------------------------
|
| 216 |
# Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement
|
| 217 |
+
# pairs. With dpo_variant="simpo", swap to SimPO (reference-free).
|
| 218 |
# ------------------------------------------------------------------
|
| 219 |
trace_replay_dpo = _zero(device)
|
| 220 |
if (
|
|
|
|
| 222 |
and "dpo_chosen_input_ids" in inputs
|
| 223 |
and inputs["dpo_chosen_input_ids"].numel() > 0
|
| 224 |
):
|
| 225 |
+
if dpo_variant == "dpo":
|
| 226 |
+
chosen_lp = _sequence_logprobs(
|
| 227 |
+
model,
|
| 228 |
+
inputs["dpo_chosen_input_ids"],
|
| 229 |
+
inputs["dpo_chosen_response_mask"],
|
| 230 |
+
)
|
| 231 |
+
rejected_lp = _sequence_logprobs(
|
| 232 |
+
model,
|
| 233 |
+
inputs["dpo_rejected_input_ids"],
|
| 234 |
+
inputs["dpo_rejected_response_mask"],
|
| 235 |
+
)
|
| 236 |
+
ref_chosen = inputs["dpo_chosen_ref_logprobs"]
|
| 237 |
+
ref_rejected = inputs["dpo_rejected_ref_logprobs"]
|
| 238 |
+
dpo_logits = replay_dpo_beta * (
|
| 239 |
+
(chosen_lp - ref_chosen) - (rejected_lp - ref_rejected)
|
| 240 |
+
)
|
| 241 |
+
trace_replay_dpo = -F.logsigmoid(dpo_logits).mean()
|
| 242 |
+
else: # dpo_variant == "simpo"
|
| 243 |
+
from composer_replication.distillation import simpo_loss
|
| 244 |
+
|
| 245 |
+
chosen_avg_lp = _avg_sequence_logprobs(
|
| 246 |
+
model,
|
| 247 |
+
inputs["dpo_chosen_input_ids"],
|
| 248 |
+
inputs["dpo_chosen_response_mask"],
|
| 249 |
+
)
|
| 250 |
+
rejected_avg_lp = _avg_sequence_logprobs(
|
| 251 |
+
model,
|
| 252 |
+
inputs["dpo_rejected_input_ids"],
|
| 253 |
+
inputs["dpo_rejected_response_mask"],
|
| 254 |
+
)
|
| 255 |
+
trace_replay_dpo = simpo_loss(
|
| 256 |
+
chosen_avg_logprobs=chosen_avg_lp,
|
| 257 |
+
rejected_avg_logprobs=rejected_avg_lp,
|
| 258 |
+
beta=simpo_beta,
|
| 259 |
+
gamma=simpo_gamma,
|
| 260 |
+
)
|
| 261 |
|
| 262 |
total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo
|
| 263 |
|
|
|
|
| 327 |
return masked.sum(dim=-1)
|
| 328 |
|
| 329 |
|
| 330 |
+
def _avg_sequence_logprobs(
|
| 331 |
+
model: torch.nn.Module,
|
| 332 |
+
input_ids: torch.Tensor,
|
| 333 |
+
response_mask: torch.Tensor,
|
| 334 |
+
) -> torch.Tensor:
|
| 335 |
+
"""Per-sequence AVERAGE next-token logprob over response tokens.
|
| 336 |
+
|
| 337 |
+
SimPO accounting: divide the sum by the number of response tokens so
|
| 338 |
+
long sequences aren't penalized for length.
|
| 339 |
+
"""
|
| 340 |
+
outputs = model(input_ids=input_ids)
|
| 341 |
+
logits = outputs.logits[:, :-1, :]
|
| 342 |
+
targets = input_ids[:, 1:]
|
| 343 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 344 |
+
token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
|
| 345 |
+
mask = response_mask[:, 1:].float()
|
| 346 |
+
masked = token_lp * mask
|
| 347 |
+
n_tokens = mask.sum(dim=-1).clamp_min(1.0)
|
| 348 |
+
return masked.sum(dim=-1) / n_tokens
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def _resolve_student_init_logits(
|
| 352 |
+
model: torch.nn.Module,
|
| 353 |
+
inputs: dict[str, torch.Tensor],
|
| 354 |
+
*,
|
| 355 |
+
expected_shape: torch.Size,
|
| 356 |
+
) -> torch.Tensor:
|
| 357 |
+
"""Return frozen student-init logits for TAID.
|
| 358 |
+
|
| 359 |
+
Preferred path: caller pre-saves a snapshot at training step 0 and passes
|
| 360 |
+
it via ``inputs['student_init_logits']``. Fallback path (only valid early
|
| 361 |
+
in training before the model has drifted): pass
|
| 362 |
+
``inputs['student_init_input_ids']`` and we run a no-grad forward through
|
| 363 |
+
``model``. Always returns a tensor on the same device as ``model``.
|
| 364 |
+
"""
|
| 365 |
+
if "student_init_logits" in inputs and inputs["student_init_logits"].numel() > 0:
|
| 366 |
+
student_init = inputs["student_init_logits"]
|
| 367 |
+
if student_init.shape != expected_shape:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"inputs['student_init_logits'] shape {tuple(student_init.shape)} "
|
| 370 |
+
f"does not match teacher logits shape {tuple(expected_shape)}"
|
| 371 |
+
)
|
| 372 |
+
return student_init.detach()
|
| 373 |
+
|
| 374 |
+
if (
|
| 375 |
+
"student_init_input_ids" in inputs
|
| 376 |
+
and inputs["student_init_input_ids"].numel() > 0
|
| 377 |
+
):
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
init_logits = model(input_ids=inputs["student_init_input_ids"]).logits
|
| 380 |
+
if init_logits.shape != expected_shape:
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f"frozen forward on student_init_input_ids gave shape "
|
| 383 |
+
f"{tuple(init_logits.shape)} which does not match teacher "
|
| 384 |
+
f"logits shape {tuple(expected_shape)}"
|
| 385 |
+
)
|
| 386 |
+
return init_logits
|
| 387 |
+
|
| 388 |
+
raise ValueError(
|
| 389 |
+
"sdpo_wrapper='taid' requires either inputs['student_init_logits'] "
|
| 390 |
+
"(precomputed) or inputs['student_init_input_ids'] (frozen forward "
|
| 391 |
+
"fallback) to be present."
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
__all__ = ["compose_loss", "LossComponents"]
|
|
@@ -1,22 +1,90 @@
|
|
| 1 |
-
"""PRIME-RL composer loss adapter
|
| 2 |
|
| 3 |
-
Per ADR-006, PRIME-RL exposes a `CustomLossConfig` that takes an
|
| 4 |
importable function. This module supplies that function: a thin adapter
|
| 5 |
-
that maps PRIME-RL's `LossInputs` struct onto the framework's 3-channel
|
| 6 |
loss composition.
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
from __future__ import annotations
|
| 22 |
|
|
@@ -26,79 +94,151 @@ from typing import Any
|
|
| 26 |
def loss_fn(
|
| 27 |
inputs: Any, # PRIME-RL's LossInputs — typed as Any to avoid hard import
|
| 28 |
*,
|
| 29 |
-
alpha_sdpo: float = 0.
|
| 30 |
-
beta_dpo: float = 0.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
- Channel 1 reads from `advantages` + `trainer_logprobs` directly.
|
| 42 |
-
(Note: this is REINFORCE-with-advantage, not full GRPO. Full
|
| 43 |
-
GRPO would use `inference_logprobs` for the importance-sampling
|
| 44 |
-
ratio + PPO clipping. See Wave 13 review Finding 6.)
|
| 45 |
-
- Channel 2 (SDPO) is **DEFERRED** for v0 because PRIME-RL v0.5
|
| 46 |
-
exposes log-probs not logits, and SDPO needs the full vocab
|
| 47 |
-
distribution. Setting alpha_sdpo>0 raises NotImplementedError
|
| 48 |
-
(Wave 13 review Finding 1 — earlier draft was silently degenerate).
|
| 49 |
-
- Channel 3 (DPO) is OUT OF SCOPE for the PRIME-RL recipe in v0
|
| 50 |
-
— it would require modifying PRIME-RL's data path to pass
|
| 51 |
-
`(chosen, rejected)` pairs alongside the rollout, which is a
|
| 52 |
-
separate integration effort. v0 emits beta_dpo=0 with a
|
| 53 |
-
warning if non-zero.
|
| 54 |
|
| 55 |
Args:
|
| 56 |
-
inputs: PRIME-RL `LossInputs` (duck-typed)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
Returns:
|
| 62 |
-
Scalar torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
-
import torch # lazy
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
# Channel 1: GRPO
|
| 68 |
advantages = inputs.advantages
|
| 69 |
trainer_lp = inputs.trainer_logprobs
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
#
|
| 79 |
-
# Wave 13 cross-model review
|
| 80 |
-
#
|
| 81 |
-
#
|
| 82 |
-
#
|
| 83 |
-
#
|
| 84 |
-
#
|
| 85 |
-
#
|
| 86 |
-
#
|
| 87 |
-
# The right path forward depends on PRIME-RL exposing full logits, not
|
| 88 |
-
# just log-probs. Until that lands upstream, refuse to fake the channel:
|
| 89 |
teacher_lp = getattr(inputs, "teacher_logprobs", None)
|
| 90 |
-
if
|
| 91 |
raise NotImplementedError(
|
| 92 |
-
"SDPO channel in the PRIME-RL recipe is deferred. PRIME-RL
|
| 93 |
-
"exposes (
|
| 94 |
-
"and SDPO/OPSD requires the full
|
| 95 |
-
"Set alpha_sdpo=0.0 to silence this and use
|
| 96 |
-
"
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
-
# Channel 3: not supported in PRIME-RL recipe v0
|
| 100 |
if beta_dpo != 0.0:
|
| 101 |
import warnings
|
|
|
|
| 102 |
warnings.warn(
|
| 103 |
"PRIME-RL recipe v0 does not support DPO channel; "
|
| 104 |
"set beta_dpo=0.0 to silence this warning.",
|
|
|
|
| 1 |
+
"""PRIME-RL composer loss adapter.
|
| 2 |
|
| 3 |
+
Per ADR-006, PRIME-RL exposes a ``CustomLossConfig`` that takes an
|
| 4 |
importable function. This module supplies that function: a thin adapter
|
| 5 |
+
that maps PRIME-RL's ``LossInputs`` struct onto the framework's 3-channel
|
| 6 |
loss composition.
|
| 7 |
|
| 8 |
+
Channel status (v0):
|
| 9 |
+
1. **DPPO + KL on the importance-sampling ratio** — implemented to
|
| 10 |
+
match PRIME-RL's upstream ``default_loss_fn`` byte-for-byte.
|
| 11 |
+
2. **SDPO / OPSD** — deferred (raises ``NotImplementedError`` when
|
| 12 |
+
enabled). PRIME-RL v0.5 exposes log-probs, not full logits, and
|
| 13 |
+
SDPO requires the full vocabulary distribution.
|
| 14 |
+
3. **Trace-replay DPO** — out of scope for this recipe; emits a
|
| 15 |
+
warning if ``beta_dpo != 0``.
|
| 16 |
+
|
| 17 |
+
LossInputs shape (verified against PrimeIntellect-ai/prime-rl
|
| 18 |
+
``src/prime_rl/trainer/rl/loss.py`` lines 13-22):
|
| 19 |
+
|
| 20 |
+
.. code-block:: python
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LossInputs:
|
| 24 |
+
trainer_logprobs: Float[Tensor, ' seq'] # current policy log-probs
|
| 25 |
+
inference_logprobs: Float[Tensor, ' seq'] # rollout-time policy log-probs
|
| 26 |
+
teacher_logprobs: Float[Tensor, ' seq'] | None
|
| 27 |
+
advantages: Float[Tensor, ' seq'] # per-token advantage
|
| 28 |
+
loss_mask: Bool[Tensor, ' seq'] # which tokens count
|
| 29 |
+
|
| 30 |
+
PRIME-RL calls the loss function once per sample, not on a batched
|
| 31 |
+
``(B, T)`` tensor.
|
| 32 |
+
|
| 33 |
+
PRIME-RL's ``default_loss_fn`` (upstream)
|
| 34 |
+
-----------------------------------------
|
| 35 |
+
Verbatim from ``prime_rl/trainer/rl/loss.py`` lines 116-165 and the
|
| 36 |
+
``DefaultLossConfig`` defaults at
|
| 37 |
+
``packages/prime-rl-configs/src/prime_rl/configs/trainer.py`` lines
|
| 38 |
+
412-425::
|
| 39 |
+
|
| 40 |
+
def default_loss_fn(inputs, loss_config):
|
| 41 |
+
# line 133-135
|
| 42 |
+
log_importance_ratio = trainer_logprobs - inference_logprobs
|
| 43 |
+
importance_ratio = exp(log_importance_ratio)
|
| 44 |
+
mismatch_kl = importance_ratio - log_importance_ratio - 1
|
| 45 |
+
|
| 46 |
+
# line 137: NOTE — probability-space diff, not log-ratio
|
| 47 |
+
probs_diff = exp(trainer_logprobs) - exp(inference_logprobs)
|
| 48 |
+
# lines 138-139
|
| 49 |
+
dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high
|
| 50 |
+
dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low
|
| 51 |
+
# lines 140-142: sign-of-advantage gate
|
| 52 |
+
positive_advantages = advantages > 0
|
| 53 |
+
dppo_invalid_mask = where(positive_advantages,
|
| 54 |
+
dppo_invalid_mask_high,
|
| 55 |
+
dppo_invalid_mask_low)
|
| 56 |
+
# lines 147-148
|
| 57 |
+
drop_mask = loss_mask & dppo_invalid_mask
|
| 58 |
+
keep_mask = loss_mask & ~dppo_invalid_mask
|
| 59 |
+
|
| 60 |
+
# lines 150-153
|
| 61 |
+
advantages = loss_config.adv_tau * advantages
|
| 62 |
+
pg_loss = keep_mask * advantages * importance_ratio
|
| 63 |
+
kl_loss = loss_mask * log_importance_ratio**2
|
| 64 |
+
loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum()
|
| 65 |
+
|
| 66 |
+
Defaults: ``dppo_mask_low=0.2``, ``dppo_mask_high=0.2``,
|
| 67 |
+
``adv_tau=1.0``, ``kl_tau=1e-3`` — all ``Field(..., ge=0)``.
|
| 68 |
+
|
| 69 |
+
Three things this differs from a textbook PPO-clip:
|
| 70 |
+
1. The mask gate is on **probability-space** ``probs_diff``, not on
|
| 71 |
+
the log-ratio. ``-loss_config.dppo_mask_low`` flips the sign so
|
| 72 |
+
``dppo_mask_low`` is itself non-negative.
|
| 73 |
+
2. The policy-gradient term is multiplied by ``importance_ratio``
|
| 74 |
+
(= ``exp(trainer_lp - inference_lp)``), giving a proper IS-corrected
|
| 75 |
+
gradient — not a plain REINFORCE on ``trainer_lp``.
|
| 76 |
+
3. The mask is **conditioned on advantage sign**: a positive-advantage
|
| 77 |
+
token is dropped when ``probs_diff`` exceeds ``dppo_mask_high``
|
| 78 |
+
(we'd be upweighting it too aggressively); a negative-advantage
|
| 79 |
+
token is dropped when ``probs_diff`` falls below ``-dppo_mask_low``
|
| 80 |
+
(we'd be downweighting it too aggressively). Zero-advantage tokens
|
| 81 |
+
are never DPPO-masked.
|
| 82 |
+
|
| 83 |
+
The reduction is a plain ``sum()`` (PRIME-RL's outer ``compute_loss``
|
| 84 |
+
divides by ``loss_scale``); we mirror that.
|
| 85 |
+
|
| 86 |
+
License: MIT (matches the rest of the framework). PRIME-RL is Apache-2;
|
| 87 |
+
we reference its algorithm and convention but vendor no code.
|
| 88 |
"""
|
| 89 |
from __future__ import annotations
|
| 90 |
|
|
|
|
| 94 |
def loss_fn(
|
| 95 |
inputs: Any, # PRIME-RL's LossInputs — typed as Any to avoid hard import
|
| 96 |
*,
|
| 97 |
+
alpha_sdpo: float = 0.0,
|
| 98 |
+
beta_dpo: float = 0.0,
|
| 99 |
+
dppo_mask_high: float = 0.2,
|
| 100 |
+
dppo_mask_low: float = 0.2,
|
| 101 |
+
adv_tau: float = 1.0,
|
| 102 |
+
kl_tau: float = 1e-3,
|
| 103 |
+
) -> Any: # Returns a torch.Tensor (scalar) matching PRIME-RL's contract
|
| 104 |
+
"""Composer 3-channel loss adapted to PRIME-RL's ``LossInputs`` struct.
|
| 105 |
+
|
| 106 |
+
Channel 1 mirrors PRIME-RL's ``default_loss_fn`` exactly so configs
|
| 107 |
+
from PRIME-RL's own examples translate. Channels 2 and 3 are
|
| 108 |
+
deferred — see module docstring.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
Args:
|
| 111 |
+
inputs: PRIME-RL ``LossInputs`` (duck-typed). All tensor fields
|
| 112 |
+
are expected to be 1-D with shape ``(seq,)``.
|
| 113 |
+
alpha_sdpo: weight on channel 2 (SDPO). Must be 0 in v0; >0
|
| 114 |
+
raises :class:`NotImplementedError`.
|
| 115 |
+
beta_dpo: weight on channel 3 (DPO). Non-zero emits a warning;
|
| 116 |
+
channel 3 is not yet wired in this recipe.
|
| 117 |
+
dppo_mask_high: upper DPPO masking threshold on
|
| 118 |
+
``exp(trainer_lp) - exp(inference_lp)``. Tokens with
|
| 119 |
+
**positive advantage** whose ``probs_diff`` exceeds this
|
| 120 |
+
value are dropped. PRIME-RL default: ``0.2``. Must be >= 0.
|
| 121 |
+
dppo_mask_low: magnitude of the lower DPPO masking threshold.
|
| 122 |
+
Tokens with **negative advantage** whose ``probs_diff`` is
|
| 123 |
+
below ``-dppo_mask_low`` are dropped. PRIME-RL default:
|
| 124 |
+
``0.2``. Must be >= 0 (note: PRIME-RL stores the magnitude;
|
| 125 |
+
the sign flip is internal to the comparison).
|
| 126 |
+
adv_tau: temperature on the advantage term. PRIME-RL default
|
| 127 |
+
``1.0``. Must be >= 0.
|
| 128 |
+
kl_tau: temperature on the KL term ``log_importance_ratio**2``.
|
| 129 |
+
PRIME-RL default ``1e-3``. Must be >= 0.
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
+
Scalar ``torch.Tensor``. PRIME-RL's outer ``compute_loss``
|
| 133 |
+
divides by ``loss_scale`` and calls ``.backward()``.
|
| 134 |
+
|
| 135 |
+
Raises:
|
| 136 |
+
ValueError: if any of ``trainer_logprobs``, ``inference_logprobs``,
|
| 137 |
+
``advantages``, ``loss_mask`` is not 1-D, or any of
|
| 138 |
+
``dppo_mask_high``, ``dppo_mask_low``, ``adv_tau``, ``kl_tau``
|
| 139 |
+
is negative.
|
| 140 |
+
NotImplementedError: if ``alpha_sdpo > 0`` (channel 2 is deferred).
|
| 141 |
"""
|
| 142 |
+
import torch # lazy — keep module importable without torch installed
|
| 143 |
+
|
| 144 |
+
# PRIME-RL enforces these via Pydantic Field(..., ge=0); we mirror it.
|
| 145 |
+
for name, val in (
|
| 146 |
+
("dppo_mask_high", dppo_mask_high),
|
| 147 |
+
("dppo_mask_low", dppo_mask_low),
|
| 148 |
+
("adv_tau", adv_tau),
|
| 149 |
+
("kl_tau", kl_tau),
|
| 150 |
+
):
|
| 151 |
+
if val < 0:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
f"{name} must be >= 0 (PRIME-RL config contract); got {val}"
|
| 154 |
+
)
|
| 155 |
|
|
|
|
| 156 |
advantages = inputs.advantages
|
| 157 |
trainer_lp = inputs.trainer_logprobs
|
| 158 |
+
inference_lp = inputs.inference_logprobs
|
| 159 |
+
loss_mask = inputs.loss_mask
|
| 160 |
+
|
| 161 |
+
# --- Shape validation -------------------------------------------------
|
| 162 |
+
# PRIME-RL passes per-sample (seq,) tensors. Reject (B, T) explicitly so
|
| 163 |
+
# callers don't silently get the wrong reduction.
|
| 164 |
+
for name, t in (
|
| 165 |
+
("trainer_logprobs", trainer_lp),
|
| 166 |
+
("inference_logprobs", inference_lp),
|
| 167 |
+
("advantages", advantages),
|
| 168 |
+
("loss_mask", loss_mask),
|
| 169 |
+
):
|
| 170 |
+
if t.dim() != 1:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"PRIME-RL loss_fn expects 1-D (seq,) tensors per "
|
| 173 |
+
f"PRIME-RL's LossInputs contract; got {name} with shape "
|
| 174 |
+
f"{tuple(t.shape)} (dim={t.dim()}). PRIME-RL calls the loss "
|
| 175 |
+
f"function once per sample, not on a batched (B, T) tensor."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# --- Channel 1: DPPO + KL on the importance ratio --------------------
|
| 179 |
+
# Mirrors prime_rl/trainer/rl/loss.py default_loss_fn lines 133-153.
|
| 180 |
+
|
| 181 |
+
log_importance_ratio = trainer_lp - inference_lp
|
| 182 |
+
importance_ratio = torch.exp(log_importance_ratio)
|
| 183 |
+
|
| 184 |
+
# NOTE: probability-space diff, NOT log-ratio. This is the key
|
| 185 |
+
# divergence from a naive PPO-clip implementation.
|
| 186 |
+
probs_diff = torch.exp(trainer_lp) - torch.exp(inference_lp)
|
| 187 |
+
|
| 188 |
+
dppo_invalid_mask_high = probs_diff > dppo_mask_high
|
| 189 |
+
dppo_invalid_mask_low = probs_diff < -dppo_mask_low
|
| 190 |
+
|
| 191 |
+
positive_advantages = advantages > 0
|
| 192 |
+
# Sign-of-advantage gate: positive-advantage tokens use the "high"
|
| 193 |
+
# threshold; negative-advantage tokens use the "low" threshold.
|
| 194 |
+
# Zero-advantage tokens fall through ``positive_advantages == False``,
|
| 195 |
+
# so they are gated by the (negative-advantage) low check; in practice
|
| 196 |
+
# zero-advantage tokens contribute zero to ``pg_loss`` regardless.
|
| 197 |
+
dppo_invalid_mask = torch.where(
|
| 198 |
+
positive_advantages, dppo_invalid_mask_high, dppo_invalid_mask_low
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# loss_mask may be bool; combine via boolean ops to match upstream
|
| 202 |
+
# exactly, then cast to the working dtype for the multiply.
|
| 203 |
+
if loss_mask.dtype != torch.bool:
|
| 204 |
+
loss_mask_bool = loss_mask.to(torch.bool)
|
| 205 |
+
else:
|
| 206 |
+
loss_mask_bool = loss_mask
|
| 207 |
+
keep_mask_bool = loss_mask_bool & ~dppo_invalid_mask
|
| 208 |
+
keep_mask = keep_mask_bool.to(trainer_lp.dtype)
|
| 209 |
+
loss_mask_f = loss_mask_bool.to(trainer_lp.dtype)
|
| 210 |
+
|
| 211 |
+
scaled_advantages = adv_tau * advantages
|
| 212 |
+
pg_loss = keep_mask * scaled_advantages * importance_ratio
|
| 213 |
+
kl_loss = loss_mask_f * log_importance_ratio**2
|
| 214 |
+
total = (-pg_loss + kl_tau * kl_loss).sum()
|
| 215 |
+
|
| 216 |
+
# --- Channel 2: SDPO/OPSD — DEFERRED in PRIME-RL recipe v0 -----------
|
| 217 |
#
|
| 218 |
+
# Wave 13 cross-model review caught that an earlier draft applied
|
| 219 |
+
# `unsqueeze(-1)` to log-prob tensors before generalized_jsd_loss,
|
| 220 |
+
# which calls log_softmax(dim=-1). Softmax of a 1-element vector is
|
| 221 |
+
# exactly 1.0; its log is 0. The SDPO term was mathematically
|
| 222 |
+
# degenerate (always 0), silently disabling channel 2 while reporting
|
| 223 |
+
# alpha_sdpo>0 in the config. Until PRIME-RL exposes full logits we
|
| 224 |
+
# refuse to fake the channel:
|
|
|
|
|
|
|
|
|
|
| 225 |
teacher_lp = getattr(inputs, "teacher_logprobs", None)
|
| 226 |
+
if alpha_sdpo > 0:
|
| 227 |
raise NotImplementedError(
|
| 228 |
+
"SDPO channel in the PRIME-RL recipe is deferred. PRIME-RL "
|
| 229 |
+
"v0.5 exposes (seq,) log-probs through LossInputs but not "
|
| 230 |
+
"full vocabulary logits, and SDPO/OPSD requires the full "
|
| 231 |
+
"distribution. Set alpha_sdpo=0.0 to silence this and use "
|
| 232 |
+
"channel 1 (DPPO+KL) only. teacher_logprobs is "
|
| 233 |
+
f"{'present' if teacher_lp is not None else 'absent'} in this "
|
| 234 |
+
"call but unused. See docs/research/WAVE_13_FINAL_REVIEW.md "
|
| 235 |
+
"Finding 1."
|
| 236 |
)
|
| 237 |
|
| 238 |
+
# --- Channel 3: not supported in PRIME-RL recipe v0 -------------------
|
| 239 |
if beta_dpo != 0.0:
|
| 240 |
import warnings
|
| 241 |
+
|
| 242 |
warnings.warn(
|
| 243 |
"PRIME-RL recipe v0 does not support DPO channel; "
|
| 244 |
"set beta_dpo=0.0 to silence this warning.",
|
|
@@ -25,9 +25,23 @@ loss:
|
|
| 25 |
# The function MUST return a scalar tensor (PRIME-RL handles backward).
|
| 26 |
import_path: "composer_replication.recipes.prime_rl.composer_loss:loss_fn"
|
| 27 |
kwargs:
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# --- PRIME-RL three-actor split --------------------------------------
|
| 33 |
trainer:
|
|
|
|
| 25 |
# The function MUST return a scalar tensor (PRIME-RL handles backward).
|
| 26 |
import_path: "composer_replication.recipes.prime_rl.composer_loss:loss_fn"
|
| 27 |
kwargs:
|
| 28 |
+
# Channel 2 (SDPO/OPSD) is deferred in v0 — set >0 to fail fast
|
| 29 |
+
# rather than silently no-op until PRIME-RL exposes full logits.
|
| 30 |
+
alpha_sdpo: 0.0
|
| 31 |
+
# Channel 3 (DPO) is out-of-scope for PRIME-RL recipe v0.
|
| 32 |
+
beta_dpo: 0.0
|
| 33 |
+
# DPPO mask thresholds (PRIME-RL convention, NOT textbook PPO).
|
| 34 |
+
# Tokens whose probability-space diff
|
| 35 |
+
# exp(trainer_lp) - exp(inference_lp)
|
| 36 |
+
# exceeds dppo_mask_high (for positive-advantage tokens) or falls
|
| 37 |
+
# below -dppo_mask_low (for negative-advantage tokens) are dropped
|
| 38 |
+
# from the policy-gradient term. Defaults match PRIME-RL's
|
| 39 |
+
# DefaultLossConfig (Field(..., ge=0), so both must be non-negative).
|
| 40 |
+
dppo_mask_high: 0.2
|
| 41 |
+
dppo_mask_low: 0.2
|
| 42 |
+
# Advantage / KL temperatures from PRIME-RL DefaultLossConfig.
|
| 43 |
+
adv_tau: 1.0
|
| 44 |
+
kl_tau: 1.0e-3
|
| 45 |
|
| 46 |
# --- PRIME-RL three-actor split --------------------------------------
|
| 47 |
trainer:
|
|
@@ -14,14 +14,19 @@ the tensors we need:
|
|
| 14 |
```python
|
| 15 |
@dataclass
|
| 16 |
class LossInputs:
|
| 17 |
-
trainer_logprobs:
|
| 18 |
-
inference_logprobs: Tensor
|
| 19 |
-
|
| 20 |
-
teacher_logprobs:
|
| 21 |
-
advantages:
|
| 22 |
-
loss_mask:
|
| 23 |
```
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
The user wires this in via a YAML config field — no fork, no Trainer
|
| 26 |
subclass, no monkey-patching:
|
| 27 |
|
|
@@ -31,8 +36,12 @@ loss:
|
|
| 31 |
custom:
|
| 32 |
import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
|
| 33 |
kwargs:
|
| 34 |
-
alpha_sdpo:
|
| 35 |
-
beta_dpo:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
```
|
| 37 |
|
| 38 |
## Step-by-step
|
|
@@ -44,26 +53,78 @@ pip install prime-rl>=0.5
|
|
| 44 |
```
|
| 45 |
|
| 46 |
### 2. Drop in the composer loss
|
|
|
|
| 47 |
The framework ships `composer_replication.recipes.prime_rl.composer_loss`
|
| 48 |
which adapts the 3-channel `compose_loss` to PRIME-RL's `LossInputs`
|
| 49 |
-
struct.
|
|
|
|
| 50 |
|
| 51 |
```python
|
| 52 |
-
def loss_fn(
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
```
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
### 3. PRIME-RL config
|
| 69 |
|
|
@@ -99,9 +160,16 @@ naturally with the framework's plug-in points.
|
|
| 99 |
- An actual training run yet — that's a separate spike.
|
| 100 |
- Quality validation against TRL/VeRL — pending Spike 004 A/B.
|
| 101 |
- Hardware autoscaling — that's the Monarch recipe's job (recipes/monarch/).
|
|
|
|
|
|
|
| 102 |
|
| 103 |
## References
|
| 104 |
|
| 105 |
- PRIME-RL repo: https://github.com/PrimeIntellect-ai/prime-rl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
- ADR-006: docs/adrs/ADR-006-rl-frameworks.md
|
| 107 |
- Reconnaissance: docs/research/RL_FRAMEWORKS_LANDSCAPE.md (§ PRIME-RL)
|
|
|
|
| 14 |
```python
|
| 15 |
@dataclass
|
| 16 |
class LossInputs:
|
| 17 |
+
trainer_logprobs: Tensor[' seq'] # current-policy log-probs (per-sample, 1-D)
|
| 18 |
+
inference_logprobs: Tensor[' seq'] # rollout-time policy log-probs
|
| 19 |
+
# (importance-sampling-ratio denominator)
|
| 20 |
+
teacher_logprobs: Tensor[' seq'] | None # if the teacher channel is wired in
|
| 21 |
+
advantages: Tensor[' seq'] # GRPO advantages (channel 1)
|
| 22 |
+
loss_mask: Tensor[' seq'] # response-token mask
|
| 23 |
```
|
| 24 |
|
| 25 |
+
> **Shape note.** PRIME-RL calls the loss function **once per sample**;
|
| 26 |
+
> the tensors above are 1-D ``(seq,)``, *not* batched ``(B, T)``. An
|
| 27 |
+
> earlier draft of `composer_loss.py` assumed `(B, T)` and was caught
|
| 28 |
+
> in the Wave 13 cross-model review.
|
| 29 |
+
|
| 30 |
The user wires this in via a YAML config field — no fork, no Trainer
|
| 31 |
subclass, no monkey-patching:
|
| 32 |
|
|
|
|
| 36 |
custom:
|
| 37 |
import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
|
| 38 |
kwargs:
|
| 39 |
+
alpha_sdpo: 0.0 # channel 2 deferred in v0 — see below
|
| 40 |
+
beta_dpo: 0.0 # channel 3 out-of-scope in v0
|
| 41 |
+
dppo_mask_high: 0.2 # PRIME-RL DPPO mask (probability-space)
|
| 42 |
+
dppo_mask_low: 0.2
|
| 43 |
+
adv_tau: 1.0
|
| 44 |
+
kl_tau: 1.0e-3
|
| 45 |
```
|
| 46 |
|
| 47 |
## Step-by-step
|
|
|
|
| 53 |
```
|
| 54 |
|
| 55 |
### 2. Drop in the composer loss
|
| 56 |
+
|
| 57 |
The framework ships `composer_replication.recipes.prime_rl.composer_loss`
|
| 58 |
which adapts the 3-channel `compose_loss` to PRIME-RL's `LossInputs`
|
| 59 |
+
struct. Channel 1 mirrors PRIME-RL's upstream `default_loss_fn` exactly
|
| 60 |
+
(verified in `prime_rl/trainer/rl/loss.py` lines 116-165):
|
| 61 |
|
| 62 |
```python
|
| 63 |
+
def loss_fn(
|
| 64 |
+
inputs: LossInputs,
|
| 65 |
+
*,
|
| 66 |
+
alpha_sdpo: float = 0.0,
|
| 67 |
+
beta_dpo: float = 0.0,
|
| 68 |
+
# PRIME-RL DefaultLossConfig defaults (Field(..., ge=0)):
|
| 69 |
+
dppo_mask_high: float = 0.2,
|
| 70 |
+
dppo_mask_low: float = 0.2,
|
| 71 |
+
adv_tau: float = 1.0,
|
| 72 |
+
kl_tau: float = 1e-3,
|
| 73 |
+
) -> Tensor:
|
| 74 |
+
# ----- Channel 1: DPPO + KL on the importance ratio -----
|
| 75 |
+
log_ir = trainer_lp - inference_lp
|
| 76 |
+
ir = exp(log_ir)
|
| 77 |
+
probs_diff = exp(trainer_lp) - exp(inference_lp) # NB: probability space
|
| 78 |
+
invalid_high = probs_diff > dppo_mask_high
|
| 79 |
+
invalid_low = probs_diff < -dppo_mask_low
|
| 80 |
+
pos_adv = advantages > 0
|
| 81 |
+
invalid = where(pos_adv, invalid_high, invalid_low) # advantage-conditioned
|
| 82 |
+
keep = loss_mask & ~invalid
|
| 83 |
+
|
| 84 |
+
pg_loss = keep * (adv_tau * advantages) * ir
|
| 85 |
+
kl_loss = loss_mask * log_ir**2
|
| 86 |
+
loss = (-pg_loss + kl_tau * kl_loss).sum() # SUM, not mean
|
| 87 |
+
|
| 88 |
+
# ----- Channel 2: SDPO/OPSD against teacher_logprobs -----
|
| 89 |
+
# DEFERRED — PRIME-RL v0.5 exposes log-probs not full logits.
|
| 90 |
+
# alpha_sdpo > 0 raises NotImplementedError until that lands.
|
| 91 |
+
|
| 92 |
+
# ----- Channel 3: trace-replay DPO -----
|
| 93 |
+
# Out of scope for PRIME-RL recipe v0.
|
| 94 |
+
|
| 95 |
+
return loss
|
| 96 |
```
|
| 97 |
|
| 98 |
+
**DPPO clip semantics — three things to know.** This is *not* a
|
| 99 |
+
textbook PPO clipped surrogate; it is exactly what PRIME-RL's
|
| 100 |
+
`default_loss_fn` does:
|
| 101 |
+
|
| 102 |
+
1. The mask gate is on **probability-space**
|
| 103 |
+
`probs_diff = exp(trainer_lp) - exp(inference_lp)`, **not** on the
|
| 104 |
+
log-ratio. (`-dppo_mask_low` flips the sign so the threshold itself
|
| 105 |
+
is non-negative; PRIME-RL stores both bounds as `Field(..., ge=0)`.)
|
| 106 |
+
2. The policy-gradient term is multiplied by
|
| 107 |
+
`importance_ratio = exp(trainer_lp - inference_lp)`, not by
|
| 108 |
+
`trainer_lp` directly — this is a proper IS-corrected gradient, not
|
| 109 |
+
plain REINFORCE.
|
| 110 |
+
3. The mask is **conditioned on the sign of the advantage**: a
|
| 111 |
+
positive-advantage token is dropped iff `probs_diff > dppo_mask_high`
|
| 112 |
+
(we'd be upweighting an already-too-high-probability token); a
|
| 113 |
+
negative-advantage token is dropped iff `probs_diff < -dppo_mask_low`
|
| 114 |
+
(we'd be downweighting an already-too-low-probability token). The
|
| 115 |
+
gates are *not* OR'd together.
|
| 116 |
+
|
| 117 |
+
The reduction is a plain `sum()`; PRIME-RL's outer `compute_loss`
|
| 118 |
+
divides by `loss_scale` and aggregates across the packed batch.
|
| 119 |
+
|
| 120 |
+
There is also a per-token KL term `kl_tau * log_importance_ratio**2`
|
| 121 |
+
(the Kimi-K2.5 KL, see PRIME-RL's docstring on line 119-126), gated by
|
| 122 |
+
`loss_mask` only — DPPO masking does not affect it.
|
| 123 |
+
|
| 124 |
+
Concrete file: `composer_loss.py` in this directory. Tests:
|
| 125 |
+
`tests/test_composer_loss.py` (16 cases including a parity test against
|
| 126 |
+
PRIME-RL's own `default_loss_fn`, skip-marked when prime-rl is not
|
| 127 |
+
installed).
|
| 128 |
|
| 129 |
### 3. PRIME-RL config
|
| 130 |
|
|
|
|
| 160 |
- An actual training run yet — that's a separate spike.
|
| 161 |
- Quality validation against TRL/VeRL — pending Spike 004 A/B.
|
| 162 |
- Hardware autoscaling — that's the Monarch recipe's job (recipes/monarch/).
|
| 163 |
+
- **SDPO/OPSD channel** — deferred until PRIME-RL exposes full logits
|
| 164 |
+
through `LossInputs` (currently only log-probs).
|
| 165 |
|
| 166 |
## References
|
| 167 |
|
| 168 |
- PRIME-RL repo: https://github.com/PrimeIntellect-ai/prime-rl
|
| 169 |
+
- PRIME-RL upstream `default_loss_fn`: `src/prime_rl/trainer/rl/loss.py`
|
| 170 |
+
lines 116-165
|
| 171 |
+
- PRIME-RL `DefaultLossConfig` defaults:
|
| 172 |
+
`packages/prime-rl-configs/src/prime_rl/configs/trainer.py` lines
|
| 173 |
+
412-425
|
| 174 |
- ADR-006: docs/adrs/ADR-006-rl-frameworks.md
|
| 175 |
- Reconnaissance: docs/research/RL_FRAMEWORKS_LANDSCAPE.md (§ PRIME-RL)
|
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the PRIME-RL composer-loss adapter.
|
| 2 |
+
|
| 3 |
+
Verifies parity with PRIME-RL's upstream ``default_loss_fn``
|
| 4 |
+
(``src/prime_rl/trainer/rl/loss.py`` lines 116-165). Hand-computed
|
| 5 |
+
expected values use the upstream formula; the parity test at the bottom
|
| 6 |
+
imports PRIME-RL itself (skip-marked when not installed) and compares
|
| 7 |
+
outputs end-to-end.
|
| 8 |
+
|
| 9 |
+
License: MIT.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Try to import PRIME-RL upstream for the parity test; skip-mark if
|
| 24 |
+
# unavailable. PRIME-RL pulls in heavy deps (jaxtyping, beartype) and
|
| 25 |
+
# is not part of the framework's own test environment.
|
| 26 |
+
try:
|
| 27 |
+
from prime_rl.trainer.rl.loss import ( # type: ignore[import-not-found]
|
| 28 |
+
LossInputs as PrimeRLLossInputs,
|
| 29 |
+
default_loss_fn as prime_rl_default_loss_fn,
|
| 30 |
+
)
|
| 31 |
+
from prime_rl.configs.trainer import ( # type: ignore[import-not-found]
|
| 32 |
+
DefaultLossConfig as PrimeRLDefaultLossConfig,
|
| 33 |
+
)
|
| 34 |
+
_HAS_PRIME_RL = True
|
| 35 |
+
except Exception: # noqa: BLE001 — broad: missing module, version skew, etc.
|
| 36 |
+
_HAS_PRIME_RL = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------
|
| 40 |
+
# Test double — duck-typed stand-in for PRIME-RL's LossInputs
|
| 41 |
+
# ---------------------------------------------------------------------
|
| 42 |
+
@dataclass
|
| 43 |
+
class FakeLossInputs:
|
| 44 |
+
trainer_logprobs: torch.Tensor
|
| 45 |
+
inference_logprobs: torch.Tensor
|
| 46 |
+
advantages: torch.Tensor
|
| 47 |
+
loss_mask: torch.Tensor
|
| 48 |
+
teacher_logprobs: Optional[torch.Tensor] = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _make_inputs(
|
| 52 |
+
seq: int = 8,
|
| 53 |
+
*,
|
| 54 |
+
same_logprobs: bool = True,
|
| 55 |
+
teacher: bool = False,
|
| 56 |
+
seed: int = 0,
|
| 57 |
+
) -> FakeLossInputs:
|
| 58 |
+
"""Build a realistic (seq,) LossInputs stand-in.
|
| 59 |
+
|
| 60 |
+
Uses ``requires_grad`` on ``trainer_logprobs`` so callers can also
|
| 61 |
+
sanity-check that the loss is differentiable end-to-end. Default
|
| 62 |
+
log-probs are clamped to a moderate negative range so
|
| 63 |
+
``exp(trainer_lp) - exp(inference_lp)`` stays inside the 0.2 PRIME-RL
|
| 64 |
+
default DPPO band — i.e. tokens are not all DPPO-masked by chance.
|
| 65 |
+
"""
|
| 66 |
+
g = torch.Generator().manual_seed(seed)
|
| 67 |
+
# Negative log-probs in [-2, -0.5] keep exp() in roughly [0.13, 0.6]
|
| 68 |
+
# so probs_diff differences stay tiny under small perturbation.
|
| 69 |
+
trainer = -(0.5 + 1.5 * torch.rand(seq, generator=g))
|
| 70 |
+
trainer = trainer.detach().clone().requires_grad_(True)
|
| 71 |
+
if same_logprobs:
|
| 72 |
+
# Tiny perturbation -> probs_diff ~ 0, no DPPO masking.
|
| 73 |
+
inference = trainer.detach().clone() + 0.001 * torch.randn(
|
| 74 |
+
seq, generator=g
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
inference = -(0.5 + 1.5 * torch.rand(seq, generator=g))
|
| 78 |
+
advantages = torch.randn(seq, generator=g)
|
| 79 |
+
loss_mask = torch.ones(seq, dtype=torch.bool)
|
| 80 |
+
teacher_lp = torch.randn(seq, generator=g) if teacher else None
|
| 81 |
+
return FakeLossInputs(
|
| 82 |
+
trainer_logprobs=trainer,
|
| 83 |
+
inference_logprobs=inference,
|
| 84 |
+
advantages=advantages,
|
| 85 |
+
loss_mask=loss_mask,
|
| 86 |
+
teacher_logprobs=teacher_lp,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------------
|
| 91 |
+
# Reference re-implementation (independent restatement of upstream).
|
| 92 |
+
# Used by hand-computed expected-value tests so we don't accidentally
|
| 93 |
+
# encode our own bugs as ground truth.
|
| 94 |
+
# ---------------------------------------------------------------------
|
| 95 |
+
def _reference_default_loss(
|
| 96 |
+
trainer_lp: torch.Tensor,
|
| 97 |
+
inference_lp: torch.Tensor,
|
| 98 |
+
advantages: torch.Tensor,
|
| 99 |
+
loss_mask: torch.Tensor,
|
| 100 |
+
*,
|
| 101 |
+
dppo_mask_high: float,
|
| 102 |
+
dppo_mask_low: float,
|
| 103 |
+
adv_tau: float,
|
| 104 |
+
kl_tau: float,
|
| 105 |
+
) -> torch.Tensor:
|
| 106 |
+
log_ir = trainer_lp - inference_lp
|
| 107 |
+
ir = torch.exp(log_ir)
|
| 108 |
+
probs_diff = torch.exp(trainer_lp) - torch.exp(inference_lp)
|
| 109 |
+
invalid_high = probs_diff > dppo_mask_high
|
| 110 |
+
invalid_low = probs_diff < -dppo_mask_low
|
| 111 |
+
pos_adv = advantages > 0
|
| 112 |
+
invalid = torch.where(pos_adv, invalid_high, invalid_low)
|
| 113 |
+
keep = loss_mask.to(torch.bool) & ~invalid
|
| 114 |
+
keep_f = keep.to(trainer_lp.dtype)
|
| 115 |
+
lm_f = loss_mask.to(trainer_lp.dtype)
|
| 116 |
+
pg = keep_f * (adv_tau * advantages) * ir
|
| 117 |
+
kl = lm_f * log_ir**2
|
| 118 |
+
return (-pg + kl_tau * kl).sum()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------------------------------------
|
| 122 |
+
# Test 1 — finite scalar on realistic (seq,) tensors
|
| 123 |
+
# ---------------------------------------------------------------------
|
| 124 |
+
def test_returns_finite_scalar():
|
| 125 |
+
inputs = _make_inputs(seq=16)
|
| 126 |
+
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 127 |
+
|
| 128 |
+
assert isinstance(out, torch.Tensor)
|
| 129 |
+
assert out.shape == (), f"expected scalar, got shape {tuple(out.shape)}"
|
| 130 |
+
assert torch.isfinite(out).item()
|
| 131 |
+
# Differentiable: gradient flows to trainer_logprobs.
|
| 132 |
+
out.backward()
|
| 133 |
+
assert inputs.trainer_logprobs.grad is not None
|
| 134 |
+
assert torch.isfinite(inputs.trainer_logprobs.grad).all().item()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------
|
| 138 |
+
# Test 2 — DPPO mask drops tokens whose probs_diff exceeds dppo_mask_high
|
| 139 |
+
# (advantage-conditioned: positive advantages use the high gate)
|
| 140 |
+
# ---------------------------------------------------------------------
|
| 141 |
+
def test_dppo_mask_high_drops_positive_advantage_outliers():
|
| 142 |
+
"""Token with positive advantage and probs_diff > dppo_mask_high is dropped.
|
| 143 |
+
|
| 144 |
+
Build a 4-token sample where token 0 has ``probs_diff`` huge and
|
| 145 |
+
positive (trainer prob ~ 1, inference prob ~ 0) AND positive
|
| 146 |
+
advantage. Tokens 1..3 have tiny probs_diff. With the upstream
|
| 147 |
+
sign-conditioned gate, only token 0 should be dropped.
|
| 148 |
+
"""
|
| 149 |
+
# trainer_lp ~ 0 -> exp ~ 1; inference_lp = -10 -> exp ~ 4.5e-5.
|
| 150 |
+
# probs_diff[0] ~ 1.0 >> dppo_mask_high (0.2).
|
| 151 |
+
trainer_lp = torch.tensor(
|
| 152 |
+
[0.0, math.log(0.30), math.log(0.40), math.log(0.50)],
|
| 153 |
+
requires_grad=True,
|
| 154 |
+
)
|
| 155 |
+
inference_lp = torch.tensor(
|
| 156 |
+
[-10.0, math.log(0.31), math.log(0.39), math.log(0.51)]
|
| 157 |
+
)
|
| 158 |
+
advantages = torch.tensor([+5.0, +1.0, -1.0, +1.0])
|
| 159 |
+
mask = torch.ones(4, dtype=torch.bool)
|
| 160 |
+
|
| 161 |
+
inputs = FakeLossInputs(
|
| 162 |
+
trainer_logprobs=trainer_lp,
|
| 163 |
+
inference_logprobs=inference_lp,
|
| 164 |
+
advantages=advantages,
|
| 165 |
+
loss_mask=mask,
|
| 166 |
+
)
|
| 167 |
+
out = loss_fn(
|
| 168 |
+
inputs,
|
| 169 |
+
alpha_sdpo=0.0,
|
| 170 |
+
beta_dpo=0.0,
|
| 171 |
+
dppo_mask_high=0.2,
|
| 172 |
+
dppo_mask_low=0.2,
|
| 173 |
+
adv_tau=1.0,
|
| 174 |
+
kl_tau=1e-3,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
expected = _reference_default_loss(
|
| 178 |
+
trainer_lp.detach(),
|
| 179 |
+
inference_lp,
|
| 180 |
+
advantages,
|
| 181 |
+
mask,
|
| 182 |
+
dppo_mask_high=0.2,
|
| 183 |
+
dppo_mask_low=0.2,
|
| 184 |
+
adv_tau=1.0,
|
| 185 |
+
kl_tau=1e-3,
|
| 186 |
+
)
|
| 187 |
+
assert torch.isclose(out, expected, atol=1e-5), (
|
| 188 |
+
f"got {out.item()}, expected {expected.item()}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Token 0 was DPPO-dropped from pg_loss but still contributes to kl_loss
|
| 192 |
+
# (loss_mask gates KL, not the DPPO mask). The pg gradient on token 0
|
| 193 |
+
# should be zero; KL contributes a small grad. We assert the pg path
|
| 194 |
+
# is masked by checking the gradient magnitude is dominated by the
|
| 195 |
+
# tiny kl_tau * 2 * log_ir term, not by the +5 advantage.
|
| 196 |
+
out.backward()
|
| 197 |
+
g0 = inputs.trainer_logprobs.grad[0].item()
|
| 198 |
+
# If pg weren't masked, |g0| would be on the order of
|
| 199 |
+
# advantage * importance_ratio * 1 ~ 5 * exp(10) ~ 1e5.
|
| 200 |
+
# With pg masked, |g0| is on the order of
|
| 201 |
+
# 2 * kl_tau * log_ir ~ 2 * 1e-3 * 10 = 0.02.
|
| 202 |
+
assert abs(g0) < 1.0, (
|
| 203 |
+
f"DPPO mask should suppress the pg gradient on token 0; got |g0|={abs(g0)}"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------------
|
| 208 |
+
# Test 3 — DPPO mask catches the lower bound on negative-advantage tokens
|
| 209 |
+
# ---------------------------------------------------------------------
|
| 210 |
+
def test_dppo_mask_low_drops_negative_advantage_outliers():
|
| 211 |
+
"""Symmetric coverage: probs_diff < -dppo_mask_low drops a NEGATIVE-adv token."""
|
| 212 |
+
# Token 0: trainer prob ~ 0, inference prob ~ 1, so probs_diff ~ -1.
|
| 213 |
+
# Negative advantage -> the low gate applies -> dropped.
|
| 214 |
+
trainer_lp = torch.tensor(
|
| 215 |
+
[-10.0, math.log(0.30), math.log(0.40)], requires_grad=True
|
| 216 |
+
)
|
| 217 |
+
inference_lp = torch.tensor(
|
| 218 |
+
[0.0, math.log(0.31), math.log(0.39)]
|
| 219 |
+
)
|
| 220 |
+
advantages = torch.tensor([-5.0, +1.0, -1.0])
|
| 221 |
+
mask = torch.ones(3, dtype=torch.bool)
|
| 222 |
+
|
| 223 |
+
inputs = FakeLossInputs(
|
| 224 |
+
trainer_logprobs=trainer_lp,
|
| 225 |
+
inference_logprobs=inference_lp,
|
| 226 |
+
advantages=advantages,
|
| 227 |
+
loss_mask=mask,
|
| 228 |
+
)
|
| 229 |
+
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 230 |
+
|
| 231 |
+
expected = _reference_default_loss(
|
| 232 |
+
trainer_lp.detach(),
|
| 233 |
+
inference_lp,
|
| 234 |
+
advantages,
|
| 235 |
+
mask,
|
| 236 |
+
dppo_mask_high=0.2,
|
| 237 |
+
dppo_mask_low=0.2,
|
| 238 |
+
adv_tau=1.0,
|
| 239 |
+
kl_tau=1e-3,
|
| 240 |
+
)
|
| 241 |
+
assert torch.isclose(out, expected, atol=1e-5)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ---------------------------------------------------------------------
|
| 245 |
+
# Test 4 — sign-conditioning: a positive-advantage token whose probs_diff
|
| 246 |
+
# is *negative* (and large in magnitude) is NOT dropped, because the
|
| 247 |
+
# high gate doesn't fire on a negative probs_diff.
|
| 248 |
+
# ---------------------------------------------------------------------
|
| 249 |
+
def test_dppo_mask_sign_conditioned_on_advantage():
|
| 250 |
+
"""A positive-advantage token with probs_diff < -dppo_mask_low survives.
|
| 251 |
+
|
| 252 |
+
PRIME-RL's gate is ``where(positive_advantages, invalid_high, invalid_low)``.
|
| 253 |
+
For positive advantages it only checks the upper bound, so
|
| 254 |
+
``probs_diff = -0.9`` with a positive advantage is KEPT; with a
|
| 255 |
+
negative advantage it would be DROPPED.
|
| 256 |
+
"""
|
| 257 |
+
# Token 0: probs_diff = exp(-10) - exp(0) ~ -1. Massively negative.
|
| 258 |
+
trainer_lp_pos = torch.tensor([-10.0], requires_grad=True)
|
| 259 |
+
inference_lp_pos = torch.tensor([0.0])
|
| 260 |
+
adv_pos = torch.tensor([+1.0])
|
| 261 |
+
mask = torch.ones(1, dtype=torch.bool)
|
| 262 |
+
|
| 263 |
+
inputs_pos = FakeLossInputs(
|
| 264 |
+
trainer_logprobs=trainer_lp_pos,
|
| 265 |
+
inference_logprobs=inference_lp_pos,
|
| 266 |
+
advantages=adv_pos,
|
| 267 |
+
loss_mask=mask,
|
| 268 |
+
)
|
| 269 |
+
out_pos = loss_fn(inputs_pos, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 270 |
+
|
| 271 |
+
# With positive advantage the LOW bound is not checked; the token is
|
| 272 |
+
# KEPT. pg = +1 * exp(-10 - 0) = ~4.5e-5; kl = (-10)^2 = 100.
|
| 273 |
+
# loss = -pg + 1e-3 * 100 ~ 0.1.
|
| 274 |
+
expected_pos = _reference_default_loss(
|
| 275 |
+
trainer_lp_pos.detach(),
|
| 276 |
+
inference_lp_pos,
|
| 277 |
+
adv_pos,
|
| 278 |
+
mask,
|
| 279 |
+
dppo_mask_high=0.2,
|
| 280 |
+
dppo_mask_low=0.2,
|
| 281 |
+
adv_tau=1.0,
|
| 282 |
+
kl_tau=1e-3,
|
| 283 |
+
)
|
| 284 |
+
assert torch.isclose(out_pos, expected_pos, atol=1e-5)
|
| 285 |
+
# Sanity: token wasn't masked, so kl_tau alone shouldn't dominate to
|
| 286 |
+
# zero — loss should be ~0.1, definitely not zero.
|
| 287 |
+
assert out_pos.item() > 0.05
|
| 288 |
+
|
| 289 |
+
# Same probs_diff but negative advantage -> DROPPED from pg.
|
| 290 |
+
trainer_lp_neg = torch.tensor([-10.0], requires_grad=True)
|
| 291 |
+
inputs_neg = FakeLossInputs(
|
| 292 |
+
trainer_logprobs=trainer_lp_neg,
|
| 293 |
+
inference_logprobs=inference_lp_pos,
|
| 294 |
+
advantages=torch.tensor([-1.0]),
|
| 295 |
+
loss_mask=mask,
|
| 296 |
+
)
|
| 297 |
+
out_neg = loss_fn(inputs_neg, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 298 |
+
expected_neg = _reference_default_loss(
|
| 299 |
+
trainer_lp_neg.detach(),
|
| 300 |
+
inference_lp_pos,
|
| 301 |
+
torch.tensor([-1.0]),
|
| 302 |
+
mask,
|
| 303 |
+
dppo_mask_high=0.2,
|
| 304 |
+
dppo_mask_low=0.2,
|
| 305 |
+
adv_tau=1.0,
|
| 306 |
+
kl_tau=1e-3,
|
| 307 |
+
)
|
| 308 |
+
assert torch.isclose(out_neg, expected_neg, atol=1e-5)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ---------------------------------------------------------------------
|
| 312 |
+
# Test 5 — alpha_sdpo=0 must not raise (channel 2 disabled)
|
| 313 |
+
# ---------------------------------------------------------------------
|
| 314 |
+
def test_alpha_sdpo_zero_does_not_raise():
|
| 315 |
+
inputs = _make_inputs(seq=6, teacher=True)
|
| 316 |
+
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 317 |
+
assert torch.isfinite(out).item()
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ---------------------------------------------------------------------
|
| 321 |
+
# Test 6 — alpha_sdpo>0 still raises NotImplementedError
|
| 322 |
+
# ---------------------------------------------------------------------
|
| 323 |
+
def test_alpha_sdpo_nonzero_raises_not_implemented():
|
| 324 |
+
inputs = _make_inputs(seq=6, teacher=True)
|
| 325 |
+
with pytest.raises(NotImplementedError, match="SDPO"):
|
| 326 |
+
loss_fn(inputs, alpha_sdpo=0.5, beta_dpo=0.0)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def test_alpha_sdpo_nonzero_no_teacher_also_raises():
|
| 330 |
+
"""Defensive: even without teacher_logprobs, alpha_sdpo>0 must fail
|
| 331 |
+
rather than silently no-op."""
|
| 332 |
+
inputs = _make_inputs(seq=6, teacher=False)
|
| 333 |
+
with pytest.raises(NotImplementedError):
|
| 334 |
+
loss_fn(inputs, alpha_sdpo=0.5, beta_dpo=0.0)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# ---------------------------------------------------------------------
|
| 338 |
+
# Test 7 — shape validation: (seq,) accepted, (B, T) rejected
|
| 339 |
+
# ---------------------------------------------------------------------
|
| 340 |
+
def test_advantages_shape_validates_seq_accepted():
|
| 341 |
+
inputs = _make_inputs(seq=12)
|
| 342 |
+
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 343 |
+
assert out.shape == ()
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def test_advantages_shape_validates_bt_rejected():
|
| 347 |
+
B, T = 2, 4
|
| 348 |
+
bad = FakeLossInputs(
|
| 349 |
+
trainer_logprobs=torch.zeros(B, T, requires_grad=True),
|
| 350 |
+
inference_logprobs=torch.zeros(B, T),
|
| 351 |
+
advantages=torch.zeros(B, T),
|
| 352 |
+
loss_mask=torch.ones(B, T, dtype=torch.bool),
|
| 353 |
+
)
|
| 354 |
+
with pytest.raises(ValueError, match="1-D"):
|
| 355 |
+
loss_fn(bad, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# ---------------------------------------------------------------------
|
| 359 |
+
# Test 8 — beta_dpo != 0 emits a warning but does not raise
|
| 360 |
+
# ---------------------------------------------------------------------
|
| 361 |
+
def test_beta_dpo_nonzero_warns():
|
| 362 |
+
inputs = _make_inputs(seq=8)
|
| 363 |
+
with pytest.warns(UserWarning, match="DPO channel"):
|
| 364 |
+
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.3)
|
| 365 |
+
assert torch.isfinite(out).item()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ---------------------------------------------------------------------
|
| 369 |
+
# Test 9 — config-validation knobs match PRIME-RL Field(..., ge=0)
|
| 370 |
+
# ---------------------------------------------------------------------
|
| 371 |
+
@pytest.mark.parametrize(
|
| 372 |
+
"kw",
|
| 373 |
+
[
|
| 374 |
+
{"dppo_mask_high": -0.1},
|
| 375 |
+
{"dppo_mask_low": -0.1},
|
| 376 |
+
{"adv_tau": -0.1},
|
| 377 |
+
{"kl_tau": -0.1},
|
| 378 |
+
],
|
| 379 |
+
)
|
| 380 |
+
def test_negative_knobs_rejected(kw):
|
| 381 |
+
inputs = _make_inputs(seq=4)
|
| 382 |
+
with pytest.raises(ValueError, match=">= 0"):
|
| 383 |
+
loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0, **kw)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# ---------------------------------------------------------------------
|
| 387 |
+
# Test 10 — disabling masking via wide bounds gives plain DPPO+KL on all
|
| 388 |
+
# tokens. This pins the "pure IS-corrected REINFORCE + KL" baseline.
|
| 389 |
+
# ---------------------------------------------------------------------
|
| 390 |
+
def test_dppo_bounds_can_be_disabled():
|
| 391 |
+
"""Setting bounds to a huge value disables DPPO masking.
|
| 392 |
+
|
| 393 |
+
At dppo_mask_high=dppo_mask_low=1e6, ``probs_diff`` never exceeds the
|
| 394 |
+
threshold so ``keep_mask == loss_mask`` and the loss reduces to the
|
| 395 |
+
plain DPPO+KL on the whole sequence.
|
| 396 |
+
"""
|
| 397 |
+
seq = 4
|
| 398 |
+
trainer_lp = torch.tensor(
|
| 399 |
+
[math.log(0.10), math.log(0.30), math.log(0.20), math.log(0.40)],
|
| 400 |
+
requires_grad=True,
|
| 401 |
+
)
|
| 402 |
+
inference_lp = torch.tensor(
|
| 403 |
+
[math.log(0.11), math.log(0.31), math.log(0.21), math.log(0.39)]
|
| 404 |
+
)
|
| 405 |
+
advantages = torch.tensor([+1.0, -1.0, +0.5, -0.5])
|
| 406 |
+
mask = torch.ones(seq, dtype=torch.bool)
|
| 407 |
+
inputs = FakeLossInputs(
|
| 408 |
+
trainer_logprobs=trainer_lp,
|
| 409 |
+
inference_logprobs=inference_lp,
|
| 410 |
+
advantages=advantages,
|
| 411 |
+
loss_mask=mask,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
out = loss_fn(
|
| 415 |
+
inputs,
|
| 416 |
+
alpha_sdpo=0.0,
|
| 417 |
+
beta_dpo=0.0,
|
| 418 |
+
dppo_mask_high=1e6,
|
| 419 |
+
dppo_mask_low=1e6,
|
| 420 |
+
adv_tau=1.0,
|
| 421 |
+
kl_tau=1e-3,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
expected = _reference_default_loss(
|
| 425 |
+
trainer_lp.detach(),
|
| 426 |
+
inference_lp,
|
| 427 |
+
advantages,
|
| 428 |
+
mask,
|
| 429 |
+
dppo_mask_high=1e6,
|
| 430 |
+
dppo_mask_low=1e6,
|
| 431 |
+
adv_tau=1.0,
|
| 432 |
+
kl_tau=1e-3,
|
| 433 |
+
)
|
| 434 |
+
assert torch.isclose(out, expected, atol=1e-6)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# ---------------------------------------------------------------------
|
| 438 |
+
# Test 11 — PARITY against PRIME-RL upstream's default_loss_fn.
|
| 439 |
+
# Skip-marked when prime-rl is not installable.
|
| 440 |
+
# ---------------------------------------------------------------------
|
| 441 |
+
@pytest.mark.skipif(
|
| 442 |
+
not _HAS_PRIME_RL,
|
| 443 |
+
reason="prime-rl not installed; skipping upstream parity test",
|
| 444 |
+
)
|
| 445 |
+
def test_parity_with_prime_rl_default_loss_fn():
|
| 446 |
+
"""Run identical inputs through ours and PRIME-RL's; loss must match."""
|
| 447 |
+
seq = 32
|
| 448 |
+
g = torch.Generator().manual_seed(42)
|
| 449 |
+
trainer_lp = -(0.1 + 2.0 * torch.rand(seq, generator=g)).to(torch.float32)
|
| 450 |
+
inference_lp = (trainer_lp + 0.05 * torch.randn(seq, generator=g)).to(torch.float32)
|
| 451 |
+
advantages = torch.randn(seq, generator=g, dtype=torch.float32)
|
| 452 |
+
loss_mask = torch.ones(seq, dtype=torch.bool)
|
| 453 |
+
|
| 454 |
+
# Use PRIME-RL's defaults (dppo_mask_high=0.2, etc.) directly.
|
| 455 |
+
cfg = PrimeRLDefaultLossConfig() # type: ignore[name-defined]
|
| 456 |
+
|
| 457 |
+
upstream_inputs = PrimeRLLossInputs( # type: ignore[name-defined]
|
| 458 |
+
trainer_logprobs=trainer_lp,
|
| 459 |
+
inference_logprobs=inference_lp,
|
| 460 |
+
teacher_logprobs=None,
|
| 461 |
+
advantages=advantages,
|
| 462 |
+
loss_mask=loss_mask,
|
| 463 |
+
)
|
| 464 |
+
upstream_out = prime_rl_default_loss_fn(upstream_inputs, cfg) # type: ignore[name-defined]
|
| 465 |
+
|
| 466 |
+
ours = loss_fn(
|
| 467 |
+
FakeLossInputs(
|
| 468 |
+
trainer_logprobs=trainer_lp.clone(),
|
| 469 |
+
inference_logprobs=inference_lp.clone(),
|
| 470 |
+
advantages=advantages.clone(),
|
| 471 |
+
loss_mask=loss_mask.clone(),
|
| 472 |
+
),
|
| 473 |
+
alpha_sdpo=0.0,
|
| 474 |
+
beta_dpo=0.0,
|
| 475 |
+
dppo_mask_high=cfg.dppo_mask_high,
|
| 476 |
+
dppo_mask_low=cfg.dppo_mask_low,
|
| 477 |
+
adv_tau=cfg.adv_tau,
|
| 478 |
+
kl_tau=cfg.kl_tau,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
assert torch.isclose(ours, upstream_out.loss, atol=1e-5, rtol=1e-5), (
|
| 482 |
+
f"Parity mismatch with PRIME-RL upstream: ours={ours.item()}, "
|
| 483 |
+
f"upstream={upstream_out.loss.item()}"
|
| 484 |
+
)
|
|
@@ -8,21 +8,43 @@
|
|
| 8 |
#
|
| 9 |
# {
|
| 10 |
# "state_id": "...",
|
| 11 |
-
# "messages":
|
| 12 |
-
#
|
| 13 |
-
# "
|
| 14 |
-
# "
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
| 16 |
# }
|
| 17 |
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Ops listed in `process` are applied in order. Each op operates on the
|
| 19 |
-
# full record but
|
| 20 |
-
# DPO/preference-pair ops know how to handle the chosen/rejected pair
|
| 21 |
-
# structure natively.
|
| 22 |
|
| 23 |
# Project & I/O are filled in by DJNormalizer at runtime; we only
|
| 24 |
# specify the op pipeline here.
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# --- Op pipeline (applied in order) -----------------------------------
|
| 27 |
process:
|
| 28 |
|
|
@@ -32,7 +54,11 @@ process:
|
|
| 32 |
- text_length_filter:
|
| 33 |
min_len: 8
|
| 34 |
max_len: 32000
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# 2. Word-count filter on response.
|
| 38 |
# Drops pairs with absurdly low (< 2 words) or high (> 4096 words)
|
|
@@ -40,23 +66,31 @@ process:
|
|
| 40 |
- words_num_filter:
|
| 41 |
min_num: 2
|
| 42 |
max_num: 4096
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# 3. Special-character filter.
|
| 46 |
# Drops responses where >50% of characters are non-alphabetic
|
| 47 |
# special chars (likely encoding errors or junk).
|
| 48 |
- special_characters_filter:
|
| 49 |
max_ratio: 0.5
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# 4. Per-conversation deduplication.
|
| 53 |
-
#
|
| 54 |
-
#
|
|
|
|
|
|
|
| 55 |
- document_deduplicator:
|
| 56 |
lowercase: true
|
| 57 |
ignore_non_character: true
|
| 58 |
-
|
| 59 |
-
# data-juicer's per-batch dedup; full corpus dedup is a separate op.
|
| 60 |
|
| 61 |
# Notes:
|
| 62 |
# - We DO NOT run `pair_preference_mapper` because its default config may
|
|
|
|
| 8 |
#
|
| 9 |
# {
|
| 10 |
# "state_id": "...",
|
| 11 |
+
# "messages": [{"role": "user", "content": "..."}], # context
|
| 12 |
+
# # --- flat-string shape (consumed by length/word/special-char/dedup filters) ---
|
| 13 |
+
# "chosen": "the chosen response as a plain string",
|
| 14 |
+
# "rejected": "the rejected response as a plain string",
|
| 15 |
+
# # --- chat-messages shape (preserved for chat-aware ops + round-trip) ---
|
| 16 |
+
# "chosen_messages": [{"role": "assistant", "content": "..."}],
|
| 17 |
+
# "rejected_messages": [{"role": "assistant", "content": "..."}],
|
| 18 |
+
# "n_teachers_agreeing": 2
|
| 19 |
# }
|
| 20 |
#
|
| 21 |
+
# IMPORTANT — field-key contract:
|
| 22 |
+
# data-juicer's `text_length_filter`, `words_num_filter`,
|
| 23 |
+
# `special_characters_filter` and `document_deduplicator` all read a SINGLE
|
| 24 |
+
# string field named by `text_key` (singular). They expect plain strings.
|
| 25 |
+
# Pointing them at a list-of-dicts (the chat-messages shape) crashes or
|
| 26 |
+
# silently no-ops. We therefore keep two parallel representations:
|
| 27 |
+
# * `chosen` / `rejected` — plain strings, fed to filter ops below.
|
| 28 |
+
# * `chosen_messages` / `rejected_messages` — chat-messages list, preserved
|
| 29 |
+
# untouched for downstream chat-aware consumers and the round-trip.
|
| 30 |
+
#
|
| 31 |
+
# data-juicer caveat: each filter op accepts only ONE `text_key`. To filter
|
| 32 |
+
# both `chosen` AND `rejected`, we duplicate each op — once with
|
| 33 |
+
# `text_key: chosen`, once with `text_key: rejected`. The top-level
|
| 34 |
+
# `text_keys: chosen` below also satisfies data-juicer's dataset-load
|
| 35 |
+
# validation (the formatter checks the global text_key exists in the dataset).
|
| 36 |
+
#
|
| 37 |
# Ops listed in `process` are applied in order. Each op operates on the
|
| 38 |
+
# full record but reads/writes one field.
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Project & I/O are filled in by DJNormalizer at runtime; we only
|
| 41 |
# specify the op pipeline here.
|
| 42 |
|
| 43 |
+
# --- Global text-key contract (see header note) -----------------------
|
| 44 |
+
# data-juicer validates this exists on the dataset before any op runs, and
|
| 45 |
+
# uses it as the default text_key for ops that don't specify their own.
|
| 46 |
+
text_keys: chosen
|
| 47 |
+
|
| 48 |
# --- Op pipeline (applied in order) -----------------------------------
|
| 49 |
process:
|
| 50 |
|
|
|
|
| 54 |
- text_length_filter:
|
| 55 |
min_len: 8
|
| 56 |
max_len: 32000
|
| 57 |
+
text_key: chosen
|
| 58 |
+
- text_length_filter:
|
| 59 |
+
min_len: 8
|
| 60 |
+
max_len: 32000
|
| 61 |
+
text_key: rejected
|
| 62 |
|
| 63 |
# 2. Word-count filter on response.
|
| 64 |
# Drops pairs with absurdly low (< 2 words) or high (> 4096 words)
|
|
|
|
| 66 |
- words_num_filter:
|
| 67 |
min_num: 2
|
| 68 |
max_num: 4096
|
| 69 |
+
text_key: chosen
|
| 70 |
+
- words_num_filter:
|
| 71 |
+
min_num: 2
|
| 72 |
+
max_num: 4096
|
| 73 |
+
text_key: rejected
|
| 74 |
|
| 75 |
# 3. Special-character filter.
|
| 76 |
# Drops responses where >50% of characters are non-alphabetic
|
| 77 |
# special chars (likely encoding errors or junk).
|
| 78 |
- special_characters_filter:
|
| 79 |
max_ratio: 0.5
|
| 80 |
+
text_key: chosen
|
| 81 |
+
- special_characters_filter:
|
| 82 |
+
max_ratio: 0.5
|
| 83 |
+
text_key: rejected
|
| 84 |
|
| 85 |
# 4. Per-conversation deduplication.
|
| 86 |
+
# Within the batch, drop records where the `chosen` field is a
|
| 87 |
+
# duplicate of another record's `chosen`. (data-juicer's
|
| 88 |
+
# document_deduplicator is per-batch hashing — full-corpus dedup is
|
| 89 |
+
# a separate op family.)
|
| 90 |
- document_deduplicator:
|
| 91 |
lowercase: true
|
| 92 |
ignore_non_character: true
|
| 93 |
+
text_key: chosen
|
|
|
|
| 94 |
|
| 95 |
# Notes:
|
| 96 |
# - We DO NOT run `pair_preference_mapper` because its default config may
|
|
@@ -71,26 +71,72 @@ class NormalizedDPOPair:
|
|
| 71 |
|
| 72 |
|
| 73 |
def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
|
| 74 |
-
"""Convert a DPOPair (or dict-shaped equivalent) into a data-juicer
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
p = cast(dict[str, Any], pair)
|
|
|
|
|
|
|
| 78 |
return {
|
| 79 |
"state_id": p.get("state_id", ""),
|
| 80 |
"messages": p.get("state_messages", []),
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
"n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
|
| 84 |
}
|
| 85 |
|
| 86 |
|
| 87 |
def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
|
| 88 |
-
"""Inverse — convert a data-juicer record back to NormalizedDPOPair.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
return NormalizedDPOPair(
|
| 90 |
state_id=rec.get("state_id", ""),
|
| 91 |
state_messages=rec.get("messages", []),
|
| 92 |
-
chosen_messages=
|
| 93 |
-
rejected_messages=
|
| 94 |
n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
|
| 95 |
metadata=rec.get("__dj_meta__", {}),
|
| 96 |
)
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
|
| 74 |
+
"""Convert a DPOPair (or dict-shaped equivalent) into a data-juicer record.
|
| 75 |
+
|
| 76 |
+
The record carries TWO shapes for chosen/rejected so that data-juicer ops
|
| 77 |
+
that expect string-typed text fields (e.g. ``text_length_filter``,
|
| 78 |
+
``words_num_filter``, ``special_characters_filter``,
|
| 79 |
+
``document_deduplicator``) work alongside chat-aware ops:
|
| 80 |
+
|
| 81 |
+
- ``chosen`` / ``rejected``: flat strings (drives the standard text ops
|
| 82 |
+
that read string fields via ``text_keys``).
|
| 83 |
+
- ``chosen_messages`` / ``rejected_messages``: chat-messages list
|
| 84 |
+
(one assistant turn each), preserving the multi-turn-aware shape.
|
| 85 |
+
|
| 86 |
+
The ``messages`` field carries the conversation context (matches
|
| 87 |
+
data-juicer's ``messages`` convention for chat-aware filters).
|
| 88 |
"""
|
| 89 |
p = cast(dict[str, Any], pair)
|
| 90 |
+
chosen_str = p.get("chosen", "") or ""
|
| 91 |
+
rejected_str = p.get("rejected", "") or ""
|
| 92 |
return {
|
| 93 |
"state_id": p.get("state_id", ""),
|
| 94 |
"messages": p.get("state_messages", []),
|
| 95 |
+
# Flat-string shape for length/word/special-char/dedup filters
|
| 96 |
+
# that expect text_keys to point at strings.
|
| 97 |
+
"chosen": chosen_str,
|
| 98 |
+
"rejected": rejected_str,
|
| 99 |
+
# Chat-messages shape for chat-aware ops and the NormalizedDPOPair
|
| 100 |
+
# round-trip.
|
| 101 |
+
"chosen_messages": [{"role": "assistant", "content": chosen_str}],
|
| 102 |
+
"rejected_messages": [{"role": "assistant", "content": rejected_str}],
|
| 103 |
"n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
|
| 104 |
}
|
| 105 |
|
| 106 |
|
| 107 |
def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
|
| 108 |
+
"""Inverse — convert a data-juicer record back to NormalizedDPOPair.
|
| 109 |
+
|
| 110 |
+
Tolerates records that only carry one of the two shapes:
|
| 111 |
+
|
| 112 |
+
- If ``chosen_messages``/``rejected_messages`` are present, use them
|
| 113 |
+
directly.
|
| 114 |
+
- Otherwise wrap the flat-string ``chosen``/``rejected`` fields into
|
| 115 |
+
a single-assistant-turn messages list. This handles the case where
|
| 116 |
+
a data-juicer op rewrites the string field but doesn't touch the
|
| 117 |
+
messages field.
|
| 118 |
+
"""
|
| 119 |
+
def _to_messages(val: Any, fallback_str: Any) -> list[dict[str, Any]]:
|
| 120 |
+
if isinstance(val, list) and val:
|
| 121 |
+
return val # already chat-messages shape
|
| 122 |
+
if isinstance(fallback_str, str) and fallback_str:
|
| 123 |
+
return [{"role": "assistant", "content": fallback_str}]
|
| 124 |
+
if isinstance(fallback_str, list):
|
| 125 |
+
# Edge case: someone put the messages list in the flat field.
|
| 126 |
+
return fallback_str
|
| 127 |
+
return []
|
| 128 |
+
|
| 129 |
+
chosen_messages = _to_messages(
|
| 130 |
+
rec.get("chosen_messages"), rec.get("chosen", "")
|
| 131 |
+
)
|
| 132 |
+
rejected_messages = _to_messages(
|
| 133 |
+
rec.get("rejected_messages"), rec.get("rejected", "")
|
| 134 |
+
)
|
| 135 |
return NormalizedDPOPair(
|
| 136 |
state_id=rec.get("state_id", ""),
|
| 137 |
state_messages=rec.get("messages", []),
|
| 138 |
+
chosen_messages=chosen_messages,
|
| 139 |
+
rejected_messages=rejected_messages,
|
| 140 |
n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
|
| 141 |
metadata=rec.get("__dj_meta__", {}),
|
| 142 |
)
|
|
@@ -42,12 +42,24 @@ def _make_pair(
|
|
| 42 |
|
| 43 |
|
| 44 |
def test_dpo_pair_to_dj_record_shape():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
p = _make_pair("s1")
|
| 46 |
rec = _dpo_pair_to_dj_record(p)
|
| 47 |
assert rec["state_id"] == "s1"
|
| 48 |
assert rec["messages"] == [{"role": "user", "content": "What is 2+2?"}]
|
| 49 |
-
|
| 50 |
-
assert rec["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
assert rec["n_teachers_agreeing"] == 2
|
| 52 |
|
| 53 |
|
|
@@ -136,3 +148,180 @@ def test_record_handles_missing_optional_fields():
|
|
| 136 |
assert rec["state_id"] == "x"
|
| 137 |
assert rec["messages"] == [] # missing state_messages → empty list
|
| 138 |
assert rec["n_teachers_agreeing"] == 0 # missing → default 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def test_dpo_pair_to_dj_record_shape():
|
| 45 |
+
"""Records carry BOTH flat-string and chat-messages shapes for chosen/rejected.
|
| 46 |
+
|
| 47 |
+
See default.yaml header for why: data-juicer's text_length_filter et al
|
| 48 |
+
consume the flat strings; chat-aware consumers and the round-trip use
|
| 49 |
+
the *_messages fields.
|
| 50 |
+
"""
|
| 51 |
p = _make_pair("s1")
|
| 52 |
rec = _dpo_pair_to_dj_record(p)
|
| 53 |
assert rec["state_id"] == "s1"
|
| 54 |
assert rec["messages"] == [{"role": "user", "content": "What is 2+2?"}]
|
| 55 |
+
# Flat-string shape (drives text_length_filter, words_num_filter, ...)
|
| 56 |
+
assert rec["chosen"] == "Four."
|
| 57 |
+
assert rec["rejected"] == "Five."
|
| 58 |
+
assert isinstance(rec["chosen"], str)
|
| 59 |
+
assert isinstance(rec["rejected"], str)
|
| 60 |
+
# Chat-messages shape (preserved for chat-aware ops)
|
| 61 |
+
assert rec["chosen_messages"] == [{"role": "assistant", "content": "Four."}]
|
| 62 |
+
assert rec["rejected_messages"] == [{"role": "assistant", "content": "Five."}]
|
| 63 |
assert rec["n_teachers_agreeing"] == 2
|
| 64 |
|
| 65 |
|
|
|
|
| 148 |
assert rec["state_id"] == "x"
|
| 149 |
assert rec["messages"] == [] # missing state_messages → empty list
|
| 150 |
assert rec["n_teachers_agreeing"] == 0 # missing → default 0
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ---------------------------------------------------------------------
|
| 154 |
+
# Dual-shape contract (Wave 13 review Suggestion 3)
|
| 155 |
+
# ---------------------------------------------------------------------
|
| 156 |
+
#
|
| 157 |
+
# data-juicer's text_length_filter / words_num_filter /
|
| 158 |
+
# special_characters_filter / document_deduplicator all expect string-typed
|
| 159 |
+
# fields under `text_keys`. Earlier the converter wrapped chosen/rejected
|
| 160 |
+
# into list-of-dicts (chat-messages), which would have caused those ops to
|
| 161 |
+
# crash or no-op silently. The fix carries BOTH shapes:
|
| 162 |
+
# - chosen / rejected → flat strings (filter ops)
|
| 163 |
+
# - chosen_messages / rejected_messages → list-of-dicts (chat-aware ops + round-trip)
|
| 164 |
+
#
|
| 165 |
+
# The tests below pin that contract.
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_record_chosen_rejected_are_flat_strings_for_dj_text_ops():
|
| 169 |
+
"""text_length_filter & friends expect text_keys to point at strings.
|
| 170 |
+
|
| 171 |
+
If we ever regress to wrapping `chosen`/`rejected` into list-of-dicts,
|
| 172 |
+
data-juicer's text-key ops break. Keep this red-line explicit.
|
| 173 |
+
"""
|
| 174 |
+
p = _make_pair(
|
| 175 |
+
"s_strings",
|
| 176 |
+
chosen="A long-enough chosen response.",
|
| 177 |
+
rejected="A long-enough rejected response.",
|
| 178 |
+
)
|
| 179 |
+
rec = _dpo_pair_to_dj_record(p)
|
| 180 |
+
|
| 181 |
+
assert isinstance(rec["chosen"], str)
|
| 182 |
+
assert isinstance(rec["rejected"], str)
|
| 183 |
+
assert rec["chosen"] == "A long-enough chosen response."
|
| 184 |
+
assert rec["rejected"] == "A long-enough rejected response."
|
| 185 |
+
# Sanity: text_length_filter style usage works without crashing.
|
| 186 |
+
assert len(rec["chosen"]) >= 8
|
| 187 |
+
assert len(rec["rejected"]) >= 8
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test_record_chosen_rejected_messages_carry_chat_shape():
|
| 191 |
+
"""The *_messages variants preserve the chat-template-aware shape."""
|
| 192 |
+
p = _make_pair("s_msgs", chosen="hello world", rejected="goodbye world")
|
| 193 |
+
rec = _dpo_pair_to_dj_record(p)
|
| 194 |
+
|
| 195 |
+
assert isinstance(rec["chosen_messages"], list)
|
| 196 |
+
assert isinstance(rec["rejected_messages"], list)
|
| 197 |
+
assert rec["chosen_messages"] == [
|
| 198 |
+
{"role": "assistant", "content": "hello world"}
|
| 199 |
+
]
|
| 200 |
+
assert rec["rejected_messages"] == [
|
| 201 |
+
{"role": "assistant", "content": "goodbye world"}
|
| 202 |
+
]
|
| 203 |
+
# Both shapes must agree on content.
|
| 204 |
+
assert rec["chosen_messages"][0]["content"] == rec["chosen"]
|
| 205 |
+
assert rec["rejected_messages"][0]["content"] == rec["rejected"]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def test_dj_record_to_normalized_uses_chat_messages_when_present():
|
| 209 |
+
"""When *_messages fields are present, the round-trip uses them directly
|
| 210 |
+
(does not re-wrap the flat string)."""
|
| 211 |
+
rec = {
|
| 212 |
+
"state_id": "s_present",
|
| 213 |
+
"messages": [{"role": "user", "content": "q"}],
|
| 214 |
+
"chosen": "some flat str — should be ignored when _messages present",
|
| 215 |
+
"rejected": "another flat str",
|
| 216 |
+
"chosen_messages": [
|
| 217 |
+
{"role": "assistant", "content": "real chosen"},
|
| 218 |
+
],
|
| 219 |
+
"rejected_messages": [
|
| 220 |
+
{"role": "assistant", "content": "real rejected"},
|
| 221 |
+
],
|
| 222 |
+
"n_teachers_agreeing": 4,
|
| 223 |
+
}
|
| 224 |
+
norm = _dj_record_to_normalized(rec)
|
| 225 |
+
assert norm.chosen_messages == [{"role": "assistant", "content": "real chosen"}]
|
| 226 |
+
assert norm.rejected_messages == [{"role": "assistant", "content": "real rejected"}]
|
| 227 |
+
assert norm.n_teachers_agreeing == 4
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def test_dj_record_to_normalized_falls_back_to_flat_strings():
|
| 231 |
+
"""When *_messages fields are absent (e.g. an op only rewrote the flat
|
| 232 |
+
string), the round-trip wraps the flat string into a single assistant
|
| 233 |
+
turn so downstream consumers always see the chat-messages shape."""
|
| 234 |
+
rec = {
|
| 235 |
+
"state_id": "s_fallback",
|
| 236 |
+
"messages": [{"role": "user", "content": "q"}],
|
| 237 |
+
"chosen": "rewritten chosen",
|
| 238 |
+
"rejected": "rewritten rejected",
|
| 239 |
+
# NOTE: no chosen_messages / rejected_messages
|
| 240 |
+
"n_teachers_agreeing": 1,
|
| 241 |
+
}
|
| 242 |
+
norm = _dj_record_to_normalized(rec)
|
| 243 |
+
assert norm.chosen_messages == [
|
| 244 |
+
{"role": "assistant", "content": "rewritten chosen"}
|
| 245 |
+
]
|
| 246 |
+
assert norm.rejected_messages == [
|
| 247 |
+
{"role": "assistant", "content": "rewritten rejected"}
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def test_round_trip_preserves_strings_through_skip_dj():
|
| 252 |
+
"""End-to-end shape sanity: pair → normalize(skip_dj=True) → assert
|
| 253 |
+
chat-messages content matches original strings."""
|
| 254 |
+
pairs = [
|
| 255 |
+
_make_pair("rt1", chosen="alpha", rejected="beta", n_teachers_agreeing=2),
|
| 256 |
+
_make_pair("rt2", chosen="gamma", rejected="delta", n_teachers_agreeing=3),
|
| 257 |
+
]
|
| 258 |
+
out = DJNormalizer(skip_dj=True).normalize(pairs)
|
| 259 |
+
assert len(out) == 2
|
| 260 |
+
assert out[0].chosen_messages[0]["content"] == "alpha"
|
| 261 |
+
assert out[0].rejected_messages[0]["content"] == "beta"
|
| 262 |
+
assert out[1].chosen_messages[0]["content"] == "gamma"
|
| 263 |
+
assert out[1].rejected_messages[0]["content"] == "delta"
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ---------------------------------------------------------------------
|
| 267 |
+
# End-to-end test against the real data-juicer engine.
|
| 268 |
+
# ---------------------------------------------------------------------
|
| 269 |
+
#
|
| 270 |
+
# Install path tried during Wave 13 fix: `pip install py-data-juicer`
|
| 271 |
+
# (the canonical PyPI distribution name; `data-juicer` redirects there).
|
| 272 |
+
# If that succeeded in the runtime environment, the e2e test runs the
|
| 273 |
+
# actual op-graph from default.yaml against a tiny fixture and verifies
|
| 274 |
+
# the dual-shape contract holds at the JSONL boundary. If data-juicer
|
| 275 |
+
# is NOT importable, the test is skipped.
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
import data_juicer # type: ignore[import-not-found] # noqa: F401
|
| 279 |
+
_HAS_DJ = True
|
| 280 |
+
except ImportError:
|
| 281 |
+
_HAS_DJ = False
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@pytest.mark.skipif(not _HAS_DJ, reason="data-juicer not installed")
|
| 285 |
+
def test_dj_normalizer_e2e_default_recipe(tmp_path):
|
| 286 |
+
"""E2E: real data-juicer engine + default.yaml on a 3-record fixture.
|
| 287 |
+
|
| 288 |
+
Verifies:
|
| 289 |
+
1. The engine runs without a type-mismatch crash on the flat-string
|
| 290 |
+
text_keys (this is the bug Wave 13 Suggestion 3 flagged).
|
| 291 |
+
2. Output records survive the round-trip with both shapes intact.
|
| 292 |
+
"""
|
| 293 |
+
pairs = [
|
| 294 |
+
_make_pair(
|
| 295 |
+
"e2e1",
|
| 296 |
+
chosen="A reasonably long chosen response with several words.",
|
| 297 |
+
rejected="A reasonably long rejected response with several words.",
|
| 298 |
+
n_teachers_agreeing=2,
|
| 299 |
+
),
|
| 300 |
+
_make_pair(
|
| 301 |
+
"e2e2",
|
| 302 |
+
chosen="Another solid chosen completion that has enough text.",
|
| 303 |
+
rejected="Another solid rejected completion that has enough text.",
|
| 304 |
+
n_teachers_agreeing=3,
|
| 305 |
+
),
|
| 306 |
+
_make_pair(
|
| 307 |
+
"e2e3",
|
| 308 |
+
chosen="Third chosen example with sufficient length to pass.",
|
| 309 |
+
rejected="Third rejected example with sufficient length to pass.",
|
| 310 |
+
n_teachers_agreeing=2,
|
| 311 |
+
),
|
| 312 |
+
]
|
| 313 |
+
|
| 314 |
+
normalizer = DJNormalizer(skip_dj=False)
|
| 315 |
+
out = normalizer.normalize(pairs)
|
| 316 |
+
|
| 317 |
+
# Length filter, etc., should NOT drop any of these — all are
|
| 318 |
+
# comfortably within bounds. If we get back fewer than 1, the op-graph
|
| 319 |
+
# is misconfigured.
|
| 320 |
+
assert len(out) >= 1
|
| 321 |
+
for n in out:
|
| 322 |
+
assert isinstance(n, NormalizedDPOPair)
|
| 323 |
+
# Round-trip should always give us chat-messages shape on the way out.
|
| 324 |
+
assert isinstance(n.chosen_messages, list)
|
| 325 |
+
assert isinstance(n.rejected_messages, list)
|
| 326 |
+
assert n.chosen_messages and n.chosen_messages[0]["role"] == "assistant"
|
| 327 |
+
assert n.rejected_messages and n.rejected_messages[0]["role"] == "assistant"
|
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for ADR-007 distillation kwargs in compose_loss.
|
| 2 |
+
|
| 3 |
+
These tests exercise the wiring between `compose_loss` and the three
|
| 4 |
+
pluggable losses (SimPO, TAID, Entropy-Aware OPD). They use a tiny
|
| 5 |
+
hand-rolled language model wrapper (no HF, no TRL) so the tests run
|
| 6 |
+
in <1s on CPU and are isolated from external library churn.
|
| 7 |
+
|
| 8 |
+
Coverage requirements (from Wave 13 BLOCKER 2 fix):
|
| 9 |
+
(a) defaults reproduce existing compose_loss output bit-exact
|
| 10 |
+
(b) dpo_variant='simpo' produces a different total than dpo
|
| 11 |
+
(c) sdpo_wrapper='taid' with schedule_step=0 reproduces existing SDPO
|
| 12 |
+
when alpha_min=alpha_max=1.0
|
| 13 |
+
(d) sdpo_wrapper='taid' interpolates as expected when
|
| 14 |
+
schedule_step=total_steps/2
|
| 15 |
+
(e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar
|
| 16 |
+
(f) error case: sdpo_wrapper='taid' without taid_schedule_step raises
|
| 17 |
+
ValueError
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import pytest
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from composer_replication import LossComponents, compose_loss
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ----------------------------------------------------------------------
|
| 29 |
+
# Tiny LM stand-in
|
| 30 |
+
# ----------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
class TinyLM(nn.Module):
|
| 33 |
+
"""Minimal `nn.Module` with the HF-style `model(input_ids=...).logits` API.
|
| 34 |
+
|
| 35 |
+
Vocab=32, hidden=16, two-layer MLP head. Tiny enough that all tests
|
| 36 |
+
run in milliseconds on CPU.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, vocab: int = 32, hidden: int = 16, seed: int = 0):
|
| 40 |
+
super().__init__()
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
self.embed = nn.Embedding(vocab, hidden)
|
| 43 |
+
self.fc = nn.Linear(hidden, hidden)
|
| 44 |
+
self.head = nn.Linear(hidden, vocab)
|
| 45 |
+
|
| 46 |
+
def forward(self, input_ids: torch.Tensor):
|
| 47 |
+
h = torch.tanh(self.fc(self.embed(input_ids)))
|
| 48 |
+
logits = self.head(h)
|
| 49 |
+
|
| 50 |
+
class _Out:
|
| 51 |
+
pass
|
| 52 |
+
out = _Out()
|
| 53 |
+
out.logits = logits
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ----------------------------------------------------------------------
|
| 58 |
+
# Batch fixtures
|
| 59 |
+
# ----------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
VOCAB = 32
|
| 62 |
+
B = 2
|
| 63 |
+
T = 8
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _base_batch(seed: int = 7, *, with_dpo: bool = True) -> dict[str, torch.Tensor]:
|
| 67 |
+
"""Build a deterministic input batch with all 3 channels populated."""
|
| 68 |
+
g = torch.Generator().manual_seed(seed)
|
| 69 |
+
inputs: dict[str, torch.Tensor] = {
|
| 70 |
+
"input_ids": torch.randint(0, VOCAB, (B, T), generator=g),
|
| 71 |
+
"response_mask": torch.zeros(B, T, dtype=torch.long),
|
| 72 |
+
"ctx_teacher_input_ids": torch.randint(0, VOCAB, (B, T), generator=g),
|
| 73 |
+
"sdpo_loss_mask": torch.zeros(B, T, dtype=torch.long),
|
| 74 |
+
}
|
| 75 |
+
# Mark the second half as response tokens so the LM-CE channel is non-trivial.
|
| 76 |
+
inputs["response_mask"][:, T // 2:] = 1
|
| 77 |
+
inputs["sdpo_loss_mask"][:, T // 2:] = 1
|
| 78 |
+
|
| 79 |
+
if with_dpo:
|
| 80 |
+
inputs["dpo_chosen_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g)
|
| 81 |
+
inputs["dpo_chosen_response_mask"] = torch.ones(B, T, dtype=torch.long)
|
| 82 |
+
inputs["dpo_rejected_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g)
|
| 83 |
+
inputs["dpo_rejected_response_mask"] = torch.ones(B, T, dtype=torch.long)
|
| 84 |
+
# Standard DPO needs ref logprobs; SimPO ignores them.
|
| 85 |
+
inputs["dpo_chosen_ref_logprobs"] = torch.randn(B, generator=g)
|
| 86 |
+
inputs["dpo_rejected_ref_logprobs"] = torch.randn(B, generator=g)
|
| 87 |
+
return inputs
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _model_seeded(seed: int = 0) -> TinyLM:
|
| 91 |
+
m = TinyLM(vocab=VOCAB, hidden=16, seed=seed)
|
| 92 |
+
m.eval() # Deterministic forward — no dropout.
|
| 93 |
+
return m
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ----------------------------------------------------------------------
|
| 97 |
+
# (a) Defaults reproduce existing output bit-exact
|
| 98 |
+
# ----------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
def test_defaults_bit_exact_with_legacy_kwargs():
|
| 101 |
+
"""Calling compose_loss with new kwargs at their defaults must equal
|
| 102 |
+
calling it with only the legacy kwargs. Bit-exact: every channel +
|
| 103 |
+
total agree to 0 ULPs because the code path is identical.
|
| 104 |
+
"""
|
| 105 |
+
inputs = _base_batch()
|
| 106 |
+
|
| 107 |
+
model_a = _model_seeded(seed=0)
|
| 108 |
+
out_legacy = compose_loss(
|
| 109 |
+
model_a,
|
| 110 |
+
inputs,
|
| 111 |
+
alpha_sdpo=0.1,
|
| 112 |
+
beta_replay=0.05,
|
| 113 |
+
sdpo_jsd_beta=0.5,
|
| 114 |
+
sdpo_temperature=1.0,
|
| 115 |
+
replay_dpo_beta=0.1,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
model_b = _model_seeded(seed=0)
|
| 119 |
+
out_new = compose_loss(
|
| 120 |
+
model_b,
|
| 121 |
+
inputs,
|
| 122 |
+
alpha_sdpo=0.1,
|
| 123 |
+
beta_replay=0.05,
|
| 124 |
+
sdpo_jsd_beta=0.5,
|
| 125 |
+
sdpo_temperature=1.0,
|
| 126 |
+
replay_dpo_beta=0.1,
|
| 127 |
+
dpo_variant="dpo",
|
| 128 |
+
sdpo_wrapper="none",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
assert isinstance(out_new, LossComponents)
|
| 132 |
+
assert torch.equal(out_legacy.lm_ce, out_new.lm_ce)
|
| 133 |
+
assert torch.equal(out_legacy.sdpo_jsd, out_new.sdpo_jsd)
|
| 134 |
+
assert torch.equal(out_legacy.trace_replay_dpo, out_new.trace_replay_dpo)
|
| 135 |
+
assert torch.equal(out_legacy.total, out_new.total)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ----------------------------------------------------------------------
|
| 139 |
+
# (b) dpo_variant='simpo' produces a different total than dpo
|
| 140 |
+
# ----------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
def test_simpo_variant_changes_total():
|
| 143 |
+
"""SimPO uses average-logprob and drops the reference subtraction, so
|
| 144 |
+
it must produce a different (and finite) trace_replay_dpo + total."""
|
| 145 |
+
inputs = _base_batch()
|
| 146 |
+
|
| 147 |
+
model_a = _model_seeded(seed=0)
|
| 148 |
+
out_dpo = compose_loss(
|
| 149 |
+
model_a, inputs,
|
| 150 |
+
alpha_sdpo=0.0, # isolate channel 3
|
| 151 |
+
beta_replay=0.05,
|
| 152 |
+
dpo_variant="dpo",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
model_b = _model_seeded(seed=0)
|
| 156 |
+
out_simpo = compose_loss(
|
| 157 |
+
model_b, inputs,
|
| 158 |
+
alpha_sdpo=0.0,
|
| 159 |
+
beta_replay=0.05,
|
| 160 |
+
dpo_variant="simpo",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
assert torch.isfinite(out_simpo.total)
|
| 164 |
+
assert torch.isfinite(out_simpo.trace_replay_dpo)
|
| 165 |
+
# Different formulae => different values.
|
| 166 |
+
assert not torch.allclose(
|
| 167 |
+
out_dpo.trace_replay_dpo, out_simpo.trace_replay_dpo
|
| 168 |
+
)
|
| 169 |
+
assert not torch.allclose(out_dpo.total, out_simpo.total)
|
| 170 |
+
# Gradient flow check.
|
| 171 |
+
out_simpo.total.backward()
|
| 172 |
+
assert any(
|
| 173 |
+
p.grad is not None and torch.isfinite(p.grad).all()
|
| 174 |
+
for p in model_b.parameters()
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def test_simpo_does_not_require_ref_logprobs():
|
| 179 |
+
"""SimPO is reference-free; compose_loss should run when those keys are
|
| 180 |
+
absent from `inputs` (only when dpo_variant='simpo')."""
|
| 181 |
+
inputs = _base_batch()
|
| 182 |
+
inputs.pop("dpo_chosen_ref_logprobs")
|
| 183 |
+
inputs.pop("dpo_rejected_ref_logprobs")
|
| 184 |
+
|
| 185 |
+
model = _model_seeded(seed=0)
|
| 186 |
+
out = compose_loss(
|
| 187 |
+
model, inputs,
|
| 188 |
+
alpha_sdpo=0.0,
|
| 189 |
+
beta_replay=0.05,
|
| 190 |
+
dpo_variant="simpo",
|
| 191 |
+
)
|
| 192 |
+
assert torch.isfinite(out.total)
|
| 193 |
+
assert torch.isfinite(out.trace_replay_dpo)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ----------------------------------------------------------------------
|
| 197 |
+
# (c) TAID with schedule_step=0, alpha_min=alpha_max=1.0 ==> pure SDPO
|
| 198 |
+
# ----------------------------------------------------------------------
|
| 199 |
+
|
| 200 |
+
def test_taid_alpha_one_recovers_sdpo():
|
| 201 |
+
"""With alpha_min=alpha_max=1.0, the TAID schedule is pinned at α=1
|
| 202 |
+
regardless of step. The blended target collapses to pure teacher,
|
| 203 |
+
making channel 2 numerically equivalent to the standard SDPO path
|
| 204 |
+
(modulo the softmax→log roundtrip in `taid_blended_logits`, which is
|
| 205 |
+
bit-equivalent for finite logits).
|
| 206 |
+
"""
|
| 207 |
+
inputs = _base_batch(with_dpo=False)
|
| 208 |
+
|
| 209 |
+
model_a = _model_seeded(seed=1)
|
| 210 |
+
out_sdpo = compose_loss(
|
| 211 |
+
model_a, inputs,
|
| 212 |
+
alpha_sdpo=0.1,
|
| 213 |
+
beta_replay=0.0, # disable channel 3 so we isolate channel 2
|
| 214 |
+
sdpo_wrapper="none",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
model_b = _model_seeded(seed=1)
|
| 218 |
+
# Provide a student_init_logits snapshot — for α=1 its value doesn't
|
| 219 |
+
# affect the blended target (P_blended = teacher when α=1), so any
|
| 220 |
+
# valid-shape tensor works. Use the teacher shape.
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
init_logits = model_b(input_ids=inputs["ctx_teacher_input_ids"]).logits.clone()
|
| 223 |
+
inputs_taid = dict(inputs)
|
| 224 |
+
inputs_taid["student_init_logits"] = init_logits
|
| 225 |
+
|
| 226 |
+
out_taid = compose_loss(
|
| 227 |
+
model_b, inputs_taid,
|
| 228 |
+
alpha_sdpo=0.1,
|
| 229 |
+
beta_replay=0.0,
|
| 230 |
+
sdpo_wrapper="taid",
|
| 231 |
+
taid_schedule_step=0,
|
| 232 |
+
taid_total_steps=100,
|
| 233 |
+
taid_alpha_min=1.0,
|
| 234 |
+
taid_alpha_max=1.0,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Same channel-2 value up to numerical roundtrip through softmax→log.
|
| 238 |
+
assert torch.allclose(out_sdpo.sdpo_jsd, out_taid.sdpo_jsd, atol=1e-5, rtol=1e-5)
|
| 239 |
+
assert torch.allclose(out_sdpo.total, out_taid.total, atol=1e-5, rtol=1e-5)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ----------------------------------------------------------------------
|
| 243 |
+
# (d) TAID interpolates at schedule_step = total_steps / 2
|
| 244 |
+
# ----------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
def test_taid_interpolates_at_midpoint():
|
| 247 |
+
"""At step=total_steps/2 with schedule='linear' and alpha_min=0,
|
| 248 |
+
alpha_max=1, the schedule yields α=0.5. The resulting loss must
|
| 249 |
+
differ from both endpoints (α=0 → init-only target, α=1 → pure SDPO),
|
| 250 |
+
and must be finite + differentiable.
|
| 251 |
+
"""
|
| 252 |
+
inputs = _base_batch(with_dpo=False)
|
| 253 |
+
|
| 254 |
+
# Build a single shared student_init_logits snapshot. We use a
|
| 255 |
+
# *different-seed* model to produce it so the blended target actually
|
| 256 |
+
# differs from the live student's teacher forward (otherwise α=0 and
|
| 257 |
+
# α=1 would both target the same distribution and the test would
|
| 258 |
+
# become vacuous).
|
| 259 |
+
snapshot_model = _model_seeded(seed=99)
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
init_logits = snapshot_model(
|
| 262 |
+
input_ids=inputs["ctx_teacher_input_ids"]
|
| 263 |
+
).logits.clone()
|
| 264 |
+
inputs = dict(inputs)
|
| 265 |
+
inputs["student_init_logits"] = init_logits
|
| 266 |
+
|
| 267 |
+
# Endpoint α=1 (pure SDPO target — init_logits ignored)
|
| 268 |
+
model_end = _model_seeded(seed=2)
|
| 269 |
+
out_alpha_one = compose_loss(
|
| 270 |
+
model_end, inputs,
|
| 271 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 272 |
+
sdpo_wrapper="taid",
|
| 273 |
+
taid_schedule_step=100, taid_total_steps=100,
|
| 274 |
+
taid_alpha_min=0.0, taid_alpha_max=1.0,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Endpoint α=0 (pure init target — teacher_logits ignored)
|
| 278 |
+
model_start = _model_seeded(seed=2)
|
| 279 |
+
out_alpha_zero = compose_loss(
|
| 280 |
+
model_start, inputs,
|
| 281 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 282 |
+
sdpo_wrapper="taid",
|
| 283 |
+
taid_schedule_step=0, taid_total_steps=100,
|
| 284 |
+
taid_alpha_min=0.0, taid_alpha_max=1.0,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Midpoint α=0.5
|
| 288 |
+
model_mid = _model_seeded(seed=2)
|
| 289 |
+
out_mid = compose_loss(
|
| 290 |
+
model_mid, inputs,
|
| 291 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 292 |
+
sdpo_wrapper="taid",
|
| 293 |
+
taid_schedule_step=50, taid_total_steps=100,
|
| 294 |
+
taid_alpha_min=0.0, taid_alpha_max=1.0,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# All finite.
|
| 298 |
+
for out in (out_alpha_zero, out_mid, out_alpha_one):
|
| 299 |
+
assert torch.isfinite(out.total), f"non-finite total: {out.total}"
|
| 300 |
+
assert torch.isfinite(out.sdpo_jsd), f"non-finite sdpo_jsd: {out.sdpo_jsd}"
|
| 301 |
+
|
| 302 |
+
# Midpoint must differ from both endpoints — different blended target.
|
| 303 |
+
assert not torch.allclose(
|
| 304 |
+
out_mid.sdpo_jsd, out_alpha_zero.sdpo_jsd, atol=1e-5
|
| 305 |
+
), "midpoint TAID matches α=0 endpoint — schedule not interpolating"
|
| 306 |
+
assert not torch.allclose(
|
| 307 |
+
out_mid.sdpo_jsd, out_alpha_one.sdpo_jsd, atol=1e-5
|
| 308 |
+
), "midpoint TAID matches α=1 endpoint — schedule not interpolating"
|
| 309 |
+
|
| 310 |
+
# Differentiable.
|
| 311 |
+
out_mid.total.backward()
|
| 312 |
+
assert any(
|
| 313 |
+
p.grad is not None and torch.isfinite(p.grad).all()
|
| 314 |
+
for p in model_mid.parameters()
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# ----------------------------------------------------------------------
|
| 319 |
+
# (e) Entropy-Aware OPD returns a finite differentiable scalar
|
| 320 |
+
# ----------------------------------------------------------------------
|
| 321 |
+
|
| 322 |
+
def test_entropy_opd_returns_finite_differentiable_scalar():
|
| 323 |
+
inputs = _base_batch(with_dpo=False)
|
| 324 |
+
|
| 325 |
+
model = _model_seeded(seed=3)
|
| 326 |
+
out = compose_loss(
|
| 327 |
+
model, inputs,
|
| 328 |
+
alpha_sdpo=0.1,
|
| 329 |
+
beta_replay=0.0,
|
| 330 |
+
sdpo_wrapper="entropy_opd",
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
assert isinstance(out, LossComponents)
|
| 334 |
+
assert out.total.shape == ()
|
| 335 |
+
assert torch.isfinite(out.total)
|
| 336 |
+
assert torch.isfinite(out.sdpo_jsd)
|
| 337 |
+
assert out.total.requires_grad
|
| 338 |
+
|
| 339 |
+
out.total.backward()
|
| 340 |
+
grads = [p.grad for p in model.parameters() if p.grad is not None]
|
| 341 |
+
assert len(grads) > 0
|
| 342 |
+
assert all(torch.isfinite(g).all() for g in grads)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# ----------------------------------------------------------------------
|
| 346 |
+
# (f) Error: sdpo_wrapper='taid' without taid_schedule_step
|
| 347 |
+
# ----------------------------------------------------------------------
|
| 348 |
+
|
| 349 |
+
def test_taid_requires_schedule_step():
|
| 350 |
+
inputs = _base_batch(with_dpo=False)
|
| 351 |
+
model = _model_seeded(seed=4)
|
| 352 |
+
with pytest.raises(ValueError, match="taid_schedule_step"):
|
| 353 |
+
compose_loss(
|
| 354 |
+
model, inputs,
|
| 355 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 356 |
+
sdpo_wrapper="taid",
|
| 357 |
+
taid_total_steps=100,
|
| 358 |
+
# taid_schedule_step omitted on purpose
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def test_taid_requires_total_steps():
|
| 363 |
+
inputs = _base_batch(with_dpo=False)
|
| 364 |
+
model = _model_seeded(seed=4)
|
| 365 |
+
with pytest.raises(ValueError, match="taid_total_steps"):
|
| 366 |
+
compose_loss(
|
| 367 |
+
model, inputs,
|
| 368 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 369 |
+
sdpo_wrapper="taid",
|
| 370 |
+
taid_schedule_step=0,
|
| 371 |
+
# taid_total_steps omitted on purpose
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def test_invalid_dpo_variant_raises():
|
| 376 |
+
inputs = _base_batch()
|
| 377 |
+
model = _model_seeded(seed=5)
|
| 378 |
+
with pytest.raises(ValueError, match="dpo_variant"):
|
| 379 |
+
compose_loss(
|
| 380 |
+
model, inputs,
|
| 381 |
+
dpo_variant="bogus", # type: ignore[arg-type]
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def test_invalid_sdpo_wrapper_raises():
|
| 386 |
+
inputs = _base_batch()
|
| 387 |
+
model = _model_seeded(seed=5)
|
| 388 |
+
with pytest.raises(ValueError, match="sdpo_wrapper"):
|
| 389 |
+
compose_loss(
|
| 390 |
+
model, inputs,
|
| 391 |
+
sdpo_wrapper="bogus", # type: ignore[arg-type]
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ----------------------------------------------------------------------
|
| 396 |
+
# Bonus: TAID accepts precomputed init logits
|
| 397 |
+
# ----------------------------------------------------------------------
|
| 398 |
+
|
| 399 |
+
def test_taid_accepts_precomputed_student_init_logits():
|
| 400 |
+
"""The preferred path: caller saves a step-0 logits snapshot and
|
| 401 |
+
passes it as `inputs['student_init_logits']`."""
|
| 402 |
+
inputs = _base_batch(with_dpo=False)
|
| 403 |
+
model = _model_seeded(seed=6)
|
| 404 |
+
|
| 405 |
+
# Pre-compute init logits the way a real trainer would.
|
| 406 |
+
with torch.no_grad():
|
| 407 |
+
init_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits.clone()
|
| 408 |
+
inputs["student_init_logits"] = init_logits
|
| 409 |
+
|
| 410 |
+
out = compose_loss(
|
| 411 |
+
model, inputs,
|
| 412 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 413 |
+
sdpo_wrapper="taid",
|
| 414 |
+
taid_schedule_step=10, taid_total_steps=100,
|
| 415 |
+
)
|
| 416 |
+
assert torch.isfinite(out.total)
|
|
@@ -0,0 +1,1484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Reference — composer-replication-framework
|
| 2 |
+
|
| 3 |
+
Complete reference for every public symbol in `composer_replication`. Source-of-truth is the `.py` files in `composer_replication/`; docstrings have been pulled verbatim where they exist and supplemented where missing.
|
| 4 |
+
|
| 5 |
+
**Legend**
|
| 6 |
+
|
| 7 |
+
- ⚠️ **UNTESTED-CONTRACT** — symbol exists and is callable, but its behaviour is not pinned by an automated test in `composer_replication/**/tests/` or `spikes/**/tests/`.
|
| 8 |
+
- 🟡 **SKELETON** — class/method body raises `NotImplementedError`; ships as design-of-record per ADR-005 / ADR-006.
|
| 9 |
+
|
| 10 |
+
**Module groups (in this document)**
|
| 11 |
+
|
| 12 |
+
1. `composer_replication` (top-level re-exports)
|
| 13 |
+
2. `composer_replication.loss`
|
| 14 |
+
3. `composer_replication.batch`
|
| 15 |
+
4. `composer_replication.opsd`
|
| 16 |
+
5. `composer_replication.distillation`
|
| 17 |
+
6. `composer_replication.teacher_replay`
|
| 18 |
+
7. `composer_replication.replaysim`
|
| 19 |
+
8. `composer_replication.ingestion` (+ `.claude_code`)
|
| 20 |
+
9. `composer_replication.hint_generator`
|
| 21 |
+
10. `composer_replication.trainer` (+ `.composer_trainer`, `.data_collator`)
|
| 22 |
+
11. `composer_replication.diloco`
|
| 23 |
+
12. `composer_replication.diloco.serverless` (+ `.executor`, `.allreduce`, `.modal`, `.hf_jobs`, `.replica_entrypoint`)
|
| 24 |
+
13. `composer_replication.recipes.prime_rl.composer_loss`
|
| 25 |
+
14. `composer_replication.recipes.monarch.actors`
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 1. `composer_replication` — top-level package
|
| 30 |
+
|
| 31 |
+
The package re-exports the most common entry points from sub-modules. `__all__` is the canonical list of public top-level names.
|
| 32 |
+
|
| 33 |
+
### `composer_replication.__version__: str`
|
| 34 |
+
|
| 35 |
+
Package version string. Currently `"0.1.0"`.
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
import composer_replication
|
| 39 |
+
print(composer_replication.__version__) # "0.1.0"
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### `composer_replication._DILOCO_AVAILABLE: bool`
|
| 43 |
+
|
| 44 |
+
`True` iff `torchft` is importable in the running Python environment (gates `make_diloco_outer_loop`). Set to `False` and `make_diloco_outer_loop` is set to `None` when `torchft` is missing.
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from composer_replication import _DILOCO_AVAILABLE
|
| 48 |
+
if _DILOCO_AVAILABLE:
|
| 49 |
+
from composer_replication import make_diloco_outer_loop
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Re-exports
|
| 53 |
+
|
| 54 |
+
| Name | Source module |
|
| 55 |
+
|---|---|
|
| 56 |
+
| `compose_loss` | `composer_replication.loss` |
|
| 57 |
+
| `LossComponents` | `composer_replication.loss` |
|
| 58 |
+
| `build_batch` | `composer_replication.batch` |
|
| 59 |
+
| `generalized_jsd_loss` | `composer_replication.opsd` |
|
| 60 |
+
| `ClaudeCodeIngester` | `composer_replication.ingestion.claude_code` |
|
| 61 |
+
| `IngestionStats` | `composer_replication.ingestion.claude_code` |
|
| 62 |
+
| `SYSTEM_PROMPT` | `composer_replication.ingestion.claude_code` |
|
| 63 |
+
| `DEFAULT_TEACHERS` | `composer_replication.teacher_replay` |
|
| 64 |
+
| `DPOPair` | `composer_replication.teacher_replay` |
|
| 65 |
+
| `TeacherCallResult` | `composer_replication.teacher_replay` |
|
| 66 |
+
| `TeacherSpec` | `composer_replication.teacher_replay` |
|
| 67 |
+
| `TraceState` | `composer_replication.teacher_replay` |
|
| 68 |
+
| `extract_dpo_pairs` | `composer_replication.teacher_replay` |
|
| 69 |
+
| `replay_trace` | `composer_replication.teacher_replay` |
|
| 70 |
+
| `ComposerReplicationTrainer` | `composer_replication.trainer` |
|
| 71 |
+
| `make_diloco_outer_loop` | `composer_replication.diloco` (or `None` if `torchft` missing) |
|
| 72 |
+
|
| 73 |
+
See each source module below for full signatures.
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## 2. `composer_replication.loss`
|
| 78 |
+
|
| 79 |
+
Verification-harness 3-channel loss. Free function, does not depend on `trl`.
|
| 80 |
+
|
| 81 |
+
### `class LossComponents`
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
@dataclass
|
| 85 |
+
class LossComponents:
|
| 86 |
+
lm_ce: torch.Tensor
|
| 87 |
+
sdpo_jsd: torch.Tensor
|
| 88 |
+
trace_replay_dpo: torch.Tensor
|
| 89 |
+
total: torch.Tensor
|
| 90 |
+
|
| 91 |
+
def detached(self) -> dict[str, float]: ...
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar `torch.Tensor`s (`shape=()`); `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`.
|
| 95 |
+
|
| 96 |
+
**`detached() -> dict[str, float]`** — returns Python-float copies of all four fields with no grad. Useful for W&B logging.
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
from composer_replication import compose_loss, build_batch
|
| 100 |
+
components = compose_loss(model, build_batch(tokenizer))
|
| 101 |
+
print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...}
|
| 102 |
+
components.total.backward()
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### `compose_loss(model, inputs, *, ...) -> LossComponents`
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
def compose_loss(
|
| 109 |
+
model: torch.nn.Module,
|
| 110 |
+
inputs: dict[str, torch.Tensor],
|
| 111 |
+
*,
|
| 112 |
+
alpha_sdpo: float = 0.1,
|
| 113 |
+
beta_replay: float = 0.05,
|
| 114 |
+
sdpo_jsd_beta: float = 0.5,
|
| 115 |
+
sdpo_temperature: float = 1.0,
|
| 116 |
+
sdpo_token_clip: float | None = None,
|
| 117 |
+
replay_dpo_beta: float = 0.1,
|
| 118 |
+
lm_ce_label_smoothing: float = 0.0,
|
| 119 |
+
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 120 |
+
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 121 |
+
taid_schedule_step: int | None = None,
|
| 122 |
+
taid_total_steps: int | None = None,
|
| 123 |
+
simpo_beta: float = 2.0,
|
| 124 |
+
simpo_gamma: float = 1.0,
|
| 125 |
+
taid_schedule: str = "linear",
|
| 126 |
+
taid_alpha_min: float = 0.0,
|
| 127 |
+
taid_alpha_max: float = 1.0,
|
| 128 |
+
entropy_opd_h_max: float | None = None,
|
| 129 |
+
) -> LossComponents
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`.
|
| 133 |
+
|
| 134 |
+
**Required keys in `inputs`**
|
| 135 |
+
|
| 136 |
+
- `input_ids`: `(B, T_s)` student rollout token ids.
|
| 137 |
+
- `response_mask`: `(B, T_s)` 1 on assistant-response tokens, 0 elsewhere.
|
| 138 |
+
|
| 139 |
+
**Optional keys** (channel auto-disables if missing OR if its weight = 0):
|
| 140 |
+
|
| 141 |
+
- SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`.
|
| 142 |
+
- DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed).
|
| 143 |
+
- SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored.
|
| 144 |
+
- TAID (`sdpo_wrapper="taid"`): `student_init_logits` `(B, T_t, V)` precomputed, OR `student_init_input_ids` `(B, T_t)` for a no-grad-fallback forward.
|
| 145 |
+
|
| 146 |
+
**Parameters**
|
| 147 |
+
|
| 148 |
+
| Name | Type | Default | Meaning |
|
| 149 |
+
|---|---|---|---|
|
| 150 |
+
| `model` | `torch.nn.Module` | — | HF causal-LM. Must accept `input_ids=` and return an object with `.logits`. |
|
| 151 |
+
| `inputs` | `dict[str, torch.Tensor]` | — | Batch dict (see required/optional keys above). |
|
| 152 |
+
| `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. |
|
| 153 |
+
| `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. |
|
| 154 |
+
| `sdpo_jsd_beta` | `float` | `0.5` | β param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). |
|
| 155 |
+
| `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. |
|
| 156 |
+
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. |
|
| 157 |
+
| `replay_dpo_beta` | `float` | `0.1` | β in standard DPO logit. |
|
| 158 |
+
| `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. |
|
| 159 |
+
| `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. |
|
| 160 |
+
| `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. |
|
| 161 |
+
| `taid_schedule_step` | `int \| None` | `None` | Required when `sdpo_wrapper="taid"`. |
|
| 162 |
+
| `taid_total_steps` | `int \| None` | `None` | Required when `sdpo_wrapper="taid"`. |
|
| 163 |
+
| `simpo_beta` | `float` | `2.0` | SimPO β (paper default). |
|
| 164 |
+
| `simpo_gamma` | `float` | `1.0` | SimPO target margin γ (paper default). |
|
| 165 |
+
| `taid_schedule` | `str` | `"linear"` | One of `"linear"`, `"cosine"`, `"exp"`. |
|
| 166 |
+
| `taid_alpha_min` | `float` | `0.0` | Lower α bound. |
|
| 167 |
+
| `taid_alpha_max` | `float` | `1.0` | Upper α bound. |
|
| 168 |
+
| `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None` ⇒ `log(V)`. |
|
| 169 |
+
|
| 170 |
+
**Returns** `LossComponents` (see above).
|
| 171 |
+
|
| 172 |
+
**Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without both `taid_schedule_step` and `taid_total_steps`, or if TAID's frozen-init logits cannot be resolved (neither `student_init_logits` nor `student_init_input_ids` provided / shape mismatch).
|
| 173 |
+
|
| 174 |
+
```python
|
| 175 |
+
from composer_replication import compose_loss, build_batch
|
| 176 |
+
batch = build_batch(tokenizer)
|
| 177 |
+
out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
|
| 178 |
+
out.total.backward()
|
| 179 |
+
print(out.detached())
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
## 3. `composer_replication.batch`
|
| 185 |
+
|
| 186 |
+
Verification-harness batch builder.
|
| 187 |
+
|
| 188 |
+
### `build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]`
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
def build_batch(
|
| 192 |
+
tokenizer: Any,
|
| 193 |
+
*,
|
| 194 |
+
device: torch.device | str = "cpu",
|
| 195 |
+
seed: int = 42,
|
| 196 |
+
variant: str = "factorial",
|
| 197 |
+
align_sdpo_shapes: bool = False,
|
| 198 |
+
) -> dict[str, torch.Tensor]
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute).
|
| 202 |
+
|
| 203 |
+
**Returned keys**: `input_ids`, `response_mask`, `ctx_teacher_input_ids`, `sdpo_loss_mask`, `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs`.
|
| 204 |
+
|
| 205 |
+
**Parameters**
|
| 206 |
+
|
| 207 |
+
| Name | Type | Default | Meaning |
|
| 208 |
+
|---|---|---|---|
|
| 209 |
+
| `tokenizer` | HF `AutoTokenizer` (duck-typed) | — | Must support `apply_chat_template` and `__call__`. |
|
| 210 |
+
| `device` | `torch.device \| str` | `"cpu"` | Target device for all returned tensors. |
|
| 211 |
+
| `seed` | `int` | `42` | Fixes `torch.manual_seed`. |
|
| 212 |
+
| `variant` | `str` | `"factorial"` | One of `"factorial"`, `"binary_search"`. |
|
| 213 |
+
| `align_sdpo_shapes` | `bool` | `False` | If True, truncate/pad `ctx_teacher_input_ids` to `input_ids` length so the SDPO channel actually fires. |
|
| 214 |
+
|
| 215 |
+
**Raises** `ValueError` if `variant` is unknown.
|
| 216 |
+
|
| 217 |
+
```python
|
| 218 |
+
from transformers import AutoTokenizer
|
| 219 |
+
from composer_replication import build_batch
|
| 220 |
+
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
| 221 |
+
batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True)
|
| 222 |
+
print({k: v.shape for k, v in batch.items()})
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
---
|
| 226 |
+
|
| 227 |
+
## 4. `composer_replication.opsd`
|
| 228 |
+
|
| 229 |
+
Self-distillation generalized-JSD loss, lifted verbatim from `siyan-zhao/OPSD` (MIT) per ADR-006.
|
| 230 |
+
|
| 231 |
+
### `generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor`
|
| 232 |
+
|
| 233 |
+
```python
|
| 234 |
+
def generalized_jsd_loss(
|
| 235 |
+
student_logits: torch.Tensor,
|
| 236 |
+
teacher_logits: torch.Tensor,
|
| 237 |
+
labels: torch.Tensor | None = None,
|
| 238 |
+
beta: float = 0.5,
|
| 239 |
+
temperature: float = 1.0,
|
| 240 |
+
reduction: str = "batchmean",
|
| 241 |
+
logits_are_probs: bool = False,
|
| 242 |
+
top_k: int | None = None,
|
| 243 |
+
token_clip: float | None = None,
|
| 244 |
+
) -> torch.Tensor
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model.
|
| 248 |
+
|
| 249 |
+
**Parameters**
|
| 250 |
+
|
| 251 |
+
| Name | Type | Default | Meaning |
|
| 252 |
+
|---|---|---|---|
|
| 253 |
+
| `student_logits` | `Tensor (B, T, V)` | — | Student logits with grad. |
|
| 254 |
+
| `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits (no grad in SDPO). |
|
| 255 |
+
| `labels` | `Tensor (B, T) \| None` | `None` | Per-token mask. `-100` positions are ignored (HF convention). |
|
| 256 |
+
| `beta` | `float` in [0, 1] | `0.5` | 0=fwd KL, 1=rev KL, 0.5=symmetric JSD. |
|
| 257 |
+
| `temperature` | `float` | `1.0` | Softmax temperature. |
|
| 258 |
+
| `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. |
|
| 259 |
+
| `logits_are_probs` | `bool` | `False` | Skip softmax if inputs are already probabilities. |
|
| 260 |
+
| `top_k` | `int \| None` | `None` | Restrict KL to teacher's top-k tokens. |
|
| 261 |
+
| `token_clip` | `float \| None` | `None` | Clip per-token JSD for stability. |
|
| 262 |
+
|
| 263 |
+
**Returns** scalar tensor (or `(B, T)` if `reduction="none"`).
|
| 264 |
+
|
| 265 |
+
**Raises** `ValueError` for unknown `reduction`.
|
| 266 |
+
|
| 267 |
+
```python
|
| 268 |
+
import torch
|
| 269 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 270 |
+
s = torch.randn(2, 8, 32, requires_grad=True)
|
| 271 |
+
t = torch.randn(2, 8, 32)
|
| 272 |
+
loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean")
|
| 273 |
+
loss.backward()
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
---
|
| 277 |
+
|
| 278 |
+
## 5. `composer_replication.distillation`
|
| 279 |
+
|
| 280 |
+
Pluggable self-distillation losses (ADR-007). All pure PyTorch.
|
| 281 |
+
|
| 282 |
+
### `simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor`
|
| 283 |
+
|
| 284 |
+
```python
|
| 285 |
+
def simpo_loss(
|
| 286 |
+
chosen_avg_logprobs: torch.Tensor,
|
| 287 |
+
rejected_avg_logprobs: torch.Tensor,
|
| 288 |
+
*,
|
| 289 |
+
beta: float = 2.0,
|
| 290 |
+
gamma: float = 1.0,
|
| 291 |
+
) -> torch.Tensor
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). `L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ)`.
|
| 295 |
+
|
| 296 |
+
**Parameters**
|
| 297 |
+
|
| 298 |
+
| Name | Type | Default | Meaning |
|
| 299 |
+
|---|---|---|---|
|
| 300 |
+
| `chosen_avg_logprobs` | `Tensor (B,)` | — | Per-sequence avg logprob over chosen response tokens. |
|
| 301 |
+
| `rejected_avg_logprobs` | `Tensor (B,)` | — | Same for rejected. |
|
| 302 |
+
| `beta` | `float` | `2.0` | Scaling factor (paper default). |
|
| 303 |
+
| `gamma` | `float` | `1.0` | Target margin (paper default). |
|
| 304 |
+
|
| 305 |
+
**Returns** scalar; **Raises** `ValueError` if shapes mismatch.
|
| 306 |
+
|
| 307 |
+
```python
|
| 308 |
+
import torch
|
| 309 |
+
from composer_replication.distillation import simpo_loss
|
| 310 |
+
loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]),
|
| 311 |
+
beta=2.0, gamma=1.0)
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
### `avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor`
|
| 315 |
+
|
| 316 |
+
⚠️ UNTESTED-CONTRACT (helper exported from `simpo.py` but not asserted by a test).
|
| 317 |
+
|
| 318 |
+
```python
|
| 319 |
+
def avg_sequence_logprob(
|
| 320 |
+
model_logprobs: torch.Tensor,
|
| 321 |
+
response_mask: torch.Tensor,
|
| 322 |
+
) -> torch.Tensor
|
| 323 |
+
```
|
| 324 |
+
|
| 325 |
+
Convert `(B, T)` per-token logprobs + `(B, T)` response mask into `(B,)` per-sequence average over response tokens.
|
| 326 |
+
|
| 327 |
+
```python
|
| 328 |
+
from composer_replication.distillation.simpo import avg_sequence_logprob
|
| 329 |
+
import torch
|
| 330 |
+
lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
|
| 331 |
+
out = avg_sequence_logprob(lp, m) # shape (2,)
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
### `taid_loss(student_logits, teacher_logits, student_init_logits, *, schedule_step, total_steps, ...) -> torch.Tensor`
|
| 335 |
+
|
| 336 |
+
```python
|
| 337 |
+
def taid_loss(
|
| 338 |
+
student_logits: torch.Tensor,
|
| 339 |
+
teacher_logits: torch.Tensor,
|
| 340 |
+
student_init_logits: torch.Tensor,
|
| 341 |
+
*,
|
| 342 |
+
schedule_step: int,
|
| 343 |
+
total_steps: int,
|
| 344 |
+
schedule: str = "linear",
|
| 345 |
+
alpha_min: float = 0.0,
|
| 346 |
+
alpha_max: float = 1.0,
|
| 347 |
+
jsd_beta: float = 0.5,
|
| 348 |
+
temperature: float = 1.0,
|
| 349 |
+
reduction: str = "batchmean",
|
| 350 |
+
) -> torch.Tensor
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
TAID-wrapped generalized-JSD: target distribution is `(1-α)·P_student_init + α·P_teacher` with α annealed by `schedule_step / total_steps`. At α=0 you regularize toward init; at α=1 it reduces to plain SDPO.
|
| 354 |
+
|
| 355 |
+
**Parameters**
|
| 356 |
+
|
| 357 |
+
| Name | Type | Default | Meaning |
|
| 358 |
+
|---|---|---|---|
|
| 359 |
+
| `student_logits` | `Tensor (B,T,V)` | — | Current student (with grad). |
|
| 360 |
+
| `teacher_logits` | `Tensor (B,T,V)` | — | Teacher logits (no grad). |
|
| 361 |
+
| `student_init_logits` | `Tensor (B,T,V)` | — | Frozen step-0 student logits. Caller must keep a snapshot. |
|
| 362 |
+
| `schedule_step` | `int` | — | Current training step. |
|
| 363 |
+
| `total_steps` | `int` | — | Total planned steps. |
|
| 364 |
+
| `schedule` | `str` | `"linear"` | One of `"linear"`, `"cosine"`, `"exp"`. |
|
| 365 |
+
| `alpha_min`, `alpha_max` | `float`, `float` | `0.0`, `1.0` | Schedule range. |
|
| 366 |
+
| `jsd_beta` | `float` | `0.5` | β param of `generalized_jsd_loss`. |
|
| 367 |
+
| `temperature` | `float` | `1.0` | Softmax temperature. |
|
| 368 |
+
| `reduction` | `str` | `"batchmean"` | Forwarded to `generalized_jsd_loss`. |
|
| 369 |
+
|
| 370 |
+
**Raises** `ValueError` for unknown `schedule`, non-positive `total_steps`, negative `step`, or shape mismatch.
|
| 371 |
+
|
| 372 |
+
```python
|
| 373 |
+
from composer_replication.distillation import taid_loss
|
| 374 |
+
loss = taid_loss(s_logits, t_logits, init_logits,
|
| 375 |
+
schedule_step=500, total_steps=10_000, schedule="linear")
|
| 376 |
+
```
|
| 377 |
+
|
| 378 |
+
### `taid_alpha_schedule(step, total_steps, *, schedule="linear", alpha_min=0.0, alpha_max=1.0, warmup_frac=0.0) -> float`
|
| 379 |
+
|
| 380 |
+
```python
|
| 381 |
+
def taid_alpha_schedule(
|
| 382 |
+
step: int, total_steps: int, *,
|
| 383 |
+
schedule: str = "linear",
|
| 384 |
+
alpha_min: float = 0.0,
|
| 385 |
+
alpha_max: float = 1.0,
|
| 386 |
+
warmup_frac: float = 0.0,
|
| 387 |
+
) -> float
|
| 388 |
+
```
|
| 389 |
+
|
| 390 |
+
Compute α(t) for the TAID schedule. Returns a Python float in `[alpha_min, alpha_max]`.
|
| 391 |
+
|
| 392 |
+
**Raises** `ValueError` on `total_steps <= 0`, `step < 0`, or unknown `schedule`.
|
| 393 |
+
|
| 394 |
+
```python
|
| 395 |
+
from composer_replication.distillation.taid import taid_alpha_schedule
|
| 396 |
+
a = taid_alpha_schedule(step=500, total_steps=10000, schedule="cosine") # 0.012...
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
### `taid_blended_logits(student_init_logits, teacher_logits, alpha) -> torch.Tensor`
|
| 400 |
+
|
| 401 |
+
```python
|
| 402 |
+
def taid_blended_logits(
|
| 403 |
+
student_init_logits: torch.Tensor,
|
| 404 |
+
teacher_logits: torch.Tensor,
|
| 405 |
+
alpha: float,
|
| 406 |
+
) -> torch.Tensor
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
Return logits whose softmax is `(1-α)·P_student_init + α·P_teacher`. Mixes in probability space then `log()`.
|
| 410 |
+
|
| 411 |
+
**Raises** `ValueError` if `alpha` ∉ `[0,1]` or shapes differ.
|
| 412 |
+
|
| 413 |
+
```python
|
| 414 |
+
from composer_replication.distillation.taid import taid_blended_logits
|
| 415 |
+
blended = taid_blended_logits(init_logits, teacher_logits, alpha=0.3)
|
| 416 |
+
```
|
| 417 |
+
|
| 418 |
+
### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor`
|
| 419 |
+
|
| 420 |
+
```python
|
| 421 |
+
def entropy_aware_opd_loss(
|
| 422 |
+
student_logits: torch.Tensor,
|
| 423 |
+
teacher_logits: torch.Tensor,
|
| 424 |
+
*,
|
| 425 |
+
labels: torch.Tensor | None = None,
|
| 426 |
+
h_max: float | None = None,
|
| 427 |
+
temperature: float = 1.0,
|
| 428 |
+
reduction: str = "batchmean",
|
| 429 |
+
) -> torch.Tensor
|
| 430 |
+
```
|
| 431 |
+
|
| 432 |
+
Per-token mixture of forward and reverse KL gated by teacher entropy: `w(t) = clamp(H_teacher(t)/h_max, 0, 1)`. High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking).
|
| 433 |
+
|
| 434 |
+
**Parameters**
|
| 435 |
+
|
| 436 |
+
| Name | Type | Default | Meaning |
|
| 437 |
+
|---|---|---|---|
|
| 438 |
+
| `student_logits` | `Tensor (B,T,V)` | — | Student logits (grad). |
|
| 439 |
+
| `teacher_logits` | `Tensor (B,T,V)` | — | Teacher logits (no grad). |
|
| 440 |
+
| `labels` | `Tensor (B,T) \| None` | `None` | 0/1 mask, applied multiplicatively after the per-token mix. |
|
| 441 |
+
| `h_max` | `float \| None` | `None` ⇒ `log(V)` | Max-entropy normalizer. |
|
| 442 |
+
| `temperature` | `float` | `1.0` | Softmax temperature on both. |
|
| 443 |
+
| `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. |
|
| 444 |
+
|
| 445 |
+
**Raises** `ValueError` on shape mismatch (student vs teacher; labels vs per-token loss) or unknown `reduction`.
|
| 446 |
+
|
| 447 |
+
```python
|
| 448 |
+
from composer_replication.distillation import entropy_aware_opd_loss
|
| 449 |
+
loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0)
|
| 450 |
+
loss.backward()
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
### `teacher_entropy(teacher_logits) -> torch.Tensor`
|
| 454 |
+
|
| 455 |
+
⚠️ UNTESTED-CONTRACT (helper exposed from `entropy_aware_opd.py`'s `__all__` but not directly asserted).
|
| 456 |
+
|
| 457 |
+
Per-token entropy in nats. Input `(B,T,V)`, output `(B,T)`.
|
| 458 |
+
|
| 459 |
+
```python
|
| 460 |
+
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
|
| 461 |
+
H = teacher_entropy(teacher_logits) # (B, T)
|
| 462 |
+
```
|
| 463 |
+
|
| 464 |
+
---
|
| 465 |
+
|
| 466 |
+
## 6. `composer_replication.teacher_replay`
|
| 467 |
+
|
| 468 |
+
N-teacher OpenRouter parallel client + DPO-pair extractor. `httpx` is lazy-imported inside `replay_trace`; the deterministic local logic is testable without it.
|
| 469 |
+
|
| 470 |
+
### `DEFAULT_TEACHERS: list[TeacherSpec]`
|
| 471 |
+
|
| 472 |
+
Three-teacher default set: `anthropic/claude-opus-4.7`, `openai/gpt-5`, `deepseek/deepseek-v4-pro` with paper-baseline OpenRouter pricing.
|
| 473 |
+
|
| 474 |
+
```python
|
| 475 |
+
from composer_replication.teacher_replay import DEFAULT_TEACHERS
|
| 476 |
+
print([t["slug"] for t in DEFAULT_TEACHERS])
|
| 477 |
+
```
|
| 478 |
+
|
| 479 |
+
### `class TeacherSpec(TypedDict)`
|
| 480 |
+
|
| 481 |
+
```python
|
| 482 |
+
class TeacherSpec(TypedDict):
|
| 483 |
+
slug: str
|
| 484 |
+
input_per_mtok: float
|
| 485 |
+
output_per_mtok: float
|
| 486 |
+
```
|
| 487 |
+
|
| 488 |
+
OpenRouter model slug + per-million-token pricing.
|
| 489 |
+
|
| 490 |
+
```python
|
| 491 |
+
spec: TeacherSpec = {"slug": "openai/gpt-5",
|
| 492 |
+
"input_per_mtok": 1.25, "output_per_mtok": 10.0}
|
| 493 |
+
```
|
| 494 |
+
|
| 495 |
+
### `class TraceState(TypedDict)`
|
| 496 |
+
|
| 497 |
+
```python
|
| 498 |
+
class TraceState(TypedDict):
|
| 499 |
+
state_id: str # unique within the trace
|
| 500 |
+
messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt
|
| 501 |
+
student_action: str # what the student actually did at this step
|
| 502 |
+
```
|
| 503 |
+
|
| 504 |
+
One step of a frozen agentic trace. `student_action` is the raw text emitted by the student; teachers are queried with `messages` and asked to predict the assistant's next action.
|
| 505 |
+
|
| 506 |
+
```python
|
| 507 |
+
state: TraceState = {"state_id": "ex001::0042",
|
| 508 |
+
"messages": [{"role": "user", "content": "..."}],
|
| 509 |
+
"student_action": "[TOOL_USE] name=Read input={...}"}
|
| 510 |
+
```
|
| 511 |
+
|
| 512 |
+
### `class TeacherCallResult(TypedDict)`
|
| 513 |
+
|
| 514 |
+
```python
|
| 515 |
+
class TeacherCallResult(TypedDict):
|
| 516 |
+
state_id: str
|
| 517 |
+
teacher_slug: str
|
| 518 |
+
response_text: str | None # None on error
|
| 519 |
+
latency_s: float
|
| 520 |
+
prompt_tokens: int
|
| 521 |
+
completion_tokens: int
|
| 522 |
+
cost_usd: float
|
| 523 |
+
error: str | None # None on success
|
| 524 |
+
```
|
| 525 |
+
|
| 526 |
+
One row of N×T results from `replay_trace`.
|
| 527 |
+
|
| 528 |
+
```python
|
| 529 |
+
r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5",
|
| 530 |
+
"response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100,
|
| 531 |
+
"completion_tokens": 5, "cost_usd": 0.001, "error": None}
|
| 532 |
+
```
|
| 533 |
+
|
| 534 |
+
### `class DPOPair(TypedDict)`
|
| 535 |
+
|
| 536 |
+
```python
|
| 537 |
+
class DPOPair(TypedDict):
|
| 538 |
+
state_id: str
|
| 539 |
+
state_messages: list[dict]
|
| 540 |
+
chosen: str # teacher-consensus action
|
| 541 |
+
rejected: str # student action
|
| 542 |
+
n_teachers_agreeing: int
|
| 543 |
+
```
|
| 544 |
+
|
| 545 |
+
One preference pair extracted from teacher-vs-student disagreement.
|
| 546 |
+
|
| 547 |
+
```python
|
| 548 |
+
p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...",
|
| 549 |
+
"rejected": "...", "n_teachers_agreeing": 2}
|
| 550 |
+
```
|
| 551 |
+
|
| 552 |
+
### `async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]`
|
| 553 |
+
|
| 554 |
+
```python
|
| 555 |
+
async def replay_trace(
|
| 556 |
+
states: Sequence[TraceState],
|
| 557 |
+
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
|
| 558 |
+
max_total_usd: float = 5.0,
|
| 559 |
+
api_key: str | None = None,
|
| 560 |
+
) -> list[TeacherCallResult]
|
| 561 |
+
```
|
| 562 |
+
|
| 563 |
+
For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at `max_total_usd` (stops after the offending state completes).
|
| 564 |
+
|
| 565 |
+
**Parameters**
|
| 566 |
+
|
| 567 |
+
| Name | Type | Default | Meaning |
|
| 568 |
+
|---|---|---|---|
|
| 569 |
+
| `states` | `Sequence[TraceState]` | — | Frozen trace, one entry per assistant turn. |
|
| 570 |
+
| `teachers` | `Sequence[TeacherSpec]` | `DEFAULT_TEACHERS` | Models to query in parallel. |
|
| 571 |
+
| `max_total_usd` | `float` | `5.0` | Cumulative spend cap. |
|
| 572 |
+
| `api_key` | `str \| None` | `None` | OpenRouter key; defaults to `OPENROUTER_API_KEY` env or `~/.hermes/.env`. |
|
| 573 |
+
|
| 574 |
+
**Returns** flat list of `TeacherCallResult`s (length `len(states) * len(teachers)` modulo budget cutoff).
|
| 575 |
+
|
| 576 |
+
**Raises** `RuntimeError` if `OPENROUTER_API_KEY` is not findable; `ImportError` if `httpx` is missing at call time.
|
| 577 |
+
|
| 578 |
+
```python
|
| 579 |
+
import asyncio
|
| 580 |
+
from composer_replication import replay_trace
|
| 581 |
+
results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0))
|
| 582 |
+
```
|
| 583 |
+
|
| 584 |
+
### `extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]`
|
| 585 |
+
|
| 586 |
+
```python
|
| 587 |
+
def extract_dpo_pairs(
|
| 588 |
+
states: Sequence[TraceState],
|
| 589 |
+
teacher_actions: Sequence[TeacherCallResult],
|
| 590 |
+
agreement_threshold: int = 2,
|
| 591 |
+
) -> list[DPOPair]
|
| 592 |
+
```
|
| 593 |
+
|
| 594 |
+
Group teacher_actions by `state_id`, normalize whitespace, and emit one `DPOPair` per state where ≥`agreement_threshold` teachers agreed on an action that differs from the student's. `chosen` is the original (un-normalized) teacher response text.
|
| 595 |
+
|
| 596 |
+
**Parameters**
|
| 597 |
+
|
| 598 |
+
| Name | Type | Default | Meaning |
|
| 599 |
+
|---|---|---|---|
|
| 600 |
+
| `states` | `Sequence[TraceState]` | — | Same as passed to `replay_trace`. |
|
| 601 |
+
| `teacher_actions` | `Sequence[TeacherCallResult]` | — | Output of `replay_trace`. |
|
| 602 |
+
| `agreement_threshold` | `int` | `2` | Min teachers that must agree for a pair to fire. |
|
| 603 |
+
|
| 604 |
+
**Returns** list of `DPOPair`. At most one pair per state (the most-agreed-upon action wins).
|
| 605 |
+
|
| 606 |
+
```python
|
| 607 |
+
from composer_replication import extract_dpo_pairs
|
| 608 |
+
pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2)
|
| 609 |
+
```
|
| 610 |
+
|
| 611 |
+
### `save_pairs(pairs, path) -> None`
|
| 612 |
+
|
| 613 |
+
⚠️ UNTESTED-CONTRACT.
|
| 614 |
+
|
| 615 |
+
```python
|
| 616 |
+
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None
|
| 617 |
+
```
|
| 618 |
+
|
| 619 |
+
Write pairs to JSONL (one dict per line). Creates parent dirs.
|
| 620 |
+
|
| 621 |
+
```python
|
| 622 |
+
from composer_replication.teacher_replay import save_pairs
|
| 623 |
+
save_pairs(pairs, "/tmp/dpo_pairs.jsonl")
|
| 624 |
+
```
|
| 625 |
+
|
| 626 |
+
---
|
| 627 |
+
|
| 628 |
+
## 7. `composer_replication.replaysim`
|
| 629 |
+
|
| 630 |
+
ADR-004 normalization layer over `teacher_replay`. Re-exports `DPOPair`, `TeacherCallResult`, `extract_dpo_pairs`, `replay_trace` from `teacher_replay`.
|
| 631 |
+
|
| 632 |
+
### `class NormalizedDPOPair`
|
| 633 |
+
|
| 634 |
+
```python
|
| 635 |
+
@dataclass
|
| 636 |
+
class NormalizedDPOPair:
|
| 637 |
+
state_id: str
|
| 638 |
+
state_messages: list[dict[str, Any]]
|
| 639 |
+
chosen_messages: list[dict[str, Any]]
|
| 640 |
+
rejected_messages: list[dict[str, Any]]
|
| 641 |
+
n_teachers_agreeing: int
|
| 642 |
+
metadata: dict[str, Any]
|
| 643 |
+
```
|
| 644 |
+
|
| 645 |
+
Post-normalization shape. `chosen_messages`/`rejected_messages` are chat-format (`[{"role": "assistant", "content": ...}]`). `metadata` carries op-graph provenance, including `{"skipped": True}` when the normalizer was bypassed (`skip_dj=True`).
|
| 646 |
+
|
| 647 |
+
```python
|
| 648 |
+
from composer_replication.replaysim import NormalizedDPOPair
|
| 649 |
+
n = NormalizedDPOPair(state_id="x", state_messages=[],
|
| 650 |
+
chosen_messages=[{"role": "assistant", "content": "ok"}],
|
| 651 |
+
rejected_messages=[{"role": "assistant", "content": "no"}],
|
| 652 |
+
n_teachers_agreeing=2, metadata={})
|
| 653 |
+
```
|
| 654 |
+
|
| 655 |
+
### `class DJNormalizer`
|
| 656 |
+
|
| 657 |
+
```python
|
| 658 |
+
class DJNormalizer:
|
| 659 |
+
DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml
|
| 660 |
+
|
| 661 |
+
def __init__(
|
| 662 |
+
self,
|
| 663 |
+
recipe_path: str | os.PathLike[str] | None = None,
|
| 664 |
+
*,
|
| 665 |
+
skip_dj: bool = False,
|
| 666 |
+
) -> None: ...
|
| 667 |
+
|
| 668 |
+
def normalize(
|
| 669 |
+
self,
|
| 670 |
+
pairs: Iterable[DPOPair | dict[str, Any]],
|
| 671 |
+
) -> list[NormalizedDPOPair]: ...
|
| 672 |
+
```
|
| 673 |
+
|
| 674 |
+
`data-juicer`-backed normalizer. Pipeline: each `DPOPair` → JSONL record → `data_juicer.core.DefaultExecutor.run()` against the recipe → JSONL → `NormalizedDPOPair`.
|
| 675 |
+
|
| 676 |
+
**Constructor parameters**
|
| 677 |
+
|
| 678 |
+
| Name | Type | Default | Meaning |
|
| 679 |
+
|---|---|---|---|
|
| 680 |
+
| `recipe_path` | `str \| PathLike \| None` | `None` ⇒ default recipe | data-juicer YAML recipe path. |
|
| 681 |
+
| `skip_dj` | `bool` (kw-only) | `False` | If True: passthrough; records get `metadata={"skipped": True}` and no ops run. |
|
| 682 |
+
|
| 683 |
+
**`normalize(pairs) -> list[NormalizedDPOPair]`** runs the op-graph. Output may be shorter than input if filter ops drop records.
|
| 684 |
+
|
| 685 |
+
**Raises** `RuntimeError` at construction time if `skip_dj=False` and `data_juicer` is not importable. `FileNotFoundError` if `recipe_path` (default or explicit) is missing and `skip_dj=False`.
|
| 686 |
+
|
| 687 |
+
```python
|
| 688 |
+
from composer_replication.replaysim import DJNormalizer
|
| 689 |
+
norm = DJNormalizer(skip_dj=True)
|
| 690 |
+
out = norm.normalize(my_pairs)
|
| 691 |
+
```
|
| 692 |
+
|
| 693 |
+
### `async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]`
|
| 694 |
+
|
| 695 |
+
```python
|
| 696 |
+
async def replay_and_normalize_trace(
|
| 697 |
+
*,
|
| 698 |
+
states: Any,
|
| 699 |
+
teachers: Any = None,
|
| 700 |
+
agreement_threshold: int = 2,
|
| 701 |
+
max_total_usd: float = 5.0,
|
| 702 |
+
normalizer: DJNormalizer | None = None,
|
| 703 |
+
**replay_kwargs: Any,
|
| 704 |
+
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
|
| 705 |
+
```
|
| 706 |
+
|
| 707 |
+
End-to-end async: replay → extract pairs → normalize.
|
| 708 |
+
|
| 709 |
+
**Parameters**
|
| 710 |
+
|
| 711 |
+
| Name | Type | Default | Meaning |
|
| 712 |
+
|---|---|---|---|
|
| 713 |
+
| `states` | `Sequence[TraceState]` | — | Frozen trace. |
|
| 714 |
+
| `teachers` | `Sequence[TeacherSpec] \| None` | `None` ⇒ defaults | Forwarded to `replay_trace`. |
|
| 715 |
+
| `agreement_threshold` | `int` | `2` | Forwarded to `extract_dpo_pairs`. |
|
| 716 |
+
| `max_total_usd` | `float` | `5.0` | Spend cap. |
|
| 717 |
+
| `normalizer` | `DJNormalizer \| None` | `None` ⇒ `DJNormalizer()` | Pass `DJNormalizer(skip_dj=True)` to bypass. |
|
| 718 |
+
| `**replay_kwargs` | `Any` | — | Forwarded to `replay_trace` (e.g. `api_key`). |
|
| 719 |
+
|
| 720 |
+
**Returns** `(raw_teacher_actions, normalized_pairs)`.
|
| 721 |
+
|
| 722 |
+
```python
|
| 723 |
+
import asyncio
|
| 724 |
+
from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer
|
| 725 |
+
raw, norm = asyncio.run(replay_and_normalize_trace(
|
| 726 |
+
states=my_states, normalizer=DJNormalizer(skip_dj=True)))
|
| 727 |
+
```
|
| 728 |
+
|
| 729 |
+
### `replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]`
|
| 730 |
+
|
| 731 |
+
⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via `asyncio.run`).
|
| 732 |
+
|
| 733 |
+
```python
|
| 734 |
+
def replay_and_normalize_trace_sync(*args, **kwargs) -> ...
|
| 735 |
+
```
|
| 736 |
+
|
| 737 |
+
Sync convenience wrapping `asyncio.run(replay_and_normalize_trace(...))`.
|
| 738 |
+
|
| 739 |
+
```python
|
| 740 |
+
from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync
|
| 741 |
+
raw, norm = replay_and_normalize_trace_sync(states=my_states)
|
| 742 |
+
```
|
| 743 |
+
|
| 744 |
+
---
|
| 745 |
+
|
| 746 |
+
## 8. `composer_replication.ingestion` & `composer_replication.ingestion.claude_code`
|
| 747 |
+
|
| 748 |
+
Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL.
|
| 749 |
+
|
| 750 |
+
### `SYSTEM_PROMPT: str`
|
| 751 |
+
|
| 752 |
+
Default synthetic system prompt injected at `messages[0]` for ingested traces (most Claude Code sessions don't write one). Truncated head: `"You are a senior software engineer working as a coding agent in a terminal environment..."`.
|
| 753 |
+
|
| 754 |
+
```python
|
| 755 |
+
from composer_replication import SYSTEM_PROMPT
|
| 756 |
+
print(SYSTEM_PROMPT[:60])
|
| 757 |
+
```
|
| 758 |
+
|
| 759 |
+
### `class IngestionStats`
|
| 760 |
+
|
| 761 |
+
```python
|
| 762 |
+
@dataclass
|
| 763 |
+
class IngestionStats:
|
| 764 |
+
n_records_total: int = 0
|
| 765 |
+
n_records_skipped: int = 0
|
| 766 |
+
n_states_emitted: int = 0
|
| 767 |
+
n_assistant_turns: int = 0
|
| 768 |
+
n_tool_use_blocks: int = 0
|
| 769 |
+
n_text_blocks: int = 0
|
| 770 |
+
skipped_subagent: int = 0
|
| 771 |
+
skipped_summary: int = 0
|
| 772 |
+
skipped_truncated_lines: int = 0
|
| 773 |
+
version_warnings: list[str] | None = None # initialized to [] in __post_init__
|
| 774 |
+
```
|
| 775 |
+
|
| 776 |
+
Counters populated by `ClaudeCodeIngester.ingest()` and exposed as `ingester.last_stats`.
|
| 777 |
+
|
| 778 |
+
```python
|
| 779 |
+
from composer_replication import IngestionStats
|
| 780 |
+
s = IngestionStats(n_records_total=5)
|
| 781 |
+
print(s.version_warnings) # []
|
| 782 |
+
```
|
| 783 |
+
|
| 784 |
+
### `class ClaudeCodeIngester`
|
| 785 |
+
|
| 786 |
+
```python
|
| 787 |
+
class ClaudeCodeIngester:
|
| 788 |
+
def __init__(
|
| 789 |
+
self,
|
| 790 |
+
*,
|
| 791 |
+
system_prompt: str = SYSTEM_PROMPT,
|
| 792 |
+
skip_sidechain: bool = True,
|
| 793 |
+
strip_thinking: bool = True,
|
| 794 |
+
max_history_tokens: int | None = None,
|
| 795 |
+
) -> None: ...
|
| 796 |
+
|
| 797 |
+
def ingest(self, path: Path) -> Iterator[TraceState]: ...
|
| 798 |
+
```
|
| 799 |
+
|
| 800 |
+
Convert a Claude Code session JSONL to a stream of `TraceState`s — one per assistant TURN (not per `tool_use` block).
|
| 801 |
+
|
| 802 |
+
**Constructor parameters**
|
| 803 |
+
|
| 804 |
+
| Name | Type | Default | Meaning |
|
| 805 |
+
|---|---|---|---|
|
| 806 |
+
| `system_prompt` | `str` | `SYSTEM_PROMPT` | Synthetic system message injected at history[0]. |
|
| 807 |
+
| `skip_sidechain` | `bool` | `True` | Skip subagent files (`agent-*.jsonl`) and records with `isSidechain=True`. |
|
| 808 |
+
| `strip_thinking` | `bool` | `True` | Remove `[THINKING]` blocks from history handed to teachers (kept inside `student_action`). |
|
| 809 |
+
| `max_history_tokens` | `int \| None` | `None` | ⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate. |
|
| 810 |
+
|
| 811 |
+
**`ingest(path) -> Iterator[TraceState]`**: generator over `TraceState` objects. Each turn's `state_id` is `f"{path.stem}::{idx:04d}"`. Side effect: replaces `self.last_stats` with a fresh `IngestionStats` and updates it as records stream.
|
| 812 |
+
|
| 813 |
+
```python
|
| 814 |
+
from pathlib import Path
|
| 815 |
+
from composer_replication import ClaudeCodeIngester
|
| 816 |
+
ing = ClaudeCodeIngester()
|
| 817 |
+
for state in ing.ingest(Path("session.jsonl")):
|
| 818 |
+
print(state["state_id"])
|
| 819 |
+
print(ing.last_stats.n_states_emitted)
|
| 820 |
+
```
|
| 821 |
+
|
| 822 |
+
---
|
| 823 |
+
|
| 824 |
+
## 9. `composer_replication.hint_generator`
|
| 825 |
+
|
| 826 |
+
⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test).
|
| 827 |
+
|
| 828 |
+
Template-based hint registry for SDPO error-site injection.
|
| 829 |
+
|
| 830 |
+
### `class HintContext(TypedDict, total=False)`
|
| 831 |
+
|
| 832 |
+
```python
|
| 833 |
+
class HintContext(TypedDict, total=False):
|
| 834 |
+
error_kind: str
|
| 835 |
+
error_message: str
|
| 836 |
+
available_tools: list[str]
|
| 837 |
+
tool_name: str
|
| 838 |
+
tool_schema: dict
|
| 839 |
+
intent: str
|
| 840 |
+
```
|
| 841 |
+
|
| 842 |
+
Per-error context dict consumed by hint templates.
|
| 843 |
+
|
| 844 |
+
### `HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]`
|
| 845 |
+
|
| 846 |
+
Default registry keys: `"tool_not_found"`, `"json_decode"`, `"type_error"`, `"runtime_error"`, `"repeated_failure"`.
|
| 847 |
+
|
| 848 |
+
### `dispatch(error_kind, ctx=None) -> str | None`
|
| 849 |
+
|
| 850 |
+
```python
|
| 851 |
+
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None
|
| 852 |
+
```
|
| 853 |
+
|
| 854 |
+
Look up `error_kind` in `HINT_TEMPLATES`. Returns the template's hint text, or `None` if the kind is unknown.
|
| 855 |
+
|
| 856 |
+
```python
|
| 857 |
+
from composer_replication.hint_generator import dispatch
|
| 858 |
+
hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..."
|
| 859 |
+
```
|
| 860 |
+
|
| 861 |
+
### `register(error_kind, fn) -> None`
|
| 862 |
+
|
| 863 |
+
```python
|
| 864 |
+
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None
|
| 865 |
+
```
|
| 866 |
+
|
| 867 |
+
Add or override a custom hint template.
|
| 868 |
+
|
| 869 |
+
```python
|
| 870 |
+
from composer_replication.hint_generator import register
|
| 871 |
+
register("my_error", lambda ctx: "Reminder: try X.")
|
| 872 |
+
```
|
| 873 |
+
|
| 874 |
+
### Individual template functions
|
| 875 |
+
|
| 876 |
+
⚠️ UNTESTED-CONTRACT — exported only via `HINT_TEMPLATES`, useful as building blocks:
|
| 877 |
+
|
| 878 |
+
- `hint_tool_not_found(ctx) -> str`
|
| 879 |
+
- `hint_json_decode(ctx) -> str`
|
| 880 |
+
- `hint_type_error(ctx) -> str`
|
| 881 |
+
- `hint_runtime_error(ctx) -> str`
|
| 882 |
+
- `hint_repeated_failure(ctx) -> str`
|
| 883 |
+
|
| 884 |
+
Each accepts a `HintContext` and returns hint text. Signatures are uniform: `Callable[[HintContext], str]`.
|
| 885 |
+
|
| 886 |
+
```python
|
| 887 |
+
from composer_replication.hint_generator import hint_tool_not_found
|
| 888 |
+
text = hint_tool_not_found({"available_tools": ["Read", "Write"]})
|
| 889 |
+
```
|
| 890 |
+
|
| 891 |
+
---
|
| 892 |
+
|
| 893 |
+
## 10. `composer_replication.trainer` & sub-modules
|
| 894 |
+
|
| 895 |
+
Production trainer (TRL `GRPOTrainer` subclass) plus data collator.
|
| 896 |
+
|
| 897 |
+
### `class ComposerReplicationTrainer`
|
| 898 |
+
|
| 899 |
+
```python
|
| 900 |
+
class ComposerReplicationTrainer(GRPOTrainer):
|
| 901 |
+
def __init__(
|
| 902 |
+
self,
|
| 903 |
+
*args: Any,
|
| 904 |
+
alpha_sdpo: float = 0.1,
|
| 905 |
+
beta_replay: float = 0.05,
|
| 906 |
+
sdpo_jsd_beta: float = 0.5,
|
| 907 |
+
sdpo_temperature: float = 1.0,
|
| 908 |
+
sdpo_token_clip: float | None = None,
|
| 909 |
+
replay_dpo_beta: float = 0.1,
|
| 910 |
+
**kwargs: Any,
|
| 911 |
+
) -> None: ...
|
| 912 |
+
|
| 913 |
+
def _compute_loss(
|
| 914 |
+
self,
|
| 915 |
+
model: torch.nn.Module,
|
| 916 |
+
inputs: dict[str, torch.Tensor],
|
| 917 |
+
) -> torch.Tensor: ...
|
| 918 |
+
```
|
| 919 |
+
|
| 920 |
+
`trl.GRPOTrainer` subclass that overrides `_compute_loss(model, inputs)` to compose `total = grpo + α·sdpo + β·trace_replay_dpo`. When `trl` is not installed, the parent class falls back to `object` so the module imports — but instantiation will fail because the parent's GRPO machinery is missing.
|
| 921 |
+
|
| 922 |
+
**Constructor (kw-only beyond GRPOTrainer's own `*args, **kwargs`)**
|
| 923 |
+
|
| 924 |
+
| Name | Type | Default | Meaning |
|
| 925 |
+
|---|---|---|---|
|
| 926 |
+
| `alpha_sdpo` | `float` | `0.1` | Channel-2 weight. |
|
| 927 |
+
| `beta_replay` | `float` | `0.05` | Channel-3 weight. |
|
| 928 |
+
| `sdpo_jsd_beta` | `float` | `0.5` | β for `generalized_jsd_loss`. |
|
| 929 |
+
| `sdpo_temperature` | `float` | `1.0` | SDPO softmax temperature. |
|
| 930 |
+
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clip. |
|
| 931 |
+
| `replay_dpo_beta` | `float` | `0.1` | DPO β. |
|
| 932 |
+
|
| 933 |
+
**`_compute_loss(model, inputs) -> torch.Tensor`** — overrides `GRPOTrainer._compute_loss`. Calls `super()._compute_loss` for channel 1, then `_compute_sdpo_loss` and `_compute_trace_replay_loss`, then composes. Logs per-channel components every `args.logging_steps` (default 50). **Raises** whatever `super()` raises (TRL-shaped errors).
|
| 934 |
+
|
| 935 |
+
**Internal methods (publicly accessible, exercised by spike tests)**
|
| 936 |
+
|
| 937 |
+
- ⚠️ UNTESTED-CONTRACT `_compute_sdpo_loss(model, inputs) -> torch.Tensor` — generalized-JSD between student forward and `ctx_teacher_input_ids` forward. Returns `0.0` (with grad) when `alpha_sdpo == 0`, the key is missing, or shapes mismatch. Logs a warning on shape mismatch.
|
| 938 |
+
- ⚠️ UNTESTED-CONTRACT `_compute_trace_replay_loss(model, inputs) -> torch.Tensor` — standard DPO over `dpo_chosen_*` and `dpo_rejected_*`, using precomputed `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`.
|
| 939 |
+
- ⚠️ UNTESTED-CONTRACT `@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor` — sum logprobs over response tokens; standard DPO accounting.
|
| 940 |
+
|
| 941 |
+
```python
|
| 942 |
+
from composer_replication import ComposerReplicationTrainer
|
| 943 |
+
trainer = ComposerReplicationTrainer(
|
| 944 |
+
model=my_model, args=my_grpo_args, train_dataset=ds,
|
| 945 |
+
data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05,
|
| 946 |
+
)
|
| 947 |
+
# trainer.train() # uses overridden _compute_loss
|
| 948 |
+
```
|
| 949 |
+
|
| 950 |
+
### `class TraceTurn(TypedDict, total=False)` — `trainer.data_collator`
|
| 951 |
+
|
| 952 |
+
```python
|
| 953 |
+
class TraceTurn(TypedDict, total=False):
|
| 954 |
+
role: str # "user" | "assistant" | "tool"
|
| 955 |
+
content: str
|
| 956 |
+
tool_call: dict | None
|
| 957 |
+
tool_error: str | None
|
| 958 |
+
error_meta: dict
|
| 959 |
+
```
|
| 960 |
+
|
| 961 |
+
One turn of an agentic trace as consumed by `ComposerDataCollator`.
|
| 962 |
+
|
| 963 |
+
### `class TraceExample(TypedDict, total=False)` — `trainer.data_collator`
|
| 964 |
+
|
| 965 |
+
```python
|
| 966 |
+
class TraceExample(TypedDict, total=False):
|
| 967 |
+
trace_id: str
|
| 968 |
+
turns: list[TraceTurn]
|
| 969 |
+
final_reward: float
|
| 970 |
+
dpo_pairs: list[dict] | None
|
| 971 |
+
```
|
| 972 |
+
|
| 973 |
+
One training example: `(turns, optional dpo_pairs)`. `dpo_pairs` shape matches `DPOPair`.
|
| 974 |
+
|
| 975 |
+
### `class TokenizerLike` — `trainer.data_collator`
|
| 976 |
+
|
| 977 |
+
⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint).
|
| 978 |
+
|
| 979 |
+
```python
|
| 980 |
+
class TokenizerLike:
|
| 981 |
+
pad_token_id: int
|
| 982 |
+
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ...
|
| 983 |
+
def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ...
|
| 984 |
+
```
|
| 985 |
+
|
| 986 |
+
Minimal protocol the collator needs. Compatible with HF `AutoTokenizer`.
|
| 987 |
+
|
| 988 |
+
### `class CollatorConfig` — `trainer.data_collator`
|
| 989 |
+
|
| 990 |
+
```python
|
| 991 |
+
@dataclass
|
| 992 |
+
class CollatorConfig:
|
| 993 |
+
max_seq_len: int = 4096
|
| 994 |
+
max_dpo_seq_len: int = 2048
|
| 995 |
+
pad_token_id: int = 0
|
| 996 |
+
ignore_index: int = -100
|
| 997 |
+
enable_sdpo: bool = True
|
| 998 |
+
hint_generator: Callable[[str, dict], str | None] | None = None
|
| 999 |
+
enable_replay_dpo: bool = True
|
| 1000 |
+
rlvr_reward_key: str = "final_reward"
|
| 1001 |
+
```
|
| 1002 |
+
|
| 1003 |
+
Tunables for `ComposerDataCollator`.
|
| 1004 |
+
|
| 1005 |
+
| Field | Default | Meaning |
|
| 1006 |
+
|---|---|---|
|
| 1007 |
+
| `max_seq_len` | `4096` | Truncation cap for student/teacher sequences. |
|
| 1008 |
+
| `max_dpo_seq_len` | `2048` | Truncation cap for DPO chosen/rejected sequences. |
|
| 1009 |
+
| `pad_token_id` | `0` | Padding token id. |
|
| 1010 |
+
| `ignore_index` | `-100` | HF "ignore in loss" sentinel for SDPO mask. |
|
| 1011 |
+
| `enable_sdpo` | `True` | Toggle channel-2 fields. |
|
| 1012 |
+
| `hint_generator` | `Callable[[str, dict], str \| None] \| None` (`None`) | `(error_kind, error_meta) -> hint_text`. SDPO is no-op without this. |
|
| 1013 |
+
| `enable_replay_dpo` | `True` | Toggle channel-3 fields. |
|
| 1014 |
+
| `rlvr_reward_key` | `"final_reward"` | Key in `TraceExample` to read scalar reward. |
|
| 1015 |
+
|
| 1016 |
+
```python
|
| 1017 |
+
from composer_replication.trainer.data_collator import CollatorConfig
|
| 1018 |
+
cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch)
|
| 1019 |
+
```
|
| 1020 |
+
|
| 1021 |
+
### `class ComposerDataCollator` — `trainer.data_collator`
|
| 1022 |
+
|
| 1023 |
+
```python
|
| 1024 |
+
@dataclass
|
| 1025 |
+
class ComposerDataCollator:
|
| 1026 |
+
tokenizer: TokenizerLike
|
| 1027 |
+
config: CollatorConfig = field(default_factory=CollatorConfig)
|
| 1028 |
+
|
| 1029 |
+
def __call__(
|
| 1030 |
+
self, batch: Sequence[TraceExample]
|
| 1031 |
+
) -> dict[str, torch.Tensor]: ...
|
| 1032 |
+
```
|
| 1033 |
+
|
| 1034 |
+
Build trainer-ready batches from raw traces + optional DPO pairs.
|
| 1035 |
+
|
| 1036 |
+
**Output dict keys** (tested in `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`):
|
| 1037 |
+
|
| 1038 |
+
- Channel 1 (always): `input_ids`, `attention_mask`, `response_mask`, `rewards`.
|
| 1039 |
+
- Channel 2 (when `enable_sdpo=True` AND batch has at least one error site AND `hint_generator` is set): `ctx_teacher_input_ids`, `sdpo_loss_mask`.
|
| 1040 |
+
- Channel 3 (when `enable_replay_dpo=True` AND batch has at least one `dpo_pair`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`. (Reference logprobs are NOT computed here — the trainer does that pass.)
|
| 1041 |
+
|
| 1042 |
+
```python
|
| 1043 |
+
from composer_replication.trainer.data_collator import (
|
| 1044 |
+
ComposerDataCollator, CollatorConfig)
|
| 1045 |
+
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
|
| 1046 |
+
batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}])
|
| 1047 |
+
```
|
| 1048 |
+
|
| 1049 |
+
---
|
| 1050 |
+
|
| 1051 |
+
## 11. `composer_replication.diloco`
|
| 1052 |
+
|
| 1053 |
+
DiLoCo outer-loop wrapper around `torchft.local_sgd.DiLoCo`. Optional dep — when `torchft` is missing the package re-export `composer_replication.make_diloco_outer_loop` is `None`.
|
| 1054 |
+
|
| 1055 |
+
### Module-level attributes
|
| 1056 |
+
|
| 1057 |
+
- `DiLoCo: Any` — `torchft.local_sgd.DiLoCo` if importable else `None`.
|
| 1058 |
+
- `Manager: Any` — `torchft.manager.Manager` if importable else `None`.
|
| 1059 |
+
- `_DummyWork: Any` — `torchft.work._DummyWork` if importable else `None`.
|
| 1060 |
+
- `_TORCHFT_AVAILABLE: bool` — whether the imports succeeded.
|
| 1061 |
+
|
| 1062 |
+
```python
|
| 1063 |
+
from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo
|
| 1064 |
+
```
|
| 1065 |
+
|
| 1066 |
+
### `make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo`
|
| 1067 |
+
|
| 1068 |
+
```python
|
| 1069 |
+
def make_diloco_outer_loop(
|
| 1070 |
+
manager: Any,
|
| 1071 |
+
model_fragments: list[torch.nn.Module],
|
| 1072 |
+
inner_optimizer: torch.optim.Optimizer,
|
| 1073 |
+
*,
|
| 1074 |
+
outer_lr: float = 0.7,
|
| 1075 |
+
outer_momentum: float = 0.9,
|
| 1076 |
+
nesterov: bool = True,
|
| 1077 |
+
sync_every: int = 100,
|
| 1078 |
+
fragment_sync_delay: int = 0,
|
| 1079 |
+
fragment_update_alpha: float = 0.0,
|
| 1080 |
+
) -> Any
|
| 1081 |
+
```
|
| 1082 |
+
|
| 1083 |
+
Construct a `torchft.DiLoCo` configured with framework-default hyperparams (DiLoCo paper §3.2: `lr=0.7, momentum=0.9, Nesterov`).
|
| 1084 |
+
|
| 1085 |
+
**Parameters**
|
| 1086 |
+
|
| 1087 |
+
| Name | Type | Default | Meaning |
|
| 1088 |
+
|---|---|---|---|
|
| 1089 |
+
| `manager` | `torchft.Manager` (or duck-typed `MockManager`) | — | Provides `allreduce`, `should_commit`, `current_step`, `start_quorum`, etc. |
|
| 1090 |
+
| `model_fragments` | `list[torch.nn.Module]` | — | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. |
|
| 1091 |
+
| `inner_optimizer` | `torch.optim.Optimizer` | — | Inner-step optimizer (steps every batch). |
|
| 1092 |
+
| `outer_lr` | `float` | `0.7` | Outer SGD lr. |
|
| 1093 |
+
| `outer_momentum` | `float` | `0.9` | Outer SGD momentum. |
|
| 1094 |
+
| `nesterov` | `bool` | `True` | Nesterov momentum on outer SGD. |
|
| 1095 |
+
| `sync_every` | `int` | `100` | Inner steps per outer round. |
|
| 1096 |
+
| `fragment_sync_delay` | `int` | `0` | 0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). |
|
| 1097 |
+
| `fragment_update_alpha` | `float` | `0.0` | 0 = full replacement on sync; >0 = exponential mix. |
|
| 1098 |
+
|
| 1099 |
+
**Returns** a `torchft.local_sgd.DiLoCo` instance — usable as a context manager.
|
| 1100 |
+
|
| 1101 |
+
**Raises** `RuntimeError` if `torchft` is not installed.
|
| 1102 |
+
|
| 1103 |
+
```python
|
| 1104 |
+
import torch
|
| 1105 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 1106 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
| 1107 |
+
outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model],
|
| 1108 |
+
inner_optimizer=opt, sync_every=100)
|
| 1109 |
+
with outer:
|
| 1110 |
+
for _ in range(N):
|
| 1111 |
+
opt.zero_grad(); loss.backward(); opt.step()
|
| 1112 |
+
```
|
| 1113 |
+
|
| 1114 |
+
---
|
| 1115 |
+
|
| 1116 |
+
## 12. `composer_replication.diloco.serverless`
|
| 1117 |
+
|
| 1118 |
+
ADR-005 serverless DiLoCo executors + object-store all-reduce.
|
| 1119 |
+
|
| 1120 |
+
### `class ReplicaHandle` — `serverless.executor`
|
| 1121 |
+
|
| 1122 |
+
```python
|
| 1123 |
+
@dataclass
|
| 1124 |
+
class ReplicaHandle:
|
| 1125 |
+
rank: int
|
| 1126 |
+
backend_name: str
|
| 1127 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 1128 |
+
```
|
| 1129 |
+
|
| 1130 |
+
Opaque handle returned by `ServerlessExecutor.launch_replicas`. `metadata` is backend-specific.
|
| 1131 |
+
|
| 1132 |
+
```python
|
| 1133 |
+
from composer_replication.diloco.serverless import ReplicaHandle
|
| 1134 |
+
h = ReplicaHandle(rank=0, backend_name="local_process",
|
| 1135 |
+
metadata={"pid": 12345})
|
| 1136 |
+
```
|
| 1137 |
+
|
| 1138 |
+
### `class ServerlessExecutor` (Protocol) — `serverless.executor`
|
| 1139 |
+
|
| 1140 |
+
```python
|
| 1141 |
+
@runtime_checkable
|
| 1142 |
+
class ServerlessExecutor(Protocol):
|
| 1143 |
+
backend_name: str
|
| 1144 |
+
supports_inter_replica_network: bool
|
| 1145 |
+
|
| 1146 |
+
def launch_replicas(
|
| 1147 |
+
self,
|
| 1148 |
+
n_replicas: int,
|
| 1149 |
+
entrypoint: str | Callable[..., Any],
|
| 1150 |
+
entrypoint_args: Mapping[str, Any],
|
| 1151 |
+
*,
|
| 1152 |
+
gpu: str | None = None,
|
| 1153 |
+
timeout: int = 3600,
|
| 1154 |
+
) -> list[ReplicaHandle]: ...
|
| 1155 |
+
|
| 1156 |
+
def poll(self, handle: ReplicaHandle) -> str: ...
|
| 1157 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ...
|
| 1158 |
+
def cancel(self, handle: ReplicaHandle) -> None: ...
|
| 1159 |
+
def collect(
|
| 1160 |
+
self, handles: list[ReplicaHandle], *, timeout: int | None = None,
|
| 1161 |
+
) -> list[dict[str, Any]]: ...
|
| 1162 |
+
```
|
| 1163 |
+
|
| 1164 |
+
Structural protocol for serverless backends.
|
| 1165 |
+
|
| 1166 |
+
- `launch_replicas(...)` returns `list[ReplicaHandle]` of length `n_replicas` in rank order. `entrypoint` is either an importable module path (uses `main()`) or a `module.function` path or a `Callable` (Local executor only). `entrypoint_args` may include `rank_env` (default `"REPLICA_RANK"`).
|
| 1167 |
+
- `poll(handle) -> str`: one of `"pending"`, `"running"`, `"succeeded"`, `"failed"`, `"cancelled"`.
|
| 1168 |
+
- `stream_logs(handle, n_lines=200) -> str`: best-effort recent stdout/stderr.
|
| 1169 |
+
- `cancel(handle) -> None`: best-effort.
|
| 1170 |
+
- `collect(handles, timeout=None) -> list[dict]`: blocks; each result dict has `rank`, `status`, `exit_code`, `error` (and `result` from `LocalProcessExecutor`).
|
| 1171 |
+
|
| 1172 |
+
```python
|
| 1173 |
+
from composer_replication.diloco.serverless import ServerlessExecutor
|
| 1174 |
+
def supports(x: ServerlessExecutor) -> bool:
|
| 1175 |
+
return isinstance(x, ServerlessExecutor) # runtime_checkable
|
| 1176 |
+
```
|
| 1177 |
+
|
| 1178 |
+
### `class LocalProcessExecutor` — `serverless.executor`
|
| 1179 |
+
|
| 1180 |
+
```python
|
| 1181 |
+
class LocalProcessExecutor:
|
| 1182 |
+
backend_name = "local_process"
|
| 1183 |
+
supports_inter_replica_network = True
|
| 1184 |
+
|
| 1185 |
+
def __init__(self) -> None: ...
|
| 1186 |
+
# implements ServerlessExecutor protocol
|
| 1187 |
+
```
|
| 1188 |
+
|
| 1189 |
+
Reference implementation using Python `multiprocessing` (`spawn` context). Used for tests, CI smokes, and local development with `file://` rendezvous.
|
| 1190 |
+
|
| 1191 |
+
`launch_replicas(...)`: emits a soft warning on `gpu != None` (local processes share whatever GPUs are visible). `metadata = {"pid": ..., "start_ts": ...}`.
|
| 1192 |
+
|
| 1193 |
+
```python
|
| 1194 |
+
from composer_replication.diloco.serverless import LocalProcessExecutor
|
| 1195 |
+
ex = LocalProcessExecutor()
|
| 1196 |
+
handles = ex.launch_replicas(
|
| 1197 |
+
n_replicas=2,
|
| 1198 |
+
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
|
| 1199 |
+
entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2,
|
| 1200 |
+
"trainer_module": "my.trainer"},
|
| 1201 |
+
)
|
| 1202 |
+
results = ex.collect(handles, timeout=60)
|
| 1203 |
+
```
|
| 1204 |
+
|
| 1205 |
+
### `class ObjectStoreAllReduce` — `serverless.allreduce`
|
| 1206 |
+
|
| 1207 |
+
```python
|
| 1208 |
+
class ObjectStoreAllReduce:
|
| 1209 |
+
def __init__(
|
| 1210 |
+
self,
|
| 1211 |
+
uri: str,
|
| 1212 |
+
rank: int,
|
| 1213 |
+
world_size: int,
|
| 1214 |
+
*,
|
| 1215 |
+
round_id: int | None = None,
|
| 1216 |
+
timeout_s: float = 1800.0,
|
| 1217 |
+
poll_interval_s: float = 1.0,
|
| 1218 |
+
) -> None: ...
|
| 1219 |
+
|
| 1220 |
+
@property
|
| 1221 |
+
def round_id(self) -> int: ...
|
| 1222 |
+
|
| 1223 |
+
def allreduce(
|
| 1224 |
+
self, tensor: torch.Tensor, *, name: str | None = None,
|
| 1225 |
+
) -> torch.Tensor: ...
|
| 1226 |
+
```
|
| 1227 |
+
|
| 1228 |
+
fsspec-backed pseudo-gradient rendezvous. `uri` accepts `s3://`, `gs://`, `az://`, `hf://`, `file://`, or a plain local path.
|
| 1229 |
+
|
| 1230 |
+
**Constructor parameters**
|
| 1231 |
+
|
| 1232 |
+
| Name | Type | Default | Meaning |
|
| 1233 |
+
|---|---|---|---|
|
| 1234 |
+
| `uri` | `str` | — | fsspec URI or local path. Trailing `/` enforced. |
|
| 1235 |
+
| `rank` | `int` | — | This replica's rank. |
|
| 1236 |
+
| `world_size` | `int` | — | Total replicas. |
|
| 1237 |
+
| `round_id` | `int \| None` (kw-only) | `None` ⇒ start at 0 | Initial round counter. |
|
| 1238 |
+
| `timeout_s` | `float` (kw-only) | `1800.0` | Per-`allreduce` timeout. |
|
| 1239 |
+
| `poll_interval_s` | `float` (kw-only) | `1.0` | Sleep between peer-file existence checks. |
|
| 1240 |
+
|
| 1241 |
+
**`allreduce(tensor, name=None) -> torch.Tensor`**: serializes `tensor.detach().cpu()` to `round_NNNNNN/rank_RRRR.pt`, blocks until all peers post, then averages. **Modifies `tensor` in place** AND returns it. Increments the internal `_round_counter`.
|
| 1242 |
+
|
| 1243 |
+
**Raises** `ValueError` on invalid `rank`, `RuntimeError` if non-local URI is requested without `fsspec` installed, `TimeoutError` if peers don't show up before `timeout_s`.
|
| 1244 |
+
|
| 1245 |
+
```python
|
| 1246 |
+
from composer_replication.diloco.serverless import ObjectStoreAllReduce
|
| 1247 |
+
import torch
|
| 1248 |
+
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
|
| 1249 |
+
g = torch.zeros(10)
|
| 1250 |
+
store.allreduce(g) # blocks for rank 1
|
| 1251 |
+
```
|
| 1252 |
+
|
| 1253 |
+
### `class MockManager` — `serverless.allreduce`
|
| 1254 |
+
|
| 1255 |
+
```python
|
| 1256 |
+
class MockManager:
|
| 1257 |
+
def __init__(self, store: ObjectStoreAllReduce) -> None: ...
|
| 1258 |
+
|
| 1259 |
+
# torchft.Manager-shaped surface:
|
| 1260 |
+
num_participants: int
|
| 1261 |
+
rank: int
|
| 1262 |
+
_use_async_quorum: bool # always False
|
| 1263 |
+
_step: int
|
| 1264 |
+
_state_dict_fns: dict[str, tuple[Any, Any]]
|
| 1265 |
+
|
| 1266 |
+
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ...
|
| 1267 |
+
def should_commit(self) -> bool: ...
|
| 1268 |
+
def start_quorum(self) -> None: ...
|
| 1269 |
+
def wait_quorum(self) -> int: ...
|
| 1270 |
+
def current_step(self) -> int: ...
|
| 1271 |
+
def allow_state_dict_read(self) -> None: ...
|
| 1272 |
+
def disallow_state_dict_read(self) -> None: ...
|
| 1273 |
+
def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ...
|
| 1274 |
+
def is_leader(self) -> bool: ...
|
| 1275 |
+
```
|
| 1276 |
+
|
| 1277 |
+
Drop-in replacement for `torchft.Manager` that routes `allreduce` through `ObjectStoreAllReduce`. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo.
|
| 1278 |
+
|
| 1279 |
+
- `allreduce(tensor)` returns an `_ImmediateWork` whose `.wait()` is a no-op (the tensor is already averaged).
|
| 1280 |
+
- `should_commit()` always `True` (no fault-tolerance failover).
|
| 1281 |
+
- `start_quorum()` bumps `_step`.
|
| 1282 |
+
- `is_leader()` returns `rank == 0`.
|
| 1283 |
+
|
| 1284 |
+
```python
|
| 1285 |
+
from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce
|
| 1286 |
+
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
|
| 1287 |
+
mgr = MockManager(store)
|
| 1288 |
+
# pass mgr into make_diloco_outer_loop(manager=mgr, ...)
|
| 1289 |
+
```
|
| 1290 |
+
|
| 1291 |
+
### `class _ImmediateWork` — `serverless.allreduce`
|
| 1292 |
+
|
| 1293 |
+
⚠️ UNTESTED-CONTRACT internal helper exported from `__all__`. `Work`-shaped wrapper with `.wait() -> True` and `.get_future() -> torch.futures.Future`. Consumed by torchft DiLoCo's `perform_sync`.
|
| 1294 |
+
|
| 1295 |
+
```python
|
| 1296 |
+
from composer_replication.diloco.serverless.allreduce import _ImmediateWork
|
| 1297 |
+
```
|
| 1298 |
+
|
| 1299 |
+
### `class ModalExecutor` — `serverless.modal`
|
| 1300 |
+
|
| 1301 |
+
🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 implementation pattern (Modal `app.function` + `function.spawn(rank=...)`).
|
| 1302 |
+
|
| 1303 |
+
```python
|
| 1304 |
+
from composer_replication.diloco.serverless.modal import ModalExecutor
|
| 1305 |
+
# ModalExecutor() # would NotImplementedError when instantiated
|
| 1306 |
+
```
|
| 1307 |
+
|
| 1308 |
+
### `class HFJobsExecutor` — `serverless.hf_jobs`
|
| 1309 |
+
|
| 1310 |
+
🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 pattern using `huggingface_hub.run_job` against `hf://datasets/.../` rendezvous.
|
| 1311 |
+
|
| 1312 |
+
```python
|
| 1313 |
+
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
|
| 1314 |
+
# instantiation will fail until v0 implementation lands
|
| 1315 |
+
```
|
| 1316 |
+
|
| 1317 |
+
### `replica_entrypoint.main(...)` — `serverless.replica_entrypoint`
|
| 1318 |
+
|
| 1319 |
+
```python
|
| 1320 |
+
def main(
|
| 1321 |
+
rendezvous_uri: str,
|
| 1322 |
+
world_size: int,
|
| 1323 |
+
trainer_module: str,
|
| 1324 |
+
trainer_fn: str = "train",
|
| 1325 |
+
trainer_kwargs: dict[str, Any] | None = None,
|
| 1326 |
+
) -> Any
|
| 1327 |
+
```
|
| 1328 |
+
|
| 1329 |
+
Script run by every replica. Reads `REPLICA_RANK` env var, builds `ObjectStoreAllReduce` + `MockManager`, imports `trainer_module`, and calls `getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...)`. Returns whatever the train fn returns.
|
| 1330 |
+
|
| 1331 |
+
**Raises** `RuntimeError` if `REPLICA_RANK` env var is missing; `ValueError` if rank ∉ `[0, world_size)`.
|
| 1332 |
+
|
| 1333 |
+
The `if __name__ == "__main__"` block accepts CLI flags `--rendezvous`, `--world-size`, `--trainer-module`, `--trainer-fn`, `--trainer-kwargs-json`.
|
| 1334 |
+
|
| 1335 |
+
```python
|
| 1336 |
+
# In-process invocation
|
| 1337 |
+
import os
|
| 1338 |
+
os.environ["REPLICA_RANK"] = "0"
|
| 1339 |
+
from composer_replication.diloco.serverless.replica_entrypoint import main
|
| 1340 |
+
result = main(rendezvous_uri="/tmp/run/", world_size=1,
|
| 1341 |
+
trainer_module="my.trainer", trainer_fn="train")
|
| 1342 |
+
```
|
| 1343 |
+
|
| 1344 |
+
---
|
| 1345 |
+
|
| 1346 |
+
## 13. `composer_replication.recipes.prime_rl.composer_loss`
|
| 1347 |
+
|
| 1348 |
+
PRIME-RL adapter (ADR-006). Maps PRIME-RL's `LossInputs` struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream `default_loss_fn` at `prime_rl/trainer/rl/loss.py` lines 116-165). Channel 2 raises `NotImplementedError`; channel 3 is out of scope.
|
| 1349 |
+
|
| 1350 |
+
### `loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor`
|
| 1351 |
+
|
| 1352 |
+
```python
|
| 1353 |
+
def loss_fn(
|
| 1354 |
+
inputs: Any, # PRIME-RL's LossInputs (duck-typed)
|
| 1355 |
+
*,
|
| 1356 |
+
alpha_sdpo: float = 0.0,
|
| 1357 |
+
beta_dpo: float = 0.0,
|
| 1358 |
+
dppo_mask_high: float = 0.2,
|
| 1359 |
+
dppo_mask_low: float = 0.2,
|
| 1360 |
+
adv_tau: float = 1.0,
|
| 1361 |
+
kl_tau: float = 1e-3,
|
| 1362 |
+
) -> Any # torch.Tensor scalar
|
| 1363 |
+
```
|
| 1364 |
+
|
| 1365 |
+
PRIME-RL passes per-sample **1-D `(seq,)` tensors** (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula:
|
| 1366 |
+
|
| 1367 |
+
- Mask gate is on **probability-space** `probs_diff = exp(trainer_lp) - exp(inference_lp)` (NOT on the log-ratio).
|
| 1368 |
+
- A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when `probs_diff > dppo_mask_high`, negative-advantage tokens when `probs_diff < -dppo_mask_low`. (PRIME-RL stores both bounds with `Field(..., ge=0)` and applies the sign internally.)
|
| 1369 |
+
- The PG term is `keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)` (importance-ratio corrected, not REINFORCE).
|
| 1370 |
+
- A KL penalty `kl_tau * log_importance_ratio**2` is added on the full `loss_mask` (DPPO masking does not gate it).
|
| 1371 |
+
- Reduction is a plain `sum()`; PRIME-RL's outer `compute_loss` divides by `loss_scale`.
|
| 1372 |
+
|
| 1373 |
+
**Parameters**
|
| 1374 |
+
|
| 1375 |
+
| Name | Type | Default | Meaning |
|
| 1376 |
+
|---|---|---|---|
|
| 1377 |
+
| `inputs` | PRIME-RL `LossInputs` (duck-typed) | — | Must expose `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` (all 1-D), and optionally `teacher_logprobs`. |
|
| 1378 |
+
| `alpha_sdpo` | `float` (kw-only) | `0.0` | Channel-2 weight. Must be `0` in v0; >0 → `NotImplementedError`. |
|
| 1379 |
+
| `beta_dpo` | `float` (kw-only) | `0.0` | Channel-3 weight. Non-zero emits a `UserWarning`. |
|
| 1380 |
+
| `dppo_mask_high` | `float` (kw-only), `>= 0` | `0.2` | Upper probability-diff threshold. PRIME-RL `DefaultLossConfig` default. |
|
| 1381 |
+
| `dppo_mask_low` | `float` (kw-only), `>= 0` | `0.2` | Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. |
|
| 1382 |
+
| `adv_tau` | `float` (kw-only), `>= 0` | `1.0` | Advantage temperature. PRIME-RL default. |
|
| 1383 |
+
| `kl_tau` | `float` (kw-only), `>= 0` | `1e-3` | KL term temperature. PRIME-RL default. |
|
| 1384 |
+
|
| 1385 |
+
**Returns** scalar `torch.Tensor` (PRIME-RL's trainer calls `.backward()`).
|
| 1386 |
+
|
| 1387 |
+
**Raises** `ValueError` if any of `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` is not 1-D, or any of the four `>=0`-constrained knobs is negative. `NotImplementedError` if `alpha_sdpo > 0` (channel 2 deferred).
|
| 1388 |
+
|
| 1389 |
+
```python
|
| 1390 |
+
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
|
| 1391 |
+
# In PRIME-RL config:
|
| 1392 |
+
# loss:
|
| 1393 |
+
# custom:
|
| 1394 |
+
# import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
|
| 1395 |
+
# kwargs:
|
| 1396 |
+
# dppo_mask_high: 0.2
|
| 1397 |
+
# dppo_mask_low: 0.2
|
| 1398 |
+
# adv_tau: 1.0
|
| 1399 |
+
# kl_tau: 1.0e-3
|
| 1400 |
+
```
|
| 1401 |
+
|
| 1402 |
+
---
|
| 1403 |
+
|
| 1404 |
+
## 14. `composer_replication.recipes.monarch.actors`
|
| 1405 |
+
|
| 1406 |
+
🟡 SKELETON module per ADR-006. Importable; classes raise `NotImplementedError` on instantiation. Documents the actor signatures so the recipe matrix is complete.
|
| 1407 |
+
|
| 1408 |
+
### `class TrainerActor` 🟡
|
| 1409 |
+
|
| 1410 |
+
```python
|
| 1411 |
+
class TrainerActor:
|
| 1412 |
+
backend = "monarch"
|
| 1413 |
+
role = "trainer"
|
| 1414 |
+
|
| 1415 |
+
def __init__(self) -> None: raise NotImplementedError(...)
|
| 1416 |
+
async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError
|
| 1417 |
+
```
|
| 1418 |
+
|
| 1419 |
+
Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+.
|
| 1420 |
+
|
| 1421 |
+
### `class GeneratorActor` 🟡
|
| 1422 |
+
|
| 1423 |
+
```python
|
| 1424 |
+
class GeneratorActor:
|
| 1425 |
+
backend = "monarch"
|
| 1426 |
+
role = "generator"
|
| 1427 |
+
def __init__(self) -> None: raise NotImplementedError(...)
|
| 1428 |
+
async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError
|
| 1429 |
+
```
|
| 1430 |
+
|
| 1431 |
+
vLLM-backed rollout actor.
|
| 1432 |
+
|
| 1433 |
+
### `class RewarderActor` 🟡
|
| 1434 |
+
|
| 1435 |
+
```python
|
| 1436 |
+
class RewarderActor:
|
| 1437 |
+
backend = "monarch"
|
| 1438 |
+
role = "rewarder"
|
| 1439 |
+
def __init__(self) -> None: raise NotImplementedError(...)
|
| 1440 |
+
async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError
|
| 1441 |
+
```
|
| 1442 |
+
|
| 1443 |
+
verifiers-protocol rewarder.
|
| 1444 |
+
|
| 1445 |
+
### `class TeacherPoolActor` 🟡
|
| 1446 |
+
|
| 1447 |
+
```python
|
| 1448 |
+
class TeacherPoolActor:
|
| 1449 |
+
backend = "monarch"
|
| 1450 |
+
role = "teacher_pool"
|
| 1451 |
+
def __init__(self) -> None: raise NotImplementedError(...)
|
| 1452 |
+
```
|
| 1453 |
+
|
| 1454 |
+
Channel-3 teacher pool wrapping `composer_replication.teacher_replay`.
|
| 1455 |
+
|
| 1456 |
+
```python
|
| 1457 |
+
# All Monarch actors raise on instantiation in v0:
|
| 1458 |
+
from composer_replication.recipes.monarch.actors import TrainerActor
|
| 1459 |
+
# TrainerActor() # NotImplementedError
|
| 1460 |
+
```
|
| 1461 |
+
|
| 1462 |
+
---
|
| 1463 |
+
|
| 1464 |
+
## Notes on test coverage
|
| 1465 |
+
|
| 1466 |
+
Tested contracts (referenced spike/test paths):
|
| 1467 |
+
|
| 1468 |
+
- `compose_loss` + `LossComponents` + `build_batch`: `composer_replication/tests/test_compose_loss_integration.py`, `spikes/006-real-hf-model-smoke/tests/`.
|
| 1469 |
+
- `generalized_jsd_loss`: `spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py`.
|
| 1470 |
+
- `simpo_loss`, `taid_loss`, `taid_alpha_schedule`, `taid_blended_logits`, `entropy_aware_opd_loss`: `composer_replication/distillation/tests/test_distillation_losses.py`.
|
| 1471 |
+
- `replay_trace`, `extract_dpo_pairs`, `DPOPair`, `TraceState`, `TeacherCallResult`, `TeacherSpec`, `DEFAULT_TEACHERS`: `spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py`.
|
| 1472 |
+
- `DJNormalizer`, `NormalizedDPOPair`, `replay_and_normalize_trace`: `composer_replication/replaysim/tests/test_replaysim.py`.
|
| 1473 |
+
- `ClaudeCodeIngester`, `IngestionStats`, `SYSTEM_PROMPT`: `spikes/007-real-trace-ingestion/tests/`.
|
| 1474 |
+
- `ComposerDataCollator`, `CollatorConfig`, `TraceTurn`, `TraceExample`: `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`.
|
| 1475 |
+
- `ComposerReplicationTrainer._compute_loss` (composition arithmetic): `spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py`.
|
| 1476 |
+
- `make_diloco_outer_loop` + sign convention: `spikes/008-streaming-diloco/tests/test_diloco_smoke.py`.
|
| 1477 |
+
- `ObjectStoreAllReduce`, `MockManager`, `LocalProcessExecutor`, `ReplicaHandle`, `ServerlessExecutor`, `replica_entrypoint.main`: `composer_replication/diloco/serverless/tests/test_serverless_local.py`, `test_serverless_diloco_integration.py`.
|
| 1478 |
+
- `recipes.prime_rl.composer_loss.loss_fn`: `composer_replication/recipes/prime_rl/tests/test_composer_loss.py`.
|
| 1479 |
+
|
| 1480 |
+
Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above.
|
| 1481 |
+
|
| 1482 |
+
---
|
| 1483 |
+
|
| 1484 |
+
**Document path**: `/mnt/e/CS/HF/composer-replication-framework/docs/API_REFERENCE.md`
|
|
@@ -0,0 +1,998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# INTEGRATION_RECIPES.md — Wiring the 3-channel composer loss into your RL stack
|
| 2 |
+
|
| 3 |
+
> **Status:** Wave 14 release reference. Supersedes the historical
|
| 4 |
+
> [`docs/INTEGRATION_ARCHITECTURE.md`](INTEGRATION_ARCHITECTURE.md) (Recipes
|
| 5 |
+
> A–D), which is retained as background reading for the original
|
| 6 |
+
> mechanism-level diagrams.
|
| 7 |
+
>
|
| 8 |
+
> **Companion docs:**
|
| 9 |
+
> - [`docs/USER_GUIDE.md`](USER_GUIDE.md) — narrative walk-through, sections 1–8
|
| 10 |
+
> - [`docs/API_REFERENCE.md`](API_REFERENCE.md) — exact kwarg signatures
|
| 11 |
+
> - [`docs/TROUBLESHOOTING.md`](TROUBLESHOOTING.md) — error → fix index
|
| 12 |
+
> - [`docs/V3_SUBSTRATE_COVERAGE.md`](V3_SUBSTRATE_COVERAGE.md) — what each
|
| 13 |
+
> substrate covers
|
| 14 |
+
> - [`docs/adrs/ADR-006-rl-frameworks.md`](adrs/ADR-006-rl-frameworks.md) —
|
| 15 |
+
> why these five recipes and not others
|
| 16 |
+
|
| 17 |
+
This document is the canonical answer to **"how do I plug the 3-channel
|
| 18 |
+
composer loss into framework X?"** for the five frameworks the project
|
| 19 |
+
supports as of Wave 14:
|
| 20 |
+
|
| 21 |
+
1. [TRL `GRPOTrainer` subclass](#recipe-1--trl-grpotrainer-subclass)
|
| 22 |
+
2. [VeRL custom `adv_estimator` + DataProto extension](#recipe-2--verl-custom-adv_estimator--dataproto-extension)
|
| 23 |
+
3. [PRIME-RL custom-loss config](#recipe-3--prime-rl-customlossconfig)
|
| 24 |
+
4. [Serverless Decoupled DiLoCo (Modal / HF Jobs / SageMaker)](#recipe-4--serverless-decoupled-diloco)
|
| 25 |
+
5. [Monarch actor mesh (TorchForge-style topology)](#recipe-5--monarch-actor-mesh)
|
| 26 |
+
|
| 27 |
+
Each recipe follows the same seven-part template:
|
| 28 |
+
|
| 29 |
+
1. **When to use it** — decision criteria.
|
| 30 |
+
2. **Install command** — which optional extras of `composer-replication`.
|
| 31 |
+
3. **Minimum-viable Python script** — copy-pasteable, ≤ 60 lines.
|
| 32 |
+
4. **Decoupled DiLoCo wiring** — how `ServerlessExecutor` +
|
| 33 |
+
`ObjectStoreAllReduce` + `MockManager` layer on top.
|
| 34 |
+
5. **Distillation-loss wiring** — how to switch DPO → SimPO and add TAID
|
| 35 |
+
via `compose_loss(..., dpo_variant=..., sdpo_wrapper=...)` or the
|
| 36 |
+
recipe's own loss-config field.
|
| 37 |
+
6. **Cost ballpark** — GPU $/hr + API spend, sourced from
|
| 38 |
+
[`docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md`](research/DILOCO_SERVERLESS_RECONNAISSANCE.md).
|
| 39 |
+
7. **Known limitations as of Wave 14**.
|
| 40 |
+
|
| 41 |
+
A cross-recipe [comparison matrix](#comparison-matrix) closes the doc.
|
| 42 |
+
|
| 43 |
+
## TL;DR — the unified loss
|
| 44 |
+
|
| 45 |
+
For any of the five recipes, the v0.1 trainer step computes:
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
total_loss = grpo_loss
|
| 49 |
+
+ α * sdpo_kl_loss (channel 2 — Composer hint-distill;
|
| 50 |
+
optional TAID or Entropy-OPD wrapper)
|
| 51 |
+
+ β * trace_replay_loss (channel 3 — N-teacher DPO;
|
| 52 |
+
switchable to SimPO)
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
This is implemented once, in
|
| 56 |
+
[`composer_replication/loss.py::compose_loss`](../composer_replication/loss.py),
|
| 57 |
+
and re-used by every recipe via the kwargs documented in
|
| 58 |
+
[`API_REFERENCE.md`](API_REFERENCE.md). The verified surface is:
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
def compose_loss(
|
| 62 |
+
model,
|
| 63 |
+
inputs,
|
| 64 |
+
*,
|
| 65 |
+
alpha_sdpo: float = 0.1,
|
| 66 |
+
beta_replay: float = 0.05,
|
| 67 |
+
sdpo_jsd_beta: float = 0.5,
|
| 68 |
+
sdpo_temperature: float = 1.0,
|
| 69 |
+
sdpo_token_clip: float | None = None,
|
| 70 |
+
replay_dpo_beta: float = 0.1,
|
| 71 |
+
# ADR-007 extensions
|
| 72 |
+
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 73 |
+
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 74 |
+
taid_schedule_step: int | None = None,
|
| 75 |
+
taid_total_steps: int | None = None,
|
| 76 |
+
simpo_beta: float = 2.0,
|
| 77 |
+
simpo_gamma: float = 1.0,
|
| 78 |
+
taid_schedule: str = "linear",
|
| 79 |
+
taid_alpha_max: float = 1.0,
|
| 80 |
+
entropy_opd_h_max: float | None = None,
|
| 81 |
+
) -> torch.Tensor: ...
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
All five recipes below either call `compose_loss` directly or call a
|
| 85 |
+
thin per-framework adapter that forwards these kwargs unchanged.
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## Recipe 1 — TRL `GRPOTrainer` subclass
|
| 90 |
+
|
| 91 |
+
### 1. When to use it
|
| 92 |
+
|
| 93 |
+
This is the **default v0.0/v0.1 path** and the one we recommend for
|
| 94 |
+
~99% of users today. Pick TRL when:
|
| 95 |
+
|
| 96 |
+
- Your model fits on ≤ 32 GPUs (typically ≤ 70B-param FSDP).
|
| 97 |
+
- You already have a HuggingFace `model` + `tokenizer` + `datasets` flow.
|
| 98 |
+
- You want minimum integration cost — `ComposerReplicationTrainer` is a
|
| 99 |
+
single subclass override of `_compute_loss` over `trl.GRPOTrainer`,
|
| 100 |
+
no Ray, no actor mesh.
|
| 101 |
+
- You're doing single-host (one node, possibly multi-GPU FSDP) training.
|
| 102 |
+
|
| 103 |
+
Don't pick TRL when you need >100 B-param scale, when you must async-decouple
|
| 104 |
+
tool calls from the GPU loop, or when a Ray cluster is already in your stack
|
| 105 |
+
(in which case Recipe 2 is cheaper).
|
| 106 |
+
|
| 107 |
+
### 2. Install command
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
pip install -e ".[train,replaysim]"
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
The `train` extra pulls `trl>=0.12`, `peft`, `accelerate`, and `datasets`.
|
| 114 |
+
The `replaysim` extra pulls `data-juicer` for CPU-side DPO normalization
|
| 115 |
+
(channel 3 cleaning step). Add `[serverless]` if you also want Decoupled
|
| 116 |
+
DiLoCo (see step 4).
|
| 117 |
+
|
| 118 |
+
### 3. Minimum-viable Python script
|
| 119 |
+
|
| 120 |
+
```python
|
| 121 |
+
# train_trl.py — minimum viable Recipe 1
|
| 122 |
+
from datasets import load_dataset
|
| 123 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 124 |
+
from composer_replication import ComposerReplicationTrainer
|
| 125 |
+
|
| 126 |
+
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" # swap for 7B once it works
|
| 127 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
| 128 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 129 |
+
dataset = load_dataset("trl-lib/tldr", split="train[:512]")
|
| 130 |
+
|
| 131 |
+
def reward_length(completions, **_):
|
| 132 |
+
return [-abs(len(c) - 64) for c in completions]
|
| 133 |
+
|
| 134 |
+
trainer = ComposerReplicationTrainer(
|
| 135 |
+
model = model,
|
| 136 |
+
processing_class = tokenizer,
|
| 137 |
+
reward_funcs = [reward_length],
|
| 138 |
+
train_dataset = dataset,
|
| 139 |
+
# Composer extras (defaults shown):
|
| 140 |
+
alpha_sdpo = 0.1,
|
| 141 |
+
beta_replay = 0.05,
|
| 142 |
+
sdpo_jsd_beta = 0.5,
|
| 143 |
+
sdpo_temperature = 1.0,
|
| 144 |
+
sdpo_token_clip = None,
|
| 145 |
+
replay_dpo_beta = 0.1,
|
| 146 |
+
)
|
| 147 |
+
trainer.train()
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Channels 2 and 3 **auto-disable per step** when their inputs aren't
|
| 151 |
+
present in the batch (e.g. batches with no error sites get
|
| 152 |
+
`sdpo_kl=0`). Set `alpha_sdpo=0` / `beta_replay=0` to disable globally
|
| 153 |
+
for ablations.
|
| 154 |
+
|
| 155 |
+
### 4. Decoupled DiLoCo wiring
|
| 156 |
+
|
| 157 |
+
`ComposerReplicationTrainer` is a single-process trainer. To run N
|
| 158 |
+
replicas of it under Decoupled DiLoCo, layer the serverless stack on the
|
| 159 |
+
outside: each replica runs the script above; `MockManager` stands in for
|
| 160 |
+
`torchft.Manager` on the inner loop and `ObjectStoreAllReduce` runs the
|
| 161 |
+
outer-loop pseudo-gradient exchange:
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
# diloco_replica.py — what each of the N replicas runs
|
| 165 |
+
import os
|
| 166 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 167 |
+
from composer_replication.diloco.serverless import (
|
| 168 |
+
LocalProcessExecutor, ObjectStoreAllReduce, MockManager,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
rendezvous = ObjectStoreAllReduce(
|
| 172 |
+
uri = "s3://my-bucket/diloco-runs/run42/",
|
| 173 |
+
world_size = 4,
|
| 174 |
+
rank = int(os.environ["REPLICA_RANK"]),
|
| 175 |
+
)
|
| 176 |
+
manager = MockManager(allreduce=rendezvous)
|
| 177 |
+
# trainer.optimizer is the *inner* optimizer; the outer is built here:
|
| 178 |
+
outer = make_diloco_outer_loop(
|
| 179 |
+
inner_optimizer = trainer.optimizer,
|
| 180 |
+
manager = manager,
|
| 181 |
+
sync_every_h = 500,
|
| 182 |
+
)
|
| 183 |
+
trainer.add_callback(outer.callback()) # syncs every H inner steps
|
| 184 |
+
trainer.train()
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
The driver process spins these up with any `ServerlessExecutor`:
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
# Wave 14: ModalExecutor / HFJobsExecutor are skeletons (raise NotImplementedError);
|
| 191 |
+
# use LocalProcessExecutor for testing. Swap once the cloud backends land.
|
| 192 |
+
executor = LocalProcessExecutor()
|
| 193 |
+
handles = executor.launch_replicas(
|
| 194 |
+
n_replicas = 4,
|
| 195 |
+
entrypoint = "diloco_replica.py",
|
| 196 |
+
entrypoint_args = {"rendezvous": rendezvous.uri,
|
| 197 |
+
"rank_env": "REPLICA_RANK"},
|
| 198 |
+
)
|
| 199 |
+
result = executor.collect(handles, timeout=3600)
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
### 5. Distillation-loss wiring
|
| 203 |
+
|
| 204 |
+
`ComposerReplicationTrainer` exposes the new ADR-007 channels via the
|
| 205 |
+
shared `compose_loss` kwargs — pass them through `**kwargs` on the
|
| 206 |
+
trainer and they're forwarded to `compose_loss`:
|
| 207 |
+
|
| 208 |
+
```python
|
| 209 |
+
trainer = ComposerReplicationTrainer(
|
| 210 |
+
model = model, processing_class = tokenizer,
|
| 211 |
+
reward_funcs = [reward_length], train_dataset = dataset,
|
| 212 |
+
# SimPO instead of DPO for channel 3:
|
| 213 |
+
dpo_variant = "simpo",
|
| 214 |
+
simpo_beta = 2.0,
|
| 215 |
+
simpo_gamma = 1.0,
|
| 216 |
+
# TAID-wrapped SDPO for channel 2:
|
| 217 |
+
sdpo_wrapper = "taid",
|
| 218 |
+
taid_schedule = "linear",
|
| 219 |
+
taid_schedule_step = 0, # bumped each call by your callback
|
| 220 |
+
taid_total_steps = 10_000,
|
| 221 |
+
taid_alpha_max = 1.0,
|
| 222 |
+
)
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
Or, equivalently, drop `entropy_opd` in for `taid` if you want
|
| 226 |
+
per-token entropy-gated forward/reverse KL instead of the
|
| 227 |
+
linear-blend interpolation. SimPO does **not** require reference
|
| 228 |
+
log-probs (channel 3 batches with `dpo_chosen_ref_logprobs` /
|
| 229 |
+
`dpo_rejected_ref_logprobs` set are silently ignored).
|
| 230 |
+
|
| 231 |
+
### 6. Cost ballpark
|
| 232 |
+
|
| 233 |
+
- **GPU**: single host, `g5.12xlarge` ($5.67/hr) or RunPod 4×A100-80GB
|
| 234 |
+
(~$5–9/hr) gets you Qwen2.5-7B at moderate throughput. For Qwen2.5-72B
|
| 235 |
+
you'll want 2–4× H100 — `p5.48xlarge` (~$98/hr on AWS, ~$25–30/hr on
|
| 236 |
+
Lambda Cloud / RunPod community).
|
| 237 |
+
- **API**: channel 3 teacher replay via OpenRouter — verified
|
| 238 |
+
~$0.98/trace at 50 steps × 3 teachers (spike 001). For a 100-trace
|
| 239 |
+
curriculum that's ~$100 in teacher tokens.
|
| 240 |
+
- **Storage**: negligible until you turn on DiLoCo (then see Recipe 4).
|
| 241 |
+
|
| 242 |
+
### 7. Known limitations as of Wave 14
|
| 243 |
+
|
| 244 |
+
- **Tool calls block the GPU.** TRL's rollout is synchronous; long
|
| 245 |
+
tool-call latency idles the trainer. Async-decouple via Recipe 2/3/5
|
| 246 |
+
if this matters.
|
| 247 |
+
- **No native multi-node.** TRL is single-process; multi-host scaling is
|
| 248 |
+
via Decoupled DiLoCo (Recipe 4) on top, not via TRL itself.
|
| 249 |
+
- **vLLM weight sync is co-located** — no resharding between FSDP and TP.
|
| 250 |
+
At 70B+ this becomes the bottleneck and you should move to Recipe 2.
|
| 251 |
+
- **`reward_funcs` must be Python callables** that return `list[float]`;
|
| 252 |
+
shell-out reward graders need a wrapper.
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
## Recipe 2 — VeRL custom `adv_estimator` + DataProto extension
|
| 257 |
+
|
| 258 |
+
### 1. When to use it
|
| 259 |
+
|
| 260 |
+
Pick VeRL when:
|
| 261 |
+
|
| 262 |
+
- You need >70B-param scale or >32-GPU multi-host, *and* a Ray cluster
|
| 263 |
+
is acceptable in your stack.
|
| 264 |
+
- You're already using or willing to adopt **3D-HybridEngine** for
|
| 265 |
+
efficient FSDP↔TP weight resharding (verified ~5× weight-sync speed-up
|
| 266 |
+
vs co-located vLLM at 70B+).
|
| 267 |
+
- You need async multi-turn rollouts where tool-call latency must not
|
| 268 |
+
block the GPU loop. VeRL's `AsyncServer` + `AgentLoop` is the
|
| 269 |
+
best-in-class option here.
|
| 270 |
+
- You want extension points the framework's authors *expect* third
|
| 271 |
+
parties to use — the `@register_adv_est("...")` decorator and the
|
| 272 |
+
`DataProto` extension contract are first-class APIs.
|
| 273 |
+
|
| 274 |
+
Don't pick VeRL if you're <7B-param or single-host (overkill —
|
| 275 |
+
Recipe 1's Trainer subclass is one file, not a Ray cluster).
|
| 276 |
+
|
| 277 |
+
### 2. Install command
|
| 278 |
+
|
| 279 |
+
```bash
|
| 280 |
+
pip install -e ".[replaysim]"
|
| 281 |
+
pip install verl # not packaged as an extra; pinned at >=0.3
|
| 282 |
+
# Optional, for the Composer adapter:
|
| 283 |
+
pip install -e ".[serverless]" # for Decoupled DiLoCo on top
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
The framework's verl adapter lives at
|
| 287 |
+
`composer_replication.recipes.verl` (currently shape-only — see
|
| 288 |
+
[Limitations](#7-known-limitations-as-of-wave-14-2) below).
|
| 289 |
+
|
| 290 |
+
### 3. Minimum-viable Python script
|
| 291 |
+
|
| 292 |
+
VeRL's actual entry point is a Hydra/YAML config + `verl.trainer.main_ppo`
|
| 293 |
+
CLI; the pythonic surface looks like this:
|
| 294 |
+
|
| 295 |
+
```python
|
| 296 |
+
# train_verl.py — minimum viable Recipe 2 sketch
|
| 297 |
+
from verl.trainer.ppo import core_algos
|
| 298 |
+
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
| 299 |
+
from composer_replication.loss import compose_loss
|
| 300 |
+
|
| 301 |
+
@core_algos.register_adv_est("grpo_composer")
|
| 302 |
+
def composer_advantage(data, **kwargs):
|
| 303 |
+
"""Custom adv-estimator that adds SDPO + DPO channels to GRPO.
|
| 304 |
+
|
| 305 |
+
Reads three extra DataProto keys (populated by the data prep step):
|
| 306 |
+
- data.batch["sdpo_teacher_logits"] (channel 2)
|
| 307 |
+
- data.non_tensor_batch["teacher_actions"] (channel 3)
|
| 308 |
+
and returns the standard (advantages, returns) tuple plus a stashed
|
| 309 |
+
composer-loss term consumed by the critic worker.
|
| 310 |
+
"""
|
| 311 |
+
advantages, returns = core_algos.compute_grpo_outcome_advantage(data, **kwargs)
|
| 312 |
+
composer_term = compose_loss(
|
| 313 |
+
model = kwargs["actor_module"],
|
| 314 |
+
inputs = data.batch,
|
| 315 |
+
alpha_sdpo = 0.1,
|
| 316 |
+
beta_replay = 0.05,
|
| 317 |
+
dpo_variant = "dpo",
|
| 318 |
+
sdpo_wrapper = "none",
|
| 319 |
+
)
|
| 320 |
+
data.meta_info["composer_loss"] = composer_term
|
| 321 |
+
return advantages, returns
|
| 322 |
+
|
| 323 |
+
# Then in your YAML:
|
| 324 |
+
# algorithm:
|
| 325 |
+
# adv_estimator: grpo_composer
|
| 326 |
+
# and run: python -m verl.trainer.main_ppo --config-name composer_grpo
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
The full driver wires `RayPPOTrainer` against your config; consult VeRL's
|
| 330 |
+
own quickstart for the Ray-cluster boilerplate. The composer-specific
|
| 331 |
+
piece is just the registered estimator above.
|
| 332 |
+
|
| 333 |
+
### 4. Decoupled DiLoCo wiring
|
| 334 |
+
|
| 335 |
+
VeRL's actor workers run in Ray; DiLoCo replicates the **whole VeRL job**.
|
| 336 |
+
Each "replica" is one Ray cluster running Recipe 2 end-to-end; the outer
|
| 337 |
+
loop is independent of Ray and just exchanges pseudo-gradients via the
|
| 338 |
+
object store between Ray-job invocations:
|
| 339 |
+
|
| 340 |
+
```python
|
| 341 |
+
from composer_replication.diloco.serverless import (
|
| 342 |
+
LocalProcessExecutor, ObjectStoreAllReduce,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
rendezvous = ObjectStoreAllReduce(
|
| 346 |
+
uri = "s3://verl-diloco/run/",
|
| 347 |
+
world_size = 4,
|
| 348 |
+
)
|
| 349 |
+
executor = LocalProcessExecutor() # Wave 14: ModalExecutor is a skeleton (raises NotImplementedError) — keep LocalProcessExecutor for now
|
| 350 |
+
handles = executor.launch_replicas(
|
| 351 |
+
n_replicas = 4,
|
| 352 |
+
entrypoint = "verl.trainer.main_ppo",
|
| 353 |
+
entrypoint_args = {
|
| 354 |
+
"+algorithm.adv_estimator": "grpo_composer",
|
| 355 |
+
"+algorithm.diloco.rendezvous": rendezvous.uri,
|
| 356 |
+
"+algorithm.diloco.sync_every_h": 500,
|
| 357 |
+
},
|
| 358 |
+
)
|
| 359 |
+
executor.collect(handles, timeout=24 * 3600)
|
| 360 |
+
```
|
| 361 |
+
|
| 362 |
+
The Ray cluster inside each replica handles intra-replica scaling
|
| 363 |
+
(FSDP / TP / vLLM); the object-store exchange handles cross-replica
|
| 364 |
+
sync. Bandwidth is identical to Recipe 1 (~2 GB / 30 min per replica
|
| 365 |
+
for a 7B-param model in bf16) and well within S3 free-tier.
|
| 366 |
+
|
| 367 |
+
### 5. Distillation-loss wiring
|
| 368 |
+
|
| 369 |
+
The custom `adv_estimator` from step 3 already calls `compose_loss`;
|
| 370 |
+
flip the kwargs there to switch DPO → SimPO or add TAID:
|
| 371 |
+
|
| 372 |
+
```python
|
| 373 |
+
composer_term = compose_loss(
|
| 374 |
+
model = kwargs["actor_module"],
|
| 375 |
+
inputs = data.batch,
|
| 376 |
+
alpha_sdpo = 0.1,
|
| 377 |
+
beta_replay = 0.05,
|
| 378 |
+
dpo_variant = "simpo", # ← SimPO swap
|
| 379 |
+
simpo_beta = 2.0,
|
| 380 |
+
simpo_gamma = 1.0,
|
| 381 |
+
sdpo_wrapper = "taid", # ← TAID wrap
|
| 382 |
+
taid_schedule_step = data.meta_info.get("global_step", 0),
|
| 383 |
+
taid_total_steps = 10_000,
|
| 384 |
+
)
|
| 385 |
+
```
|
| 386 |
+
|
| 387 |
+
VeRL's `data.meta_info` carries the global step automatically, which is
|
| 388 |
+
exactly what TAID's interpolation schedule needs. Channel 2 batches
|
| 389 |
+
without `student_init_logits` / `student_init_input_ids` are auto-skipped
|
| 390 |
+
(returns 0 for that step).
|
| 391 |
+
|
| 392 |
+
### 6. Cost ballpark
|
| 393 |
+
|
| 394 |
+
- **GPU**: 8× H100 (`p5.48xlarge` ~$98/hr on AWS, ~$25/hr on Lambda or
|
| 395 |
+
RunPod community) is the entry point for 70B-class. Expect 32–256
|
| 396 |
+
H100 for full 671B (matches DeepSeek's reported VeRL config).
|
| 397 |
+
- **API**: same ~$0.98/trace as Recipe 1 (channel 3 is a Python helper,
|
| 398 |
+
not a VeRL primitive — costs are framework-independent).
|
| 399 |
+
- **Ray cluster overhead**: head node + redis + dashboard adds ~1
|
| 400 |
+
CPU-instance ($0.10–0.50/hr) per cluster, negligible at GPU scale.
|
| 401 |
+
|
| 402 |
+
### 7. Known limitations as of Wave 14
|
| 403 |
+
|
| 404 |
+
- **`composer_replication.recipes.verl` is shape-only.** The decorator
|
| 405 |
+
registration and DataProto extension are documented but not yet shipped
|
| 406 |
+
as a runnable adapter — Wave 14 release exposes the *contract*, not the
|
| 407 |
+
glue. Expect this to land in a v0.2 follow-up spike.
|
| 408 |
+
- **Ray dependency.** Adds a heavyweight runtime; debugging
|
| 409 |
+
cross-actor crashes can be painful. Use VeRL's `--debug` mode early.
|
| 410 |
+
- **Custom-`adv_estimator` LOC**: writing your own takes ~50–150 LOC
|
| 411 |
+
including DataProto plumbing. Not a one-liner.
|
| 412 |
+
- **No first-class TAID hook in VeRL itself** — we route TAID through
|
| 413 |
+
the meta_info channel; this works but means you can't use VeRL's
|
| 414 |
+
built-in checkpoint-replay tooling without re-stamping `taid_schedule_step`
|
| 415 |
+
on each replay.
|
| 416 |
+
|
| 417 |
+
---
|
| 418 |
+
|
| 419 |
+
## Recipe 3 — PRIME-RL `CustomLossConfig`
|
| 420 |
+
|
| 421 |
+
### 1. When to use it
|
| 422 |
+
|
| 423 |
+
Pick PRIME-RL when:
|
| 424 |
+
|
| 425 |
+
- You're operating in the **PRIME-Intellect / decentralized training**
|
| 426 |
+
universe and want INTELLECT-style scaling on a long-horizon training
|
| 427 |
+
run.
|
| 428 |
+
- You need **DPPO importance-ratio masking** (the rationale most users
|
| 429 |
+
arrive with) — PRIME-RL's headline contribution is the
|
| 430 |
+
out-of-band-token *mask* (not clip) on `log_ratio = trainer_lp -
|
| 431 |
+
inference_lp`, with defaults `low=-4.0, high=4.0`.
|
| 432 |
+
- You want a **first-class custom-loss surface**: PRIME-RL ships
|
| 433 |
+
`CustomLossConfig` that takes an importable Python function and a
|
| 434 |
+
`LossInputs` struct exposing exactly the tensors we need
|
| 435 |
+
(`trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`,
|
| 436 |
+
`advantages`, `loss_mask`). No fork, no Trainer subclass, no monkey-patch.
|
| 437 |
+
- You have access to multi-node infrastructure that PRIME-RL's
|
| 438 |
+
trainer/inference/orchestrator split is designed for.
|
| 439 |
+
|
| 440 |
+
Don't pick PRIME-RL if you need full vocab logits (channel 2 SDPO
|
| 441 |
+
requires logits not log-probs — see Limitations).
|
| 442 |
+
|
| 443 |
+
### 2. Install command
|
| 444 |
+
|
| 445 |
+
```bash
|
| 446 |
+
pip install -e ".[prime-rl,replaysim]"
|
| 447 |
+
# pulls prime-rl>=0.5
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### 3. Minimum-viable Python script
|
| 451 |
+
|
| 452 |
+
PRIME-RL drives via YAML config; the only Python you write is the
|
| 453 |
+
custom-loss function (already shipped at
|
| 454 |
+
`composer_replication/recipes/prime_rl/composer_loss.py`). Wire it in:
|
| 455 |
+
|
| 456 |
+
```yaml
|
| 457 |
+
# prime_rl_config.yaml — point at the framework's adapter
|
| 458 |
+
loss:
|
| 459 |
+
custom:
|
| 460 |
+
import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
|
| 461 |
+
kwargs:
|
| 462 |
+
alpha_sdpo: 0.0 # channel 2 deferred in v0 (see below)
|
| 463 |
+
beta_dpo: 0.0 # channel 3 emits a warning if non-zero
|
| 464 |
+
dppo_mask_high: 4.0 # PRIME-RL DPPO mask bounds
|
| 465 |
+
dppo_mask_low: -4.0
|
| 466 |
+
epsilon: 1.0e-6
|
| 467 |
+
|
| 468 |
+
trainer:
|
| 469 |
+
model: Qwen/Qwen2.5-7B-Instruct
|
| 470 |
+
... # standard PRIME-RL fields
|
| 471 |
+
```
|
| 472 |
+
|
| 473 |
+
The shipped `loss_fn` signature is fixed by PRIME-RL's contract:
|
| 474 |
+
|
| 475 |
+
```python
|
| 476 |
+
def loss_fn(
|
| 477 |
+
inputs: LossInputs,
|
| 478 |
+
*,
|
| 479 |
+
alpha_sdpo: float = 0.0,
|
| 480 |
+
beta_dpo: float = 0.0,
|
| 481 |
+
dppo_mask_high: float = 4.0,
|
| 482 |
+
dppo_mask_low: float = -4.0,
|
| 483 |
+
epsilon: float = 1e-6,
|
| 484 |
+
) -> torch.Tensor:
|
| 485 |
+
log_ratio = inputs.trainer_logprobs - inputs.inference_logprobs
|
| 486 |
+
dppo_invalid = (log_ratio > dppo_mask_high) | (log_ratio < dppo_mask_low)
|
| 487 |
+
keep_mask = inputs.loss_mask & ~dppo_invalid
|
| 488 |
+
grpo = -(inputs.advantages * inputs.trainer_logprobs * keep_mask).sum() \
|
| 489 |
+
/ keep_mask.sum().clamp_min(epsilon)
|
| 490 |
+
if alpha_sdpo != 0.0:
|
| 491 |
+
raise NotImplementedError(
|
| 492 |
+
"Channel 2 SDPO requires full-vocab logits; PRIME-RL v0.5 "
|
| 493 |
+
"exposes only log-probs. Deferred to v0.2."
|
| 494 |
+
)
|
| 495 |
+
if beta_dpo != 0.0:
|
| 496 |
+
import warnings; warnings.warn(
|
| 497 |
+
"Channel 3 trace-replay DPO is out-of-scope for PRIME-RL recipe v0",
|
| 498 |
+
stacklevel=2,
|
| 499 |
+
)
|
| 500 |
+
return grpo
|
| 501 |
+
```
|
| 502 |
+
|
| 503 |
+
**Shape note** (caught in the Wave 13 cross-model review): PRIME-RL
|
| 504 |
+
calls the loss function **once per sample**; tensors are 1-D `(seq,)`,
|
| 505 |
+
*not* batched `(B, T)`. The 10 unit tests in
|
| 506 |
+
`composer_replication/recipes/prime_rl/tests/test_composer_loss.py`
|
| 507 |
+
cover this plus DPPO mask edges.
|
| 508 |
+
|
| 509 |
+
### 4. Decoupled DiLoCo wiring
|
| 510 |
+
|
| 511 |
+
PRIME-RL was designed for decentralized training and ships its own
|
| 512 |
+
weight-sync primitives. Stack DiLoCo on top via the
|
| 513 |
+
`ServerlessExecutor` Protocol — each replica runs an independent
|
| 514 |
+
PRIME-RL job pointing at the same `composer_loss:loss_fn`:
|
| 515 |
+
|
| 516 |
+
```python
|
| 517 |
+
from composer_replication.diloco.serverless import (
|
| 518 |
+
LocalProcessExecutor, ObjectStoreAllReduce,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
rendezvous = ObjectStoreAllReduce(
|
| 522 |
+
uri = "s3://prime-rl-diloco/run/",
|
| 523 |
+
world_size = 4,
|
| 524 |
+
)
|
| 525 |
+
# Wave 14: ModalExecutor is a skeleton (raises NotImplementedError until v0.x).
|
| 526 |
+
# Use LocalProcessExecutor for the inner-replica wiring; swap to the cloud
|
| 527 |
+
# executor once it lands. The DiLoCo + rendezvous code below is identical.
|
| 528 |
+
executor = LocalProcessExecutor()
|
| 529 |
+
handles = executor.launch_replicas(
|
| 530 |
+
n_replicas = 4,
|
| 531 |
+
entrypoint = "prime_rl.cli:main",
|
| 532 |
+
entrypoint_args = {
|
| 533 |
+
"config": "prime_rl_config.yaml",
|
| 534 |
+
"+diloco.rendezvous": rendezvous.uri,
|
| 535 |
+
"+diloco.sync_every_h": 500,
|
| 536 |
+
},
|
| 537 |
+
)
|
| 538 |
+
executor.collect(handles, timeout=24 * 3600)
|
| 539 |
+
```
|
| 540 |
+
|
| 541 |
+
Note PRIME-RL's own multi-node story (the trainer / inference /
|
| 542 |
+
orchestrator split) is **orthogonal** to Decoupled DiLoCo: PRIME-RL
|
| 543 |
+
multi-node = single replica scaled across many GPUs; DiLoCo = N
|
| 544 |
+
independent replicas synchronizing via object store. Combine both for
|
| 545 |
+
"big PRIME-RL job × N replicas".
|
| 546 |
+
|
| 547 |
+
### 5. Distillation-loss wiring
|
| 548 |
+
|
| 549 |
+
Channel 2 (SDPO + TAID + Entropy-OPD) is **deferred** in v0 because
|
| 550 |
+
PRIME-RL's `LossInputs` exposes log-probs not full vocab logits. The
|
| 551 |
+
SimPO swap on channel 3 is also gated by the same shape constraint, but
|
| 552 |
+
DPPO-clip itself doesn't change. To get TAID/SimPO into a PRIME-RL job
|
| 553 |
+
today you must:
|
| 554 |
+
|
| 555 |
+
1. Switch to Recipe 1 or 2 for the SFT/distill phase.
|
| 556 |
+
2. Use PRIME-RL only for the on-policy GRPO+DPPO phase.
|
| 557 |
+
|
| 558 |
+
The v0.2 plan (per ADR-007) is to extend `LossInputs` with a
|
| 559 |
+
`teacher_logits` field; the loss adapter is already shape-ready.
|
| 560 |
+
|
| 561 |
+
### 6. Cost ballpark
|
| 562 |
+
|
| 563 |
+
- **GPU**: similar profile to Recipe 2 — 8–32 H100 typical, scales to
|
| 564 |
+
hundreds for INTELLECT-class runs. Lambda Cloud or RunPod community
|
| 565 |
+
H100 community pricing (~$2–4/hr per H100) is most cost-effective.
|
| 566 |
+
- **API**: channel 3 is gated, so the only OpenRouter spend is from the
|
| 567 |
+
*offline data-prep* spike (using the verifier harness in Recipe 1 to
|
| 568 |
+
pre-bake DPO pairs), not from the training loop itself. Order of
|
| 569 |
+
magnitude: $50–500 for a curriculum-bake one-time, then $0/run.
|
| 570 |
+
- **Network**: PRIME-RL's own decentralized weight sync uses substantial
|
| 571 |
+
bandwidth between training replicas (one of its design constraints);
|
| 572 |
+
this is *separate* from the Decoupled DiLoCo bandwidth and shows up
|
| 573 |
+
as a ceiling on cross-region replica placement.
|
| 574 |
+
|
| 575 |
+
### 7. Known limitations as of Wave 14
|
| 576 |
+
|
| 577 |
+
- **Channel 2 deferred** — see step 5. `alpha_sdpo > 0` raises
|
| 578 |
+
`NotImplementedError`.
|
| 579 |
+
- **Channel 3 emits a warning** if `beta_dpo != 0`; trace-replay DPO
|
| 580 |
+
pairs must be folded into the *training data* (offline) rather than
|
| 581 |
+
the *loss* (online) until v0.2.
|
| 582 |
+
- **PRIME-RL ≥ 0.5 required.** Earlier versions don't ship
|
| 583 |
+
`CustomLossConfig`.
|
| 584 |
+
- **Smoke test deferred.** Per `prime_rl_recipe.md`, the runtime smoke
|
| 585 |
+
test requires a CUDA box + `prime-rl >= 0.5` install and is gated
|
| 586 |
+
to a follow-up spike. The 10 unit tests run cleanly without GPU.
|
| 587 |
+
- **DPPO defaults are PRIME-RL's, not ours.** We pin `low=-4.0,
|
| 588 |
+
high=4.0` to match. If you change them, you're now diverging from
|
| 589 |
+
PRIME-RL's example configs.
|
| 590 |
+
|
| 591 |
+
---
|
| 592 |
+
|
| 593 |
+
## Recipe 4 — Serverless Decoupled DiLoCo
|
| 594 |
+
|
| 595 |
+
### 1. When to use it
|
| 596 |
+
|
| 597 |
+
Pick Decoupled DiLoCo when:
|
| 598 |
+
|
| 599 |
+
- You have **N independent training replicas** that should sync
|
| 600 |
+
occasionally but can't (or shouldn't) cross-talk on every step.
|
| 601 |
+
- The cost or operational burden of an always-on multi-node cluster is
|
| 602 |
+
unacceptable, but you're happy paying for 4× independent **serverless
|
| 603 |
+
jobs**.
|
| 604 |
+
- Your inner trainer is one of Recipes 1–3 — DiLoCo wraps any inner
|
| 605 |
+
optimizer; it's *purely outer-loop*.
|
| 606 |
+
- You need **failure isolation**: if one replica crashes, the others
|
| 607 |
+
keep training; on restart it picks up from the last outer round.
|
| 608 |
+
|
| 609 |
+
DiLoCo's design rests on two abstractions (per ADR-005):
|
| 610 |
+
|
| 611 |
+
1. **`ServerlessExecutor` Protocol** — uniform interface for spinning up
|
| 612 |
+
N replicas across cloud backends (Modal / HF Jobs / SageMaker / k8s).
|
| 613 |
+
2. **`ObjectStoreAllReduce`** — fsspec-backed pseudo-gradient exchange
|
| 614 |
+
that replaces the in-process `torchft.Manager.allreduce` call.
|
| 615 |
+
|
| 616 |
+
The communication pattern is `S3 PutObject + N GetObjects` once per
|
| 617 |
+
inner-H steps, matching DiLoCo paper §3.2 (arXiv:2311.08105). For
|
| 618 |
+
1B-param bf16 that's ~2 GB / 30 min per replica — well within S3
|
| 619 |
+
free-tier.
|
| 620 |
+
|
| 621 |
+
### 2. Install command
|
| 622 |
+
|
| 623 |
+
```bash
|
| 624 |
+
pip install -e ".[diloco,serverless]"
|
| 625 |
+
# also one of the inner-trainer extras:
|
| 626 |
+
pip install -e ".[train]" # if the inner trainer is Recipe 1
|
| 627 |
+
# OR pip install verl # if the inner trainer is Recipe 2
|
| 628 |
+
# OR pip install -e ".[prime-rl]" # if the inner trainer is Recipe 3
|
| 629 |
+
```
|
| 630 |
+
|
| 631 |
+
### 3. Minimum-viable Python script
|
| 632 |
+
|
| 633 |
+
This pattern is independent of the inner trainer — pick any of Recipes
|
| 634 |
+
1/2/3 and wrap it with a `ServerlessExecutor`. The replica entrypoint
|
| 635 |
+
runs the inner trainer; the driver launches N of them and waits.
|
| 636 |
+
|
| 637 |
+
```python
|
| 638 |
+
# diloco_driver.py — driver that launches N replicas
|
| 639 |
+
from composer_replication.diloco.serverless import (
|
| 640 |
+
LocalProcessExecutor, # for dev — runs replicas as local subprocesses
|
| 641 |
+
ObjectStoreAllReduce,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
rendezvous = ObjectStoreAllReduce(
|
| 645 |
+
uri = "s3://my-bucket/diloco-runs/run42/", # or file:// for local
|
| 646 |
+
world_size = 4,
|
| 647 |
+
)
|
| 648 |
+
executor = LocalProcessExecutor() # Wave 14: ModalExecutor skeleton raises NotImplementedError; swap once cloud backend lands
|
| 649 |
+
handles = executor.launch_replicas(
|
| 650 |
+
n_replicas = 4,
|
| 651 |
+
entrypoint = "diloco_replica.py", # (script below)
|
| 652 |
+
entrypoint_args = {
|
| 653 |
+
"rendezvous": rendezvous.uri,
|
| 654 |
+
"rank_env": "REPLICA_RANK",
|
| 655 |
+
},
|
| 656 |
+
)
|
| 657 |
+
result = executor.collect(handles, timeout=3600)
|
| 658 |
+
print({h.replica_id: h.exit_code for h in result})
|
| 659 |
+
```
|
| 660 |
+
|
| 661 |
+
```python
|
| 662 |
+
# diloco_replica.py — runs inside each replica
|
| 663 |
+
import os
|
| 664 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 665 |
+
from composer_replication.diloco.serverless import (
|
| 666 |
+
ObjectStoreAllReduce, MockManager,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# Build inner trainer (Recipe 1 example):
|
| 670 |
+
from train_trl import trainer
|
| 671 |
+
|
| 672 |
+
rendezvous = ObjectStoreAllReduce(
|
| 673 |
+
uri = os.environ["DILOCO_RENDEZVOUS"],
|
| 674 |
+
world_size = 4,
|
| 675 |
+
rank = int(os.environ["REPLICA_RANK"]),
|
| 676 |
+
)
|
| 677 |
+
manager = MockManager(allreduce=rendezvous)
|
| 678 |
+
outer = make_diloco_outer_loop(
|
| 679 |
+
inner_optimizer = trainer.optimizer,
|
| 680 |
+
manager = manager,
|
| 681 |
+
sync_every_h = 500,
|
| 682 |
+
)
|
| 683 |
+
trainer.add_callback(outer.callback())
|
| 684 |
+
trainer.train()
|
| 685 |
+
```
|
| 686 |
+
|
| 687 |
+
### 4. Decoupled DiLoCo wiring
|
| 688 |
+
|
| 689 |
+
This recipe **is** the DiLoCo wiring — see step 3. The available
|
| 690 |
+
executor adapters are:
|
| 691 |
+
|
| 692 |
+
| Executor | Status | Use case |
|
| 693 |
+
|---------------------------|-------------------------------|--------------------------------------|
|
| 694 |
+
| `LocalProcessExecutor` | Production-ready | Dev loop — N subprocesses on one box |
|
| 695 |
+
| `ModalExecutor` | Skeleton (modal-client gated) | Modal cloud, $/sec billing |
|
| 696 |
+
| `HFJobsExecutor` | Skeleton (hf-hub gated) | HuggingFace Jobs, transformer-shop |
|
| 697 |
+
| `SageMakerExecutor` | Roadmap (post-v0.2) | AWS, warm-pool ~10s cold start |
|
| 698 |
+
| `K8sExecutor` | Roadmap | KubeRay / Volcano gang scheduling |
|
| 699 |
+
|
| 700 |
+
Cross-cloud replica placement (e.g. 2× Modal + 2× HF Jobs) is supported
|
| 701 |
+
in principle — they all read/write the same S3 / GCS / HF rendezvous —
|
| 702 |
+
but treat as experimental.
|
| 703 |
+
|
| 704 |
+
### 5. Distillation-loss wiring
|
| 705 |
+
|
| 706 |
+
DiLoCo is loss-agnostic — it operates purely on inner-optimizer state.
|
| 707 |
+
Whichever inner trainer you're running (Recipe 1, 2, or 3) handles
|
| 708 |
+
distillation kwargs as documented in that recipe's step 5. The only
|
| 709 |
+
DiLoCo-specific knob worth knowing: TAID's `taid_schedule_step` is a
|
| 710 |
+
*global* counter, but each replica increments it independently. If you
|
| 711 |
+
care about replicas all reading the same α at outer-sync time, set
|
| 712 |
+
`taid_schedule_step = trainer.state.global_step + replica_offset` and
|
| 713 |
+
let the outer-loop sync average them out.
|
| 714 |
+
|
| 715 |
+
### 6. Cost ballpark
|
| 716 |
+
|
| 717 |
+
Pulled from
|
| 718 |
+
[`docs/research/DILOCO_SERVERLESS_RECONNAISSANCE.md`](research/DILOCO_SERVERLESS_RECONNAISSANCE.md):
|
| 719 |
+
|
| 720 |
+
| Backend | A100-80GB $/hr | H100 $/hr | Cold-start | Notes |
|
| 721 |
+
|---------------|----------------|-----------|------------|------------------------------------------|
|
| 722 |
+
| Modal | $1.39/sec → 4× ≈ $20/hr per A100 | ~$8/hr per H100 | 1–60s warm, 60–120s first-run | $/sec billing; no minimum |
|
| 723 |
+
| AWS SageMaker | $4.10/A100·hr | $12.29/hr | 2–5 min cold, ~10s warm pool | Min 60min on warm pool |
|
| 724 |
+
| GCP Vertex | $3.67/A100·hr | $11/hr | 2–6 min cold | 30–50% premium over raw GPU |
|
| 725 |
+
| Azure ML | ~$3.67/A100·hr | ~$12.25/hr | 3–8 min cold | Use curated env to cut cold-start |
|
| 726 |
+
| RunPod | $1.19/hr (community), $2.17 (secure) | $1.99/hr (community), $4.18 (secure) | seconds | No federation; same-DC only |
|
| 727 |
+
| HF Jobs | comparable to Modal | ~$8–12/hr | 30–90s | Best DX for HF-shop |
|
| 728 |
+
|
| 729 |
+
**Object-store cost.** ~$0.02/GB-month for S3 standard, ~$0/free-tier.
|
| 730 |
+
Pseudo-gradients are ~2 GB per replica per outer round; for a 24-hour
|
| 731 |
+
4-replica run at H=500 that's ~50 outer rounds × 2 GB × 4 replicas = ~400
|
| 732 |
+
GB written. Free-tier blows through fast — budget $10–20 in storage.
|
| 733 |
+
|
| 734 |
+
### 7. Known limitations as of Wave 14
|
| 735 |
+
|
| 736 |
+
- **`ModalExecutor` and `HFJobsExecutor` are skeletons.** They check
|
| 737 |
+
`import modal` / `import huggingface_hub` at *adapter init* time and
|
| 738 |
+
raise; the actual `launch_replicas` is shape-only until the relevant
|
| 739 |
+
spike lands. Use `LocalProcessExecutor` for dev.
|
| 740 |
+
- **`ObjectStoreAllReduce(world_size=1)`** must passthrough cleanly —
|
| 741 |
+
the unit test `test_object_store_allreduce_world_size_1_passthrough`
|
| 742 |
+
is the regression guard. Don't override unless you've read it.
|
| 743 |
+
- **Rank validation is mandatory.** Tests assert
|
| 744 |
+
`ObjectStoreAllReduce(rank=N, world_size=N)` raises (rank must be
|
| 745 |
+
`< world_size`); silent corruption otherwise.
|
| 746 |
+
- **`MockManager` is *not* feature-complete.** It implements the
|
| 747 |
+
`Manager.allreduce` surface that DiLoCo's outer-loop needs, but
|
| 748 |
+
not the full `torchft.Manager` API (no fault-tolerance, no
|
| 749 |
+
membership protocol). Don't use it as a drop-in for live torchft.
|
| 750 |
+
- **No native heterogeneous compute** — all replicas are assumed to
|
| 751 |
+
have the same compute shape. Mixed A100+H100 placements work but
|
| 752 |
+
the slow replica gates outer-loop progress.
|
| 753 |
+
|
| 754 |
+
---
|
| 755 |
+
|
| 756 |
+
## Recipe 5 — Monarch actor mesh
|
| 757 |
+
|
| 758 |
+
### 1. When to use it
|
| 759 |
+
|
| 760 |
+
Pick Monarch when:
|
| 761 |
+
|
| 762 |
+
- You're at **TorchForge-style topology scale**: trainer / generator /
|
| 763 |
+
rewarder / N-teachers all want to be independent, asynchronously
|
| 764 |
+
scheduled, fault-tolerant actors on a typed mesh.
|
| 765 |
+
- You want **heterogeneous executor support** — different actors run
|
| 766 |
+
in different clouds (e.g. `TrainerActor` on Modal A100s,
|
| 767 |
+
`GeneratorActor` on dedicated H100s, `TeacherPoolActor` as 0-GPU CPU
|
| 768 |
+
pods on k8s).
|
| 769 |
+
- You need **hot-swap of actor implementations** — replace
|
| 770 |
+
"OpenRouter teachers" with "local vLLM teachers" by changing one
|
| 771 |
+
Monarch binding, no trainer code change.
|
| 772 |
+
- You're prepared to track **upstream Monarch** (v0.4.1 stable, v0.5
|
| 773 |
+
dev daily); the API is moving and v0 of this recipe is intentionally
|
| 774 |
+
deferred per ADR-006.
|
| 775 |
+
|
| 776 |
+
Don't pick Monarch in Wave 14 unless you're explicitly scoping a
|
| 777 |
+
v0.2+ pilot. The framework ships *skeleton* actors that fail-fast on
|
| 778 |
+
instantiation; this is a reference-pattern reading exercise, not a
|
| 779 |
+
production target.
|
| 780 |
+
|
| 781 |
+
### 2. Install command
|
| 782 |
+
|
| 783 |
+
```bash
|
| 784 |
+
pip install -e ".[prime-rl,monarch]"
|
| 785 |
+
# pulls monarch>=0.4.1 plus the PRIME-RL trainer used inside actors
|
| 786 |
+
```
|
| 787 |
+
|
| 788 |
+
### 3. Minimum-viable Python script
|
| 789 |
+
|
| 790 |
+
The framework ships skeleton actor definitions at
|
| 791 |
+
`composer_replication/recipes/monarch/actors.py`; they raise
|
| 792 |
+
`NotImplementedError` on instantiation in Wave 14. The shape of the
|
| 793 |
+
final answer:
|
| 794 |
+
|
| 795 |
+
```python
|
| 796 |
+
# monarch_train.py — what v0.2+ usage will look like
|
| 797 |
+
from monarch import Actor, mesh, endpoint
|
| 798 |
+
from composer_replication.recipes.monarch.actors import (
|
| 799 |
+
TrainerActor, GeneratorActor, RewarderActor, TeacherPoolActor,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
# Topology
|
| 803 |
+
trainers = mesh.spawn(TrainerActor, n=4, gpu="A100")
|
| 804 |
+
generator = mesh.spawn(GeneratorActor, n=1, gpu="A100")
|
| 805 |
+
rewarder = mesh.spawn(RewarderActor, n=1, gpu=None)
|
| 806 |
+
teachers = mesh.spawn(TeacherPoolActor, n=1, gpu=None)
|
| 807 |
+
|
| 808 |
+
# Wire endpoints
|
| 809 |
+
async def outer_step(batch_id: int):
|
| 810 |
+
prompts = await trainers[0].sample_prompts.call(batch_id)
|
| 811 |
+
rollouts = await generator.rollout.call(prompts)
|
| 812 |
+
rewards = await rewarder.score.call(rollouts)
|
| 813 |
+
teacher_acts = await teachers.replay.call([
|
| 814 |
+
{"state": r["state"]} for r in rollouts
|
| 815 |
+
])
|
| 816 |
+
await trainers.train_outer_step.call(
|
| 817 |
+
batch_id, rollouts=rollouts, rewards=rewards,
|
| 818 |
+
teacher_actions=teacher_acts,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
# Run
|
| 822 |
+
import asyncio
|
| 823 |
+
for batch_id in range(1000):
|
| 824 |
+
asyncio.run(outer_step(batch_id))
|
| 825 |
+
```
|
| 826 |
+
|
| 827 |
+
The Composer 3-channel loss lives inside `TrainerActor.train_outer_step`,
|
| 828 |
+
which calls `compose_loss(...)` exactly as Recipe 1 does. The
|
| 829 |
+
*orchestration* changes; the *loss math* doesn't.
|
| 830 |
+
|
| 831 |
+
### 4. Decoupled DiLoCo wiring
|
| 832 |
+
|
| 833 |
+
Monarch + Decoupled DiLoCo compose naturally: each `TrainerActor` is a
|
| 834 |
+
DiLoCo replica, and Monarch's supervision tree handles the failure
|
| 835 |
+
recovery that ADR-005 lists as a DiLoCo design constraint. The wire-up
|
| 836 |
+
is identical to Recipe 4's `LocalProcessExecutor` pattern, just running
|
| 837 |
+
inside Monarch instead of `subprocess`:
|
| 838 |
+
|
| 839 |
+
```python
|
| 840 |
+
from composer_replication.diloco.serverless import (
|
| 841 |
+
ObjectStoreAllReduce, MockManager,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
class TrainerActor(Actor):
|
| 845 |
+
def __init__(self, rendezvous_uri: str, rank: int, world_size: int):
|
| 846 |
+
self.rendezvous = ObjectStoreAllReduce(
|
| 847 |
+
uri=rendezvous_uri, rank=rank, world_size=world_size,
|
| 848 |
+
)
|
| 849 |
+
self.manager = MockManager(allreduce=self.rendezvous)
|
| 850 |
+
# ... build inner ComposerReplicationTrainer ...
|
| 851 |
+
|
| 852 |
+
@endpoint
|
| 853 |
+
async def train_outer_step(self, batch_id: int, **kw):
|
| 854 |
+
# Inner H steps locally, then sync via self.rendezvous
|
| 855 |
+
...
|
| 856 |
+
```
|
| 857 |
+
|
| 858 |
+
The "object store" is the cross-actor synchronization point that
|
| 859 |
+
*doesn't* go through Monarch's RDMA data plane — by design, slow
|
| 860 |
+
syncs (S3) and fast syncs (RDMA for in-actor weight broadcast) live on
|
| 861 |
+
different planes.
|
| 862 |
+
|
| 863 |
+
### 5. Distillation-loss wiring
|
| 864 |
+
|
| 865 |
+
Monarch sees the loss as opaque: it lives inside `TrainerActor` and
|
| 866 |
+
takes the same `compose_loss` kwargs as Recipe 1. The mesh-level
|
| 867 |
+
benefit is **swap-by-binding**: you can replace `TeacherPoolActor`
|
| 868 |
+
("OpenRouter") with a `LocalVLLMTeacherActor` to switch the
|
| 869 |
+
*supplier* of teacher log-probs without touching the loss config.
|
| 870 |
+
|
| 871 |
+
```python
|
| 872 |
+
# Original binding — channel 3 via OpenRouter
|
| 873 |
+
teachers = mesh.spawn(TeacherPoolActor, n=1, gpu=None)
|
| 874 |
+
|
| 875 |
+
# Swap binding — channel 3 via local vLLM
|
| 876 |
+
teachers = mesh.spawn(LocalVLLMTeacherActor, n=1, gpu="A100",
|
| 877 |
+
model_id="Qwen/Qwen2.5-72B-Instruct")
|
| 878 |
+
|
| 879 |
+
# Trainer config unchanged:
|
| 880 |
+
trainer.compose_loss_kwargs = dict(
|
| 881 |
+
dpo_variant = "simpo", # same as before
|
| 882 |
+
sdpo_wrapper = "taid",
|
| 883 |
+
taid_schedule_step = batch_id,
|
| 884 |
+
taid_total_steps = 10_000,
|
| 885 |
+
)
|
| 886 |
+
```
|
| 887 |
+
|
| 888 |
+
### 6. Cost ballpark
|
| 889 |
+
|
| 890 |
+
In Wave 14: $0 (skeleton fails fast; no compute used). Projected for v0.2+:
|
| 891 |
+
|
| 892 |
+
- **Mesh overhead**: Monarch's coordination plane is light — typically
|
| 893 |
+
<1% of total compute even at 4-actor scale. The dominant cost is
|
| 894 |
+
whatever the actors run.
|
| 895 |
+
- **Heterogeneous placement** is the cost lever: e.g. a 4-trainer mesh
|
| 896 |
+
with `TeacherPoolActor` on 0-GPU CPU pods can cut total $/hr by
|
| 897 |
+
~10–20% vs forcing all actors onto GPU nodes.
|
| 898 |
+
- **Cluster bring-up**: Monarch v0.5's Slurm backend is stable; k8s
|
| 899 |
+
backend is dev-track; bare-metal SSH backend is documented.
|
| 900 |
+
|
| 901 |
+
### 7. Known limitations as of Wave 14
|
| 902 |
+
|
| 903 |
+
- **Skeleton only, fails fast.** Importing `actors.py` is fine;
|
| 904 |
+
instantiating `TrainerActor(...)` raises `NotImplementedError("v0
|
| 905 |
+
skeleton; deferred to v0.2 per ADR-006")`. By design.
|
| 906 |
+
- **Upstream Monarch API is moving.** v0.4.1 stable + v0.5 dev daily
|
| 907 |
+
means breaking changes are expected. Pin to a Monarch hash if you
|
| 908 |
+
prototype.
|
| 909 |
+
- **TorchForge is paused.** Per its own repo banner — don't take
|
| 910 |
+
TorchForge's recipes as production patterns. Monarch alone is
|
| 911 |
+
active; Forge as a layered framework is reference reading.
|
| 912 |
+
- **Open question (deferred):** does Monarch v0.5's Slurm backend
|
| 913 |
+
hand-shake cleanly with HF Jobs lifecycle? See
|
| 914 |
+
`monarch_actor_layout.md` for the open-questions list.
|
| 915 |
+
- **Open question (deferred):** can `TrainerActor` host
|
| 916 |
+
`ComposerReplicationTrainer` unmodified, or does it need a
|
| 917 |
+
`step_init` / `step_compute` split for Monarch's async actor model?
|
| 918 |
+
|
| 919 |
+
---
|
| 920 |
+
|
| 921 |
+
## Comparison matrix
|
| 922 |
+
|
| 923 |
+
| Dimension | Recipe 1 — TRL | Recipe 2 — VeRL | Recipe 3 — PRIME-RL | Recipe 4 — Serverless DiLoCo | Recipe 5 — Monarch |
|
| 924 |
+
|------------------------------------|-----------------------------|----------------------------------|-----------------------------------|------------------------------------|-------------------------------------|
|
| 925 |
+
| **Maturity (Wave 14)** | Production-ready | Production-ready (adapter shape-only) | Recipe ready, runtime smoke deferred | `LocalProcessExecutor` ready; cloud adapters skeleton | Skeleton only; v0.2+ scope |
|
| 926 |
+
| **Supports DAPO / GRPO** | GRPO ✅; DAPO via TRL master | GRPO ✅; DAPO ✅ (built-in) | GRPO+DPPO ✅ (DAPO mask is the headline) | Inherits from inner trainer | Inherits from inner trainer |
|
| 927 |
+
| **Custom-loss extension cost (LOC)** | ~30 LOC (subclass override) | ~50–150 LOC (registered estimator) | ~20 LOC (single Python fn) | 0 (transparent wrapper) | ~30 LOC (loss inside actor) |
|
| 928 |
+
| **OpenEnv-compatible** | ✅ (HF datasets layer) | ✅ (DataProto extension) | ✅ (rollout JSONL contract) | ✅ (orthogonal) | ✅ (RewarderActor binding) |
|
| 929 |
+
| **Native multi-node** | ❌ (single-host FSDP only) | ✅ (Ray cluster + 3D-HybridEngine) | ✅ (trainer/inference/orchestrator split) | ✅ (the *whole point*) | ✅ (mesh of actors) |
|
| 930 |
+
| **Native Decoupled DiLoCo** | ❌ — wrap with Recipe 4 | ❌ — wrap with Recipe 4 | ❌ — wrap with Recipe 4 | ✅ (this *is* it) | ✅ (compose with Recipe 4 inside actor) |
|
| 931 |
+
| **License** | Apache 2.0 (TRL) | Apache 2.0 (VeRL) | Apache 2.0 (PRIME-RL) | Apache 2.0 (this repo) | BSD-3 (Monarch) |
|
| 932 |
+
| **Our recommendation (Wave 14)** | **Default for ≤ 70B / single-host** | Pick at >70B *if* Ray is acceptable | Pick if PRIME-Intellect / DPPO mask is required | Stack on top of 1/2/3 for N replicas | Reference pattern only — revisit v0.2 |
|
| 933 |
+
|
| 934 |
+
---
|
| 935 |
+
|
| 936 |
+
## Cross-recipe checklist
|
| 937 |
+
|
| 938 |
+
Regardless of which recipe you pick, these invariants are tested across
|
| 939 |
+
the 124-test suite and should be true of your wired-up system:
|
| 940 |
+
|
| 941 |
+
- **`alpha_sdpo=0`** must reproduce the channel-1-only baseline
|
| 942 |
+
bit-exact (`test_compose_loss_integration.py`).
|
| 943 |
+
- **`beta_replay=0`** must reproduce the no-channel-3 baseline
|
| 944 |
+
bit-exact.
|
| 945 |
+
- **`sdpo_wrapper="taid"` without `taid_schedule_step`** must `ValueError`
|
| 946 |
+
at first step (`test_compose_loss_integration.py`).
|
| 947 |
+
- **`sdpo_wrapper="taid"` at `taid_schedule_step / taid_total_steps = 0`**
|
| 948 |
+
must ignore the teacher signal (`test_taid_loss_alpha_zero_ignores_teacher`).
|
| 949 |
+
- **`sdpo_wrapper="taid"` at `taid_schedule_step / taid_total_steps = 1`**
|
| 950 |
+
must equal plain SDPO (`test_taid_blended_logits_endpoints`).
|
| 951 |
+
- **`dpo_variant="simpo"`** must be differentiable through the
|
| 952 |
+
`loss-of-sigmoid` path (`test_simpo_loss_differentiable`).
|
| 953 |
+
- **`sdpo_wrapper="entropy_opd"`** must zero out when student ≡ teacher
|
| 954 |
+
(`test_entropy_aware_opd_zero_when_distributions_match`).
|
| 955 |
+
- **`ObjectStoreAllReduce(world_size=1)`** must passthrough cleanly
|
| 956 |
+
(`test_object_store_allreduce_world_size_1_passthrough`).
|
| 957 |
+
|
| 958 |
+
If any of these fail in your wired-up system, run the corresponding
|
| 959 |
+
unit test to localize: most break because a kwarg got dropped at the
|
| 960 |
+
adapter boundary, not because the loss math is wrong.
|
| 961 |
+
|
| 962 |
+
---
|
| 963 |
+
|
| 964 |
+
## Picking a recipe — decision flow
|
| 965 |
+
|
| 966 |
+
1. **Piloting Monarch (v0.2+)?** → Recipe 5.
|
| 967 |
+
2. **Else, need >70B / multi-host?** → Recipe 2 (VeRL) if Ray is OK,
|
| 968 |
+
Recipe 3 (PRIME-RL) if you're in the PRIME-Intellect / DPPO universe,
|
| 969 |
+
otherwise wait for Recipe 5.
|
| 970 |
+
3. **Else** → Recipe 1 (TRL) is the v0.0/v0.1 default.
|
| 971 |
+
4. **At any of 1–3, need N independent replicas / failure isolation?**
|
| 972 |
+
→ Stack Recipe 4 (Decoupled DiLoCo) on top.
|
| 973 |
+
|
| 974 |
+
---
|
| 975 |
+
|
| 976 |
+
## Pointers to source
|
| 977 |
+
|
| 978 |
+
- Loss core: [`composer_replication/loss.py`](../composer_replication/loss.py)
|
| 979 |
+
- TRL trainer: [`composer_replication/trainer/composer_trainer.py`](../composer_replication/trainer/composer_trainer.py)
|
| 980 |
+
- PRIME-RL adapter:
|
| 981 |
+
[`composer_replication/recipes/prime_rl/composer_loss.py`](../composer_replication/recipes/prime_rl/composer_loss.py),
|
| 982 |
+
recipe doc:
|
| 983 |
+
[`composer_replication/recipes/prime_rl/prime_rl_recipe.md`](../composer_replication/recipes/prime_rl/prime_rl_recipe.md)
|
| 984 |
+
- Monarch skeleton:
|
| 985 |
+
[`composer_replication/recipes/monarch/actors.py`](../composer_replication/recipes/monarch/actors.py),
|
| 986 |
+
layout doc:
|
| 987 |
+
[`composer_replication/recipes/monarch/monarch_actor_layout.md`](../composer_replication/recipes/monarch/monarch_actor_layout.md)
|
| 988 |
+
- Serverless DiLoCo:
|
| 989 |
+
[`composer_replication/diloco/serverless/`](../composer_replication/diloco/serverless/)
|
| 990 |
+
- VeRL adapter (shape-only): `composer_replication/recipes/verl/`
|
| 991 |
+
- ADRs:
|
| 992 |
+
[`docs/adrs/ADR-005-serverless-diloco.md`](adrs/ADR-005-serverless-diloco.md),
|
| 993 |
+
[`docs/adrs/ADR-006-rl-frameworks.md`](adrs/ADR-006-rl-frameworks.md),
|
| 994 |
+
[`docs/adrs/ADR-007-distillation-losses.md`](adrs/ADR-007-distillation-losses.md)
|
| 995 |
+
|
| 996 |
+
---
|
| 997 |
+
|
| 998 |
+
**File path:** `/mnt/e/CS/HF/composer-replication-framework/docs/INTEGRATION_RECIPES.md`
|
|
@@ -0,0 +1,823 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TROUBLESHOOTING — Wave 14
|
| 2 |
+
|
| 3 |
+
This document catalogs every Wave-14-known failure mode in the Composer
|
| 4 |
+
Replication Framework, along with how to diagnose, fix, and verify each
|
| 5 |
+
one. It is intentionally surgical: the surface area added in Waves 12–14
|
| 6 |
+
(SimPO/TAID/Entropy-OPD distillation kwargs, the PRIME-RL composer-loss
|
| 7 |
+
adapter, the serverless DiLoCo `MockManager` + `ObjectStoreAllReduce`
|
| 8 |
+
path, and the data-juicer-backed replaysim normalizer) introduced new
|
| 9 |
+
ways for users to trip themselves up. Each failure mode here is something
|
| 10 |
+
a maintainer has actually seen or anticipated during the cross-model
|
| 11 |
+
review of Wave 14.
|
| 12 |
+
|
| 13 |
+
If you hit something not covered below, jump to the
|
| 14 |
+
[How to file a bug report](#how-to-file-a-bug-report) section at the end —
|
| 15 |
+
the template there gives a maintainer everything they need to reproduce.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Common things to check first
|
| 20 |
+
|
| 21 |
+
Before reading any further, run through this checklist. ~80% of "framework
|
| 22 |
+
broken" reports turn out to be one of these:
|
| 23 |
+
|
| 24 |
+
1. **Python version.** The framework targets Python 3.10–3.12. The
|
| 25 |
+
`pyproject.toml` `target-version` is `py310`. If you are on 3.13+,
|
| 26 |
+
transitive deps (notably Ray, pulled in by data-juicer) may not yet
|
| 27 |
+
ship wheels and will try to build from source. Run `python --version`.
|
| 28 |
+
|
| 29 |
+
2. **Fresh virtual environment.** Mixing the framework into an existing
|
| 30 |
+
environment that already has `torch`, `transformers`, `trl`, or
|
| 31 |
+
`torchft` pinned to incompatible versions is the #1 source of import-
|
| 32 |
+
time errors. Create a new venv: `python -m venv .venv && source
|
| 33 |
+
.venv/bin/activate && pip install -e .[dev]`.
|
| 34 |
+
|
| 35 |
+
3. **Editable install.** Most contributors run `pip install -e .` so
|
| 36 |
+
that local edits to `composer_replication/` are picked up. If you
|
| 37 |
+
`pip install composer-replication` from a registry instead, your
|
| 38 |
+
edits to the source tree will be ignored. Confirm with
|
| 39 |
+
`pip show composer-replication | grep Location`.
|
| 40 |
+
|
| 41 |
+
4. **Optional extras.** Several modules are optional-dep gated:
|
| 42 |
+
- `[replay]` — adds `pyyaml`, the OpenAI/Anthropic/Together SDKs.
|
| 43 |
+
- `[replaysim]` — adds `data-juicer` (and via it, Ray as a transitive).
|
| 44 |
+
- `[serverless]` — adds `fsspec`. For non-local rendezvous URIs you
|
| 45 |
+
also need a backend-specific fsspec adapter (see Failure Mode 5).
|
| 46 |
+
- `[dev]` — adds `pytest`, `ruff`, etc.
|
| 47 |
+
If you see `ModuleNotFoundError: No module named 'data_juicer'`, you
|
| 48 |
+
forgot the extra. Install with `pip install -e .[replaysim]`.
|
| 49 |
+
|
| 50 |
+
5. **Run the test suite first.** Before debugging anything, run the
|
| 51 |
+
subset of tests touching the area you care about:
|
| 52 |
+
```
|
| 53 |
+
pytest composer_replication/tests/ # core compose_loss
|
| 54 |
+
pytest composer_replication/distillation/tests/ # SimPO / TAID / OPD
|
| 55 |
+
pytest composer_replication/recipes/prime_rl/tests/ # PRIME-RL adapter
|
| 56 |
+
pytest composer_replication/diloco/serverless/tests/ # MockManager + DiLoCo
|
| 57 |
+
pytest composer_replication/replaysim/tests/ # data-juicer normalizer
|
| 58 |
+
```
|
| 59 |
+
If any green test fails for you locally, the problem is environmental
|
| 60 |
+
— fix that before digging into your own code.
|
| 61 |
+
|
| 62 |
+
6. **Read the docstring of the symbol you're calling.** Wave 14
|
| 63 |
+
docstrings are written to be the first line of documentation. The
|
| 64 |
+
`compose_loss` docstring (`composer_replication/loss.py`) lists every
|
| 65 |
+
required and optional input key. The `MockManager` docstring
|
| 66 |
+
enumerates the torchft surface methods it implements.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Failure modes
|
| 71 |
+
|
| 72 |
+
### 1. `pip install -e .[replaysim]` hangs or fails on Python 3.12 with a Ray-related path error
|
| 73 |
+
|
| 74 |
+
**SYMPTOM.** Installing the `[replaysim]` extra (which pulls
|
| 75 |
+
`data-juicer`) triggers a transitive install of Ray. On Python 3.12, the
|
| 76 |
+
first `import ray` (often during `pip` build hooks or the first time
|
| 77 |
+
data-juicer is loaded) fails with messages mentioning
|
| 78 |
+
`/tmp/ray/session_*` paths, missing `pyarrow` symbols, or `OSError:
|
| 79 |
+
[Errno 2] No such file or directory: '/dev/shm/ray-...'` inside Docker.
|
| 80 |
+
|
| 81 |
+
**DIAGNOSIS.** `data-juicer` declares `ray` as a transitive dependency.
|
| 82 |
+
On Python 3.12 the wheel matrix is incomplete for some Ray versions, and
|
| 83 |
+
Ray's first-import probes `/dev/shm` and `/tmp/ray` for its session
|
| 84 |
+
state. In a sandboxed container, restricted CI runner, or WSL
|
| 85 |
+
environment with a non-default `/tmp`, those probes fail. Wave 14
|
| 86 |
+
subagent T2 hit this in CI and worked around it by pinning Ray and by
|
| 87 |
+
making sure `/tmp` exists and is writable.
|
| 88 |
+
|
| 89 |
+
**FIX.**
|
| 90 |
+
- Prefer Python 3.11 if you're on 3.12+ and don't need 3.12 features.
|
| 91 |
+
- If you must stay on 3.12, ensure `/tmp` is writable and pre-create the
|
| 92 |
+
session directory: `mkdir -p /tmp/ray && chmod 1777 /tmp/ray`.
|
| 93 |
+
- In Docker, mount a real tmpfs at `/dev/shm`:
|
| 94 |
+
`docker run --shm-size=2g …`.
|
| 95 |
+
- If you don't need replaysim normalization, you can skip the extra
|
| 96 |
+
entirely. The `DJNormalizer(skip_dj=True)` passthrough (see
|
| 97 |
+
`composer_replication/replaysim/normalize.py:165`) does not import
|
| 98 |
+
`data_juicer` and therefore does not import Ray.
|
| 99 |
+
|
| 100 |
+
**VERIFICATION.** The skip-dj passthrough is exercised by
|
| 101 |
+
`test_dj_normalizer_skip_dj_passthrough` and
|
| 102 |
+
`test_dj_normalizer_skip_dj_preserves_count` in
|
| 103 |
+
`composer_replication/replaysim/tests/test_replaysim.py`. Both run
|
| 104 |
+
without `data_juicer` installed:
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
pytest composer_replication/replaysim/tests/test_replaysim.py::test_dj_normalizer_skip_dj_passthrough -xvs
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
If that passes in your environment, your `[replaysim]`-less install is
|
| 111 |
+
healthy — only the full data-juicer code path requires Ray.
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
### 2. `compose_loss` produces wrong-looking numbers when combining new kwargs
|
| 116 |
+
|
| 117 |
+
**SYMPTOM.** You pass several Wave-14 distillation kwargs to
|
| 118 |
+
`compose_loss` (e.g. `dpo_variant="simpo"`, `sdpo_wrapper="taid"`,
|
| 119 |
+
`taid_schedule_step=0`, `simpo_beta=2.0`, `entropy_opd_h_max=…`), and
|
| 120 |
+
the loss curve looks wrong: NaNs, identically-zero `sdpo_jsd` channel,
|
| 121 |
+
or a `total` that is bit-different from your reference run with no
|
| 122 |
+
distillation kwargs at all.
|
| 123 |
+
|
| 124 |
+
**DIAGNOSIS.** `compose_loss` now has 13 keyword arguments and the
|
| 125 |
+
contract between them is non-trivial. Subagent T1's review identified
|
| 126 |
+
three combinations that look reasonable but are unsupported:
|
| 127 |
+
- Passing `taid_schedule_step` without `taid_total_steps` (or vice
|
| 128 |
+
versa). The function raises `ValueError` clearly, but the message can
|
| 129 |
+
scroll past in noisy logs.
|
| 130 |
+
- Passing `dpo_variant="simpo"` while still supplying
|
| 131 |
+
`dpo_chosen_ref_logprobs`. Those keys are **silently ignored** —
|
| 132 |
+
SimPO is reference-free.
|
| 133 |
+
- Passing `sdpo_wrapper="taid"` without supplying either
|
| 134 |
+
`student_init_logits` OR `student_init_input_ids` in `inputs`. The
|
| 135 |
+
function will fall back to a forward pass through the (possibly
|
| 136 |
+
drifted) live model, which is a footgun late in training (see Failure
|
| 137 |
+
Mode 8).
|
| 138 |
+
|
| 139 |
+
**FIX.** Read the docstring at the top of
|
| 140 |
+
`composer_replication/loss.py` (lines 25–39 list the three pluggable
|
| 141 |
+
losses and their preconditions). The general rule:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
from composer_replication import compose_loss
|
| 145 |
+
|
| 146 |
+
# Defaults (no distillation knobs) reproduce legacy 3-channel composition bit-exact.
|
| 147 |
+
out = compose_loss(model, inputs)
|
| 148 |
+
|
| 149 |
+
# To opt into SimPO, pass dpo_variant ONLY. Do not pass ref-logprob keys.
|
| 150 |
+
out = compose_loss(model, inputs, dpo_variant="simpo",
|
| 151 |
+
simpo_beta=2.0, simpo_gamma=1.0)
|
| 152 |
+
|
| 153 |
+
# To opt into TAID, pass BOTH schedule_step AND total_steps, AND make sure
|
| 154 |
+
# inputs["student_init_logits"] is populated (see Failure Mode 8).
|
| 155 |
+
out = compose_loss(model, inputs, sdpo_wrapper="taid",
|
| 156 |
+
taid_schedule_step=step, taid_total_steps=total_steps)
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
Setting all 13 kwargs to their defaults is **bit-exact equivalent** to
|
| 160 |
+
the pre-Wave-13 3-channel loss; if your defaults call gives different
|
| 161 |
+
numbers than your old code, file a bug.
|
| 162 |
+
|
| 163 |
+
**VERIFICATION.** The bit-exact equivalence and every supported
|
| 164 |
+
combination is locked in by the 11 integration tests in
|
| 165 |
+
`composer_replication/tests/test_compose_loss_integration.py`. The most
|
| 166 |
+
important ones:
|
| 167 |
+
- `test_defaults_bit_exact_with_legacy_kwargs` — passing the new kwargs
|
| 168 |
+
at their defaults is identical to legacy.
|
| 169 |
+
- `test_simpo_does_not_require_ref_logprobs` — SimPO works with the
|
| 170 |
+
ref-logprob keys absent from `inputs`.
|
| 171 |
+
- `test_taid_alpha_one_recovers_sdpo` — TAID with `alpha_min=alpha_max=1`
|
| 172 |
+
reproduces standard SDPO.
|
| 173 |
+
- `test_taid_requires_schedule_step` / `test_taid_requires_total_steps` —
|
| 174 |
+
the partial-config error path.
|
| 175 |
+
|
| 176 |
+
```
|
| 177 |
+
pytest composer_replication/tests/test_compose_loss_integration.py -xvs
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
---
|
| 181 |
+
|
| 182 |
+
### 3. `MockManager` works today but silently breaks after a torchft upgrade
|
| 183 |
+
|
| 184 |
+
**SYMPTOM.** Your serverless DiLoCo run starts, the first outer round
|
| 185 |
+
completes, and then `torchft.DiLoCo` raises an `AttributeError` on
|
| 186 |
+
something like `_use_async_quorum`, `should_commit`, or
|
| 187 |
+
`current_step` — or worse, it silently uses the wrong sync semantics.
|
| 188 |
+
|
| 189 |
+
**DIAGNOSIS.** `MockManager` is a duck-typed shim that mirrors
|
| 190 |
+
`torchft.Manager` rather than subclassing it. The surface it implements
|
| 191 |
+
is enumerated in the docstring at
|
| 192 |
+
`composer_replication/diloco/serverless/allreduce.py:215`:
|
| 193 |
+
|
| 194 |
+
> Methods/attributes DiLoCo touches: `allreduce`, `should_commit`,
|
| 195 |
+
> `start_quorum`, `current_step`, `disallow_state_dict_read`,
|
| 196 |
+
> `allow_state_dict_read`, `register_state_dict_fn`, `_use_async_quorum`
|
| 197 |
+
> (attribute), `num_participants`, `rank`.
|
| 198 |
+
|
| 199 |
+
The two **private** members in that list — `_use_async_quorum` and the
|
| 200 |
+
internal `current_step` counter — are private torchft API that may be
|
| 201 |
+
renamed without notice in any torchft minor release. Wave 14 subagent
|
| 202 |
+
T3 specifically called this out: "If torchft renames `_use_async_quorum`
|
| 203 |
+
to anything else, MockManager silently breaks because there is nothing
|
| 204 |
+
holding the contract beyond a string."
|
| 205 |
+
|
| 206 |
+
**FIX.**
|
| 207 |
+
- **Pin torchft.** In `pyproject.toml` keep your torchft version pinned
|
| 208 |
+
to a known-good range (e.g. `torchft>=0.2,<0.4`). When you need to
|
| 209 |
+
upgrade, do so deliberately and re-run the integration tests below
|
| 210 |
+
before merging.
|
| 211 |
+
- **Watch the deprecation warning.** Wave 14 sets up a clear path to
|
| 212 |
+
warn if `_use_async_quorum` is read on a fresh instance — see the
|
| 213 |
+
comment at `allreduce.py:255`.
|
| 214 |
+
- **Don't pass an arbitrary torchft branch.** If you've patched torchft
|
| 215 |
+
locally, the `MockManager` may need updating in lockstep. The
|
| 216 |
+
surface-compatibility tests below will catch this in CI.
|
| 217 |
+
|
| 218 |
+
**VERIFICATION.** The full DiLoCo × MockManager surface is exercised by:
|
| 219 |
+
- `test_mock_manager_shape_compat` in
|
| 220 |
+
`composer_replication/diloco/serverless/tests/test_serverless_local.py`
|
| 221 |
+
— sanity check that all expected methods/attributes exist.
|
| 222 |
+
- `test_mockmanager_has_full_diloco_call_surface` in
|
| 223 |
+
`composer_replication/diloco/serverless/tests/test_serverless_diloco_integration.py`
|
| 224 |
+
— runs an end-to-end outer round through real torchft `DiLoCo`,
|
| 225 |
+
hitting every method on the surface list above.
|
| 226 |
+
- `test_mockmanager_diloco_outer_round_completes` — full one-round
|
| 227 |
+
smoke ending in a successful outer SGD step.
|
| 228 |
+
|
| 229 |
+
If any of these tests turn red after a torchft bump, **do not ship**:
|
| 230 |
+
inspect the new torchft Manager surface and update `MockManager`
|
| 231 |
+
to match.
|
| 232 |
+
|
| 233 |
+
```
|
| 234 |
+
pytest composer_replication/diloco/serverless/tests/test_serverless_diloco_integration.py -xvs
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
### 4. SimPO loss curve looks like noise
|
| 240 |
+
|
| 241 |
+
**SYMPTOM.** You wired in `dpo_variant="simpo"`, the run starts, and
|
| 242 |
+
the `trace_replay_dpo` channel either drifts to large negative values
|
| 243 |
+
(→ `total` blows up) or oscillates with much higher variance than
|
| 244 |
+
standard DPO. The loss curve "looks like noise."
|
| 245 |
+
|
| 246 |
+
**DIAGNOSIS.** SimPO uses **average per-token log-probability**
|
| 247 |
+
(`Σ logπ(c_t) / |c|`), not sum log-prob. From the SimPO docstring
|
| 248 |
+
(`composer_replication/distillation/simpo.py:11–18`):
|
| 249 |
+
|
| 250 |
+
> SimPO drops the reference-policy term, replaces it with a target
|
| 251 |
+
> margin γ, and uses **average sequence log-probability instead of
|
| 252 |
+
> sum**. […] L_SimPO = -log σ( β · [avg_logπ(c) - avg_logπ(r)] - γ )
|
| 253 |
+
|
| 254 |
+
If you compute `chosen_logprobs.sum()` (or any unmasked aggregation) and
|
| 255 |
+
hand it to SimPO as `chosen_avg_logprobs`, the loss is undefined: β=2.0
|
| 256 |
+
times a sum-log-prob is on a totally different scale than β=2.0 times an
|
| 257 |
+
average. The result looks plausible per-batch but the optimum is
|
| 258 |
+
nowhere near the dataset's true preference signal.
|
| 259 |
+
|
| 260 |
+
**FIX.** Use the helper
|
| 261 |
+
`composer_replication.distillation.simpo.avg_sequence_logprob`:
|
| 262 |
+
|
| 263 |
+
```python
|
| 264 |
+
from composer_replication.distillation.simpo import (
|
| 265 |
+
simpo_loss, avg_sequence_logprob,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
chosen_avg = avg_sequence_logprob(chosen_logprobs, chosen_response_mask)
|
| 269 |
+
rejected_avg = avg_sequence_logprob(rejected_logprobs, rejected_response_mask)
|
| 270 |
+
loss = simpo_loss(chosen_avg, rejected_avg, beta=2.0, gamma=1.0)
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
The mask is **1 on response tokens, 0 on prompt+padding** — same
|
| 274 |
+
convention as the rest of the framework. If you must roll your own
|
| 275 |
+
aggregation, divide by `response_mask.sum(dim=-1).clamp_min(1.0)`,
|
| 276 |
+
not by `response_mask.shape[-1]`.
|
| 277 |
+
|
| 278 |
+
**VERIFICATION.** The avg-vs-sum semantics are pinned by
|
| 279 |
+
`test_avg_sequence_logprob` in
|
| 280 |
+
`composer_replication/distillation/tests/test_distillation_losses.py`,
|
| 281 |
+
which constructs known per-token log-probs and asserts the helper
|
| 282 |
+
returns the correct per-sequence average. The end-to-end SimPO
|
| 283 |
+
loss-shape check is `test_simpo_loss_returns_scalar` in the same file.
|
| 284 |
+
|
| 285 |
+
```
|
| 286 |
+
pytest composer_replication/distillation/tests/test_distillation_losses.py::test_avg_sequence_logprob -xvs
|
| 287 |
+
pytest composer_replication/distillation/tests/test_distillation_losses.py::test_simpo_loss_lower_for_better_separation -xvs
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
---
|
| 291 |
+
|
| 292 |
+
### 5. `ObjectStoreAllReduce` works locally but fails on `s3://` at first allreduce
|
| 293 |
+
|
| 294 |
+
**SYMPTOM.** You construct
|
| 295 |
+
`ObjectStoreAllReduce(uri="s3://my-bucket/run42/", rank=0,
|
| 296 |
+
world_size=4)`. The constructor succeeds. The first call to
|
| 297 |
+
`allreduce(tensor, name="...")` raises `ImportError: Install s3fs to
|
| 298 |
+
access S3` or `botocore.exceptions.NoCredentialsError: Unable to locate
|
| 299 |
+
credentials`.
|
| 300 |
+
|
| 301 |
+
**DIAGNOSIS.** `ObjectStoreAllReduce` uses fsspec to reach the
|
| 302 |
+
backend, but **fsspec only ships protocol stubs, not adapters**. The
|
| 303 |
+
constructor doesn't know which protocol you'll use and doesn't
|
| 304 |
+
eagerly validate, so it accepts any URI. The `s3://` adapter requires:
|
| 305 |
+
1. The `s3fs` package (`pip install s3fs`), which is **not** in the
|
| 306 |
+
default `[serverless]` extra.
|
| 307 |
+
2. Working AWS credentials (env vars, `~/.aws/credentials`, IAM role,
|
| 308 |
+
or whatever your environment normally provides to boto3).
|
| 309 |
+
|
| 310 |
+
The same is true for `gs://` (`gcsfs`), `az://` (`adlfs`), and
|
| 311 |
+
`hf://` (`huggingface_hub`'s fsspec integration, which is included if
|
| 312 |
+
you have `huggingface_hub` installed).
|
| 313 |
+
|
| 314 |
+
**FIX.**
|
| 315 |
+
- Install the right adapter alongside the framework:
|
| 316 |
+
```
|
| 317 |
+
pip install s3fs # for s3://
|
| 318 |
+
pip install gcsfs # for gs://
|
| 319 |
+
pip install adlfs # for az://
|
| 320 |
+
```
|
| 321 |
+
- Verify credentials work outside the framework first:
|
| 322 |
+
```
|
| 323 |
+
python -c "import s3fs; print(s3fs.S3FileSystem().ls('my-bucket'))"
|
| 324 |
+
```
|
| 325 |
+
- If you're running on Modal/HF Jobs, set the credentials as Modal
|
| 326 |
+
secrets / HF Jobs env vars in the executor config — not in your
|
| 327 |
+
local shell.
|
| 328 |
+
|
| 329 |
+
The constructor could in principle perform an eager probe (e.g. a
|
| 330 |
+
`HEAD` on the rendezvous prefix) to fail fast at init time. Wave 14
|
| 331 |
+
deliberately did not add this because it adds a network round-trip on
|
| 332 |
+
every replica startup. If you want pre-flight validation in your
|
| 333 |
+
training script, call `fsspec.filesystem(protocol).ls(uri)` yourself
|
| 334 |
+
before constructing the manager.
|
| 335 |
+
|
| 336 |
+
**VERIFICATION.** The `file://` and bare-path code paths — the only
|
| 337 |
+
ones that don't need an extra adapter — are exercised by:
|
| 338 |
+
- `test_object_store_allreduce_local_paths_create_dir`
|
| 339 |
+
- `test_object_store_allreduce_world_size_1_passthrough`
|
| 340 |
+
- `test_object_store_allreduce_round_id_increments`
|
| 341 |
+
|
| 342 |
+
…all in
|
| 343 |
+
`composer_replication/diloco/serverless/tests/test_serverless_local.py`.
|
| 344 |
+
If those pass and your `s3://` URI fails, the framework is fine and
|
| 345 |
+
your fsspec adapter or credentials are the problem.
|
| 346 |
+
|
| 347 |
+
```
|
| 348 |
+
pytest composer_replication/diloco/serverless/tests/test_serverless_local.py -xvs
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
---
|
| 352 |
+
|
| 353 |
+
### 6. Custom replaysim recipe drops every record (or crashes data-juicer)
|
| 354 |
+
|
| 355 |
+
**SYMPTOM.** You wrote a custom replaysim YAML recipe modeled on
|
| 356 |
+
`composer_replication/recipes/replaysim/default.yaml`. It loads
|
| 357 |
+
without error, but every input DPO pair is dropped, OR data-juicer
|
| 358 |
+
raises `KeyError: 'text_key'`, OR it raises a complaint about
|
| 359 |
+
"expected str, got list" inside one of the filters.
|
| 360 |
+
|
| 361 |
+
**DIAGNOSIS.** Wave 14 fixed two related bugs in the *default* recipe
|
| 362 |
+
that custom-recipe authors will hit again. Both are documented in the
|
| 363 |
+
header comment at
|
| 364 |
+
`composer_replication/recipes/replaysim/default.yaml:21–35`:
|
| 365 |
+
|
| 366 |
+
1. **`text_keys` plural vs `text_key` singular.** The top-level
|
| 367 |
+
dataset contract uses `text_keys: chosen` (plural). Each individual
|
| 368 |
+
op uses `text_key: chosen` (singular). They are not interchangeable.
|
| 369 |
+
data-juicer's dataset loader validates that the `text_keys` field
|
| 370 |
+
exists on every record before any op runs; an op that uses
|
| 371 |
+
`text_keys` instead of `text_key` is silently misconfigured.
|
| 372 |
+
|
| 373 |
+
2. **`chosen` / `rejected` as strings vs as list-of-dicts.**
|
| 374 |
+
data-juicer ops like `text_length_filter`, `words_num_filter`,
|
| 375 |
+
`special_characters_filter`, and `document_deduplicator` read a
|
| 376 |
+
single string field. Pointing them at the chat-messages list
|
| 377 |
+
(`chosen_messages`, `rejected_messages`) crashes or silently
|
| 378 |
+
no-ops. The framework's `_dpo_pair_to_dj_record` keeps **both**
|
| 379 |
+
shapes side-by-side: `chosen`/`rejected` (strings) for filter ops,
|
| 380 |
+
and `chosen_messages`/`rejected_messages` (chat-messages list) for
|
| 381 |
+
chat-aware ops + the `NormalizedDPOPair` round-trip.
|
| 382 |
+
|
| 383 |
+
**FIX.** Treat the default recipe as your starting template. Concretely:
|
| 384 |
+
- Always declare `text_keys: chosen` at the top.
|
| 385 |
+
- For every length/word/special-char op you add, duplicate it: once
|
| 386 |
+
with `text_key: chosen`, once with `text_key: rejected`. (Each op
|
| 387 |
+
takes only one `text_key` — see comment at lines 31–35 of
|
| 388 |
+
`default.yaml`.)
|
| 389 |
+
- Never point a filter op at `chosen_messages` or `rejected_messages`.
|
| 390 |
+
Those are list-of-dicts; only chat-aware ops accept that shape.
|
| 391 |
+
|
| 392 |
+
**VERIFICATION.** The two-shape contract is locked in by:
|
| 393 |
+
- `test_record_chosen_rejected_are_flat_strings_for_dj_text_ops` —
|
| 394 |
+
asserts `chosen` and `rejected` are bare strings on every record
|
| 395 |
+
produced by `_dpo_pair_to_dj_record`.
|
| 396 |
+
- `test_record_chosen_rejected_messages_carry_chat_shape` — asserts
|
| 397 |
+
`chosen_messages` / `rejected_messages` exist as list-of-dicts.
|
| 398 |
+
- `test_dj_normalizer_e2e_default_recipe(tmp_path)` — runs the actual
|
| 399 |
+
default recipe through real data-juicer end-to-end (skipped if
|
| 400 |
+
`data_juicer` isn't importable).
|
| 401 |
+
|
| 402 |
+
…all in
|
| 403 |
+
`composer_replication/replaysim/tests/test_replaysim.py`. If those
|
| 404 |
+
pass and your custom recipe still drops everything, diff your YAML
|
| 405 |
+
against `default.yaml` until the two shapes align.
|
| 406 |
+
|
| 407 |
+
```
|
| 408 |
+
pytest composer_replication/replaysim/tests/test_replaysim.py -xvs
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
---
|
| 412 |
+
|
| 413 |
+
### 7. `ValueError: expected (seq,) shape, got (B, T)` from PRIME-RL composer_loss
|
| 414 |
+
|
| 415 |
+
**SYMPTOM.** You wired the PRIME-RL recipe into a training loop you
|
| 416 |
+
adapted from another framework (TRL, openrlhf, etc.), and on the very
|
| 417 |
+
first `loss_fn` call you get a `ValueError` mentioning shape
|
| 418 |
+
`(seq,)` versus `(B, T)`.
|
| 419 |
+
|
| 420 |
+
**DIAGNOSIS.** PRIME-RL calls its loss function **one sample at a
|
| 421 |
+
time**, with 1-D `(seq,)` tensors — not batched `(B, T)` tensors. The
|
| 422 |
+
recipe's docstring spells this out at
|
| 423 |
+
`composer_replication/recipes/prime_rl/composer_loss.py:16–30`:
|
| 424 |
+
|
| 425 |
+
> Note the **per-sample (seq,) shape** — PRIME-RL's runner calls the
|
| 426 |
+
> loss function one sample at a time, not on a batched (B, T) tensor.
|
| 427 |
+
|
| 428 |
+
Wave 14 fixed an earlier draft of the recipe that incorrectly assumed
|
| 429 |
+
`(B, T)`. The new version raises a clear `ValueError` if you hand it
|
| 430 |
+
the wrong shape, instead of silently broadcasting and producing
|
| 431 |
+
nonsense gradients. Users who are used to TRL or openrlhf — both of
|
| 432 |
+
which call the loss with batched tensors — see this on day one.
|
| 433 |
+
|
| 434 |
+
**FIX.**
|
| 435 |
+
- If you are running inside PRIME-RL via its `CustomLossConfig`, you
|
| 436 |
+
don't need to do anything: PRIME-RL's runner produces `(seq,)`
|
| 437 |
+
tensors and the recipe accepts them.
|
| 438 |
+
- If you are calling the recipe directly from your own runner, slice
|
| 439 |
+
your batch into per-sample 1-D tensors before each call:
|
| 440 |
+
```python
|
| 441 |
+
for b in range(B):
|
| 442 |
+
inputs_b = LossInputs(
|
| 443 |
+
trainer_logprobs=batched.trainer_logprobs[b],
|
| 444 |
+
inference_logprobs=batched.inference_logprobs[b],
|
| 445 |
+
advantages=batched.advantages[b],
|
| 446 |
+
loss_mask=batched.loss_mask[b],
|
| 447 |
+
teacher_logprobs=None if batched.teacher_logprobs is None
|
| 448 |
+
else batched.teacher_logprobs[b],
|
| 449 |
+
)
|
| 450 |
+
loss = loss_fn(inputs_b, ...)
|
| 451 |
+
```
|
| 452 |
+
- If you genuinely need a batched API, write a thin wrapper around
|
| 453 |
+
`loss_fn`. Don't patch the recipe — its shape contract is dictated
|
| 454 |
+
by PRIME-RL, not by us.
|
| 455 |
+
|
| 456 |
+
**VERIFICATION.** The shape contract is pinned by two tests in
|
| 457 |
+
`composer_replication/recipes/prime_rl/tests/test_composer_loss.py`:
|
| 458 |
+
- `test_advantages_shape_validates_seq_accepted` — `(seq,)` succeeds.
|
| 459 |
+
- `test_advantages_shape_validates_bt_rejected` — `(B, T)` raises
|
| 460 |
+
`ValueError`.
|
| 461 |
+
|
| 462 |
+
```
|
| 463 |
+
pytest composer_replication/recipes/prime_rl/tests/test_composer_loss.py -xvs
|
| 464 |
+
```
|
| 465 |
+
|
| 466 |
+
---
|
| 467 |
+
|
| 468 |
+
### 8. TAID can't run mid-training because `student_init_logits` is missing
|
| 469 |
+
|
| 470 |
+
**SYMPTOM.** You decide partway through a training run to enable
|
| 471 |
+
`sdpo_wrapper="taid"` (e.g. you read the TAID paper after step 2000
|
| 472 |
+
and want to retrofit). The next training step blows up — either with
|
| 473 |
+
a `KeyError` for `student_init_logits` / `student_init_input_ids`, or
|
| 474 |
+
with a strange-looking loss because the framework fell back to
|
| 475 |
+
re-running a forward pass through the *current* (drifted) model
|
| 476 |
+
instead of the init model.
|
| 477 |
+
|
| 478 |
+
**DIAGNOSIS.** TAID interpolates between the **student's distribution
|
| 479 |
+
at step 0** and the teacher's distribution. From the TAID docstring at
|
| 480 |
+
`composer_replication/distillation/taid.py:10–24`:
|
| 481 |
+
|
| 482 |
+
> TAID interpolates between an "identity" target (the student's own
|
| 483 |
+
> distribution at step 0) and the teacher's distribution, with the
|
| 484 |
+
> interpolation coefficient annealed from 0 → 1 over training.
|
| 485 |
+
|
| 486 |
+
That step-0 reference target has to come from somewhere. The framework
|
| 487 |
+
accepts it via either:
|
| 488 |
+
1. `inputs["student_init_logits"]` — a precomputed `(B, T, V)` tensor
|
| 489 |
+
captured at training start (preferred for production), OR
|
| 490 |
+
2. `inputs["student_init_input_ids"]` — input ids for a frozen forward
|
| 491 |
+
pass through `model`. **This assumes `model` has not yet drifted
|
| 492 |
+
from init.** It is correct only at step 0 or in tests; in
|
| 493 |
+
production it silently produces the wrong target.
|
| 494 |
+
|
| 495 |
+
If you forgot to capture the init logits at step 0, you cannot
|
| 496 |
+
faithfully use TAID mid-run.
|
| 497 |
+
|
| 498 |
+
**FIX.** Capture init logits at step 0 and persist them:
|
| 499 |
+
|
| 500 |
+
```python
|
| 501 |
+
# At step 0, before any optimizer.step() call:
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
init_logits = model(input_ids=batch["input_ids"]).logits
|
| 504 |
+
# Save to disk if you'll need them across restarts:
|
| 505 |
+
torch.save(init_logits, "checkpoints/init_logits_batch0.pt")
|
| 506 |
+
inputs["student_init_logits"] = init_logits
|
| 507 |
+
|
| 508 |
+
# Or, if you have a fixed eval probe set, capture init logits once
|
| 509 |
+
# for that fixed set and reuse them every step:
|
| 510 |
+
inputs["student_init_logits"] = cached_init_logits
|
| 511 |
+
```
|
| 512 |
+
|
| 513 |
+
If you genuinely have no step-0 snapshot, **TAID is not retrofittable**
|
| 514 |
+
to your run. Your options are:
|
| 515 |
+
- Restart from a checkpoint that *was* the step-0 model.
|
| 516 |
+
- Use a different distillation wrapper (`sdpo_wrapper="entropy_opd"`)
|
| 517 |
+
that doesn't need init logits.
|
| 518 |
+
- Accept the bias from the live-model fallback path. Don't.
|
| 519 |
+
|
| 520 |
+
**VERIFICATION.** The precomputed-vs-live-fallback contract is exercised by:
|
| 521 |
+
- `test_taid_accepts_precomputed_student_init_logits` in
|
| 522 |
+
`composer_replication/tests/test_compose_loss_integration.py` —
|
| 523 |
+
passes precomputed logits and asserts the TAID-wrapped channel uses
|
| 524 |
+
them.
|
| 525 |
+
- `test_taid_alpha_one_recovers_sdpo` — asserts that with
|
| 526 |
+
`alpha_min=alpha_max=1.0` (i.e. pure teacher target, init logits
|
| 527 |
+
ignored) TAID reproduces standard SDPO. If your training ignores
|
| 528 |
+
init logits silently, *this* is the test that would have failed.
|
| 529 |
+
|
| 530 |
+
```
|
| 531 |
+
pytest composer_replication/tests/test_compose_loss_integration.py::test_taid_accepts_precomputed_student_init_logits -xvs
|
| 532 |
+
```
|
| 533 |
+
|
| 534 |
+
---
|
| 535 |
+
|
| 536 |
+
### 9. `ModalExecutor()` or `HFJobsExecutor()` raises `NotImplementedError` at construction
|
| 537 |
+
|
| 538 |
+
**SYMPTOM.** You write
|
| 539 |
+
`executor = ModalExecutor(app_name="my-app")` (or the HF Jobs
|
| 540 |
+
equivalent) in a production script and the constructor immediately
|
| 541 |
+
raises:
|
| 542 |
+
|
| 543 |
+
```
|
| 544 |
+
NotImplementedError: ModalExecutor is a v0 skeleton; full implementation pending.
|
| 545 |
+
Use LocalProcessExecutor for testing.
|
| 546 |
+
```
|
| 547 |
+
|
| 548 |
+
Same for `HFJobsExecutor`. This is at *init time*, not at the first
|
| 549 |
+
`launch_replicas` call.
|
| 550 |
+
|
| 551 |
+
**DIAGNOSIS.** Per ADR-005 the v0 release ships only the
|
| 552 |
+
`ServerlessExecutor` Protocol and the reference `LocalProcessExecutor`.
|
| 553 |
+
The Modal and HF Jobs implementations are **import-safe skeletons** —
|
| 554 |
+
the classes exist and you can `from … import ModalExecutor`, but
|
| 555 |
+
`__init__` raises `NotImplementedError` to prevent silent partial
|
| 556 |
+
behavior. See `modal.py:64` and `hf_jobs.py:64`.
|
| 557 |
+
|
| 558 |
+
This is intentional. We didn't want to ship a half-working Modal
|
| 559 |
+
executor that succeeds at `launch_replicas` and then silently fails
|
| 560 |
+
two-thirds of the way through `collect`.
|
| 561 |
+
|
| 562 |
+
**FIX.**
|
| 563 |
+
- Use `LocalProcessExecutor` for development, CI, and any single-host
|
| 564 |
+
multi-process testing.
|
| 565 |
+
- For real cloud deployment in the v0 era, run your training script
|
| 566 |
+
directly in Modal/HF Jobs by hand: write your own thin Modal
|
| 567 |
+
function that constructs `MockManager(ObjectStoreAllReduce(uri,
|
| 568 |
+
rank, world_size))` and runs the training loop. The skeleton
|
| 569 |
+
docstrings at `modal.py:24–48` and `hf_jobs.py:26–49` show exactly
|
| 570 |
+
the pattern.
|
| 571 |
+
- Watch the `BACKLOG.md` for v0 polish — the real implementations are
|
| 572 |
+
scheduled.
|
| 573 |
+
|
| 574 |
+
**VERIFICATION.** That `LocalProcessExecutor` is fully functional and
|
| 575 |
+
correctly implements the Protocol is locked in by:
|
| 576 |
+
- `test_local_executor_runs_allreduce_across_replicas` in
|
| 577 |
+
`composer_replication/diloco/serverless/tests/test_serverless_local.py`
|
| 578 |
+
— runs N replicas locally, performs an allreduce across them.
|
| 579 |
+
- `test_local_executor_handles_multiple_rounds`
|
| 580 |
+
- `test_local_executor_reports_failed_replicas`
|
| 581 |
+
|
| 582 |
+
If those tests pass, your serverless DiLoCo machinery works — only the
|
| 583 |
+
specific cloud adapters are missing. The skeletons themselves are not
|
| 584 |
+
under test (raising in `__init__` is the contract).
|
| 585 |
+
|
| 586 |
+
```
|
| 587 |
+
pytest composer_replication/diloco/serverless/tests/test_serverless_local.py -xvs
|
| 588 |
+
```
|
| 589 |
+
|
| 590 |
+
---
|
| 591 |
+
|
| 592 |
+
### 10. DPPO mask drops every token — "loss became 0" or "no gradients"
|
| 593 |
+
|
| 594 |
+
**SYMPTOM.** You ported a PPO config from another framework (KL
|
| 595 |
+
penalty + clip ε=0.2 + value loss), wired it into the PRIME-RL recipe
|
| 596 |
+
with the default `dppo_mask_high=0.2` / `dppo_mask_low=0.2`, and the
|
| 597 |
+
training loss is suspiciously close to zero. Inspecting the recipe's
|
| 598 |
+
internal `keep_mask` shows nearly every token is being masked out.
|
| 599 |
+
|
| 600 |
+
**DIAGNOSIS.** PRIME-RL's "DPPO mask" is **not** the same as PPO
|
| 601 |
+
clipping, and not even the same as a log-ratio threshold. From the
|
| 602 |
+
recipe docstring at
|
| 603 |
+
`composer_replication/recipes/prime_rl/composer_loss.py` (mirroring
|
| 604 |
+
PRIME-RL upstream `prime_rl/trainer/rl/loss.py` lines 137-148):
|
| 605 |
+
|
| 606 |
+
> The mask gate is on **probability-space**
|
| 607 |
+
> `probs_diff = exp(trainer_lp) - exp(inference_lp)`, NOT on the
|
| 608 |
+
> log-ratio. A positive-advantage token is dropped iff
|
| 609 |
+
> `probs_diff > dppo_mask_high`; a negative-advantage token iff
|
| 610 |
+
> `probs_diff < -dppo_mask_low`. Masked tokens are **dropped from the
|
| 611 |
+
> policy-gradient term** but still contribute to the KL penalty.
|
| 612 |
+
|
| 613 |
+
The defaults `dppo_mask_high=dppo_mask_low=0.2` match PRIME-RL's
|
| 614 |
+
`DefaultLossConfig`. Because the gate is on probability-space, the
|
| 615 |
+
"in-band" zone is
|
| 616 |
+
`exp(trainer_lp) ∈ [exp(inference_lp) - 0.2, exp(inference_lp) + 0.2]`.
|
| 617 |
+
For a token with inference probability ~0.5 this is a fairly tight
|
| 618 |
+
band; for tokens at probability ~0.001 or ~0.999 the same threshold
|
| 619 |
+
behaves very differently from a log-ratio bound. This is by design —
|
| 620 |
+
PRIME-RL is bounding the absolute change in token probability, not the
|
| 621 |
+
multiplicative change.
|
| 622 |
+
|
| 623 |
+
The two failure modes:
|
| 624 |
+
|
| 625 |
+
1. **All tokens masked.** Trainer and inference engines disagree
|
| 626 |
+
sharply (fp16 vs bf16, stale rollout cache, mismatched chat
|
| 627 |
+
templates) and `probs_diff` exceeds 0.2 almost everywhere.
|
| 628 |
+
2. **No tokens masked.** Trainer ≈ inference (e.g. you forgot to step
|
| 629 |
+
the optimizer between rollouts) so the bound is never binding and
|
| 630 |
+
the policy never sees any DPPO regularization.
|
| 631 |
+
|
| 632 |
+
**FIX.** Inspect the empirical `probs_diff` distribution before
|
| 633 |
+
tuning:
|
| 634 |
+
|
| 635 |
+
```python
|
| 636 |
+
# In your training loop:
|
| 637 |
+
probs_diff = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs)
|
| 638 |
+
print(torch.quantile(probs_diff.abs(), torch.tensor([0.5, 0.9, 0.99])))
|
| 639 |
+
```
|
| 640 |
+
|
| 641 |
+
For a healthy on-policy run with bf16 trainer + bf16 inference and
|
| 642 |
+
fresh rollouts, the central 99% of `|probs_diff|` should sit well
|
| 643 |
+
below `0.2`. If yours doesn't, the upstream divergence is the
|
| 644 |
+
problem, not the bound. Bumping `dppo_mask_high/low` to 0.5 or 1.0 is
|
| 645 |
+
a workaround but it disables the trust-region intent of DPPO.
|
| 646 |
+
|
| 647 |
+
**Do not** translate PPO ε=0.2 directly. PPO ε=0.2 is a multiplicative
|
| 648 |
+
log-ratio bound (`|log_ratio| < log(1.2) ≈ 0.18`); DPPO's 0.2 is an
|
| 649 |
+
**additive probability-space** bound. The semantics are different and
|
| 650 |
+
the defaults are deliberately tight in probability space.
|
| 651 |
+
|
| 652 |
+
If you genuinely want to disable the mask (e.g. for bug-isolation),
|
| 653 |
+
pass `dppo_mask_high=1e6, dppo_mask_low=1e6` (both are
|
| 654 |
+
`Field(..., ge=0)` upstream — negative values are rejected by
|
| 655 |
+
both PRIME-RL and our adapter). There is a regression test for
|
| 656 |
+
exactly this knob.
|
| 657 |
+
|
| 658 |
+
**VERIFICATION.**
|
| 659 |
+
- `test_dppo_mask_high_drops_positive_advantage_outliers` and
|
| 660 |
+
`test_dppo_mask_low_drops_negative_advantage_outliers` in
|
| 661 |
+
`composer_replication/recipes/prime_rl/tests/test_composer_loss.py`
|
| 662 |
+
— assert that out-of-bound tokens are dropped from the
|
| 663 |
+
policy-gradient term (with the upstream sign-of-advantage gate).
|
| 664 |
+
- `test_dppo_mask_sign_conditioned_on_advantage` — asserts that a
|
| 665 |
+
positive-advantage token with a large *negative* probs_diff is NOT
|
| 666 |
+
dropped (PRIME-RL only checks the upper bound for positive-advantage
|
| 667 |
+
tokens).
|
| 668 |
+
- `test_dppo_bounds_can_be_disabled` — asserts that very wide bounds
|
| 669 |
+
(`1e6`) pass every token through.
|
| 670 |
+
- `test_parity_with_prime_rl_default_loss_fn` — when `prime-rl` is
|
| 671 |
+
installed, runs identical inputs through PRIME-RL upstream and our
|
| 672 |
+
adapter and asserts the loss matches.
|
| 673 |
+
|
| 674 |
+
```
|
| 675 |
+
pytest composer_replication/recipes/prime_rl/tests/test_composer_loss.py -xvs
|
| 676 |
+
```
|
| 677 |
+
|
| 678 |
+
---
|
| 679 |
+
|
| 680 |
+
### 11. `compose_loss` runs but the GRPO channel doesn't behave like real GRPO
|
| 681 |
+
|
| 682 |
+
**SYMPTOM.** You read the README, saw the "3-channel composition: GRPO
|
| 683 |
+
+ SDPO + trace-replay DPO" tagline, called `compose_loss(model,
|
| 684 |
+
inputs)` directly in your training loop, and your reward curve never
|
| 685 |
+
moves the way it would in a real GRPO trainer. Or: you compared
|
| 686 |
+
against a TRL `GRPOTrainer` baseline and `compose_loss` produces
|
| 687 |
+
totally different numbers.
|
| 688 |
+
|
| 689 |
+
**DIAGNOSIS.** From the docstring at the top of
|
| 690 |
+
`composer_replication/loss.py:1–16`:
|
| 691 |
+
|
| 692 |
+
> This is a verification-harness mirror of
|
| 693 |
+
> `ComposerReplicationTrainer._compute_loss` that does NOT depend on
|
| 694 |
+
> TRL's GRPOTrainer parent. The GRPO channel is replaced with standard
|
| 695 |
+
> LM next-token-prediction cross-entropy, which is the limit GRPO
|
| 696 |
+
> converges to under deterministic rewards.
|
| 697 |
+
>
|
| 698 |
+
> Use it for: CPU smokes on real HF models, unit tests of loss
|
| 699 |
+
> composition without spinning up TRL, anywhere we want to verify
|
| 700 |
+
> gradient flow through the 3-channel sum without paying TRL's full
|
| 701 |
+
> machinery cost.
|
| 702 |
+
>
|
| 703 |
+
> **Do NOT use it as the production training loss.** Production =
|
| 704 |
+
> ComposerReplicationTrainer (a real GRPOTrainer subclass).
|
| 705 |
+
|
| 706 |
+
The `lm_ce` channel labelled "GRPO" in the LossComponents dataclass is
|
| 707 |
+
a **stub**: it is plain language-modeling cross-entropy. It is the
|
| 708 |
+
correct channel for verification (gradient flow, channel weighting,
|
| 709 |
+
distillation wiring), but it is not GRPO's surrogate objective and
|
| 710 |
+
will never produce the same numbers as real GRPO under stochastic
|
| 711 |
+
rewards.
|
| 712 |
+
|
| 713 |
+
Real GRPO requires:
|
| 714 |
+
- A reward model or rule-based reward,
|
| 715 |
+
- Per-prompt advantage estimation across G samples,
|
| 716 |
+
- An importance-sampling-ratio clip / mask.
|
| 717 |
+
|
| 718 |
+
Those live in TRL's `GRPOTrainer`, in our PRIME-RL recipe at
|
| 719 |
+
`composer_replication/recipes/prime_rl/composer_loss.py`, or (when
|
| 720 |
+
shipped) in a future VeRL recipe.
|
| 721 |
+
|
| 722 |
+
**FIX.**
|
| 723 |
+
- For production GRPO training, do **not** call `compose_loss` directly.
|
| 724 |
+
Instead use one of:
|
| 725 |
+
- `composer_replication.trainer.composer_trainer.ComposerReplicationTrainer`
|
| 726 |
+
— TRL `GRPOTrainer` subclass, full machinery.
|
| 727 |
+
- `composer_replication.recipes.prime_rl.composer_loss.loss_fn` —
|
| 728 |
+
PRIME-RL's `CustomLossConfig` adapter (channel 1 is real DPPO-clipped GRPO).
|
| 729 |
+
- For ablations, smokes, and unit tests, `compose_loss` is the right
|
| 730 |
+
tool — but log the `lm_ce` channel as `lm_ce`, not as `grpo`. The
|
| 731 |
+
`LossComponents` dataclass already names the field correctly; if
|
| 732 |
+
your wandb logger relabels it as "GRPO loss", fix the label.
|
| 733 |
+
|
| 734 |
+
**VERIFICATION.**
|
| 735 |
+
- The 11-test integration suite at
|
| 736 |
+
`composer_replication/tests/test_compose_loss_integration.py` only
|
| 737 |
+
asserts gradient flow + bit-exact composition; it deliberately does
|
| 738 |
+
not assert any GRPO-specific property of `compose_loss`. That's the
|
| 739 |
+
contract.
|
| 740 |
+
- The PRIME-RL recipe's real DPPO+KL behavior is asserted by
|
| 741 |
+
`test_returns_finite_scalar`,
|
| 742 |
+
`test_dppo_mask_high_drops_positive_advantage_outliers`,
|
| 743 |
+
`test_dppo_mask_sign_conditioned_on_advantage`, and
|
| 744 |
+
`test_parity_with_prime_rl_default_loss_fn` (skip-marked when
|
| 745 |
+
`prime-rl` is not installed)
|
| 746 |
+
in `composer_replication/recipes/prime_rl/tests/test_composer_loss.py`.
|
| 747 |
+
Those tests verify a real importance-sampling-ratio gradient with
|
| 748 |
+
PRIME-RL's advantage-conditioned mask, which `compose_loss` would
|
| 749 |
+
not pass.
|
| 750 |
+
|
| 751 |
+
If you find yourself wanting `compose_loss` to behave like real GRPO,
|
| 752 |
+
that is the signal to switch to one of the production paths above.
|
| 753 |
+
|
| 754 |
+
```
|
| 755 |
+
pytest composer_replication/tests/test_compose_loss_integration.py::test_defaults_bit_exact_with_legacy_kwargs -xvs
|
| 756 |
+
pytest composer_replication/recipes/prime_rl/tests/test_composer_loss.py::test_returns_finite_scalar -xvs
|
| 757 |
+
```
|
| 758 |
+
|
| 759 |
+
---
|
| 760 |
+
|
| 761 |
+
## How to file a bug report
|
| 762 |
+
|
| 763 |
+
If you've read the relevant section above and your problem persists,
|
| 764 |
+
file a bug. Include **all** sections of the template below — the most
|
| 765 |
+
common reason a maintainer can't repro is a missing piece of
|
| 766 |
+
environmental context.
|
| 767 |
+
|
| 768 |
+
```markdown
|
| 769 |
+
### What I expected vs what happened
|
| 770 |
+
(One paragraph.)
|
| 771 |
+
|
| 772 |
+
### Repro steps
|
| 773 |
+
1. ...
|
| 774 |
+
2. ...
|
| 775 |
+
3. ...
|
| 776 |
+
|
| 777 |
+
Minimal self-contained snippet (no `from my_local_thing import …`):
|
| 778 |
+
|
| 779 |
+
```python
|
| 780 |
+
# repro.py
|
| 781 |
+
from composer_replication import compose_loss
|
| 782 |
+
...
|
| 783 |
+
```
|
| 784 |
+
|
| 785 |
+
### Environment
|
| 786 |
+
- OS: (uname -a or `ver` on Windows)
|
| 787 |
+
- Python: (python --version)
|
| 788 |
+
- composer-replication: (pip show composer-replication | head -3)
|
| 789 |
+
- torch: (python -c "import torch; print(torch.__version__)")
|
| 790 |
+
- torchft: (python -c "import torchft; print(torchft.__version__)" || echo "n/a")
|
| 791 |
+
- transformers / trl: (versions, or "not installed")
|
| 792 |
+
- data-juicer / fsspec: (versions, or "not installed")
|
| 793 |
+
- s3fs / gcsfs / adlfs: (versions if relevant)
|
| 794 |
+
- GPU: (nvidia-smi -L or "CPU only")
|
| 795 |
+
- Install method: pip install -e . / wheel / other
|
| 796 |
+
- Extras installed: [replay] [replaysim] [serverless] [dev]
|
| 797 |
+
|
| 798 |
+
### What you've already tried
|
| 799 |
+
- [ ] Read the relevant Failure Mode section of docs/TROUBLESHOOTING.md
|
| 800 |
+
(which one: ___)
|
| 801 |
+
- [ ] Ran `pytest <relevant test path>` and confirmed those tests pass
|
| 802 |
+
- [ ] Ran the repro snippet in a fresh venv
|
| 803 |
+
- [ ] Confirmed it reproduces on Python 3.11 (if you were on 3.12 / 3.13)
|
| 804 |
+
|
| 805 |
+
### Logs
|
| 806 |
+
(Full traceback. If it's a wrong-loss-curve rather than an exception,
|
| 807 |
+
paste loss values for the first 10 steps and link any wandb/tb run.)
|
| 808 |
+
|
| 809 |
+
### Hypothesis
|
| 810 |
+
(Optional. If you have a guess at where the bug is, name the file +
|
| 811 |
+
line number. We'll look there first.)
|
| 812 |
+
```
|
| 813 |
+
|
| 814 |
+
A few rules:
|
| 815 |
+
- **Do not** paste API keys, AWS credentials, or HuggingFace tokens.
|
| 816 |
+
- **Do** include the failing test name if you've narrowed it to one.
|
| 817 |
+
- **Do** distinguish "never worked" from "regressed between commit X
|
| 818 |
+
and Y." A regression-bisect goes straight to the front of the queue.
|
| 819 |
+
- **One bug per issue.** Multi-headed reports lose items in triage.
|
| 820 |
+
|
| 821 |
+
The Wave-14 surface area is large, but the test suite covers it
|
| 822 |
+
densely — every section above corresponds to a green test that proves
|
| 823 |
+
the fix worked.
|
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Composer Replication Framework — User Guide
|
| 2 |
+
|
| 3 |
+
A zero-to-training walkthrough for the open replication of Cursor Composer 2.5.
|
| 4 |
+
Pace: an ML engineer who knows GRPO/DPO at a textbook level but has never
|
| 5 |
+
opened this repo. Every step references real code, and every kwarg name
|
| 6 |
+
listed below has been imported and verified against
|
| 7 |
+
`composer_replication/` source.
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## 1. What is this framework?
|
| 12 |
+
|
| 13 |
+
A pure-PyTorch replication of the **3-channel composer loss** that powers
|
| 14 |
+
agentic-coding model training. One model, one optimizer, three additive
|
| 15 |
+
loss terms — composed every step:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
┌────────────────────────────────────────────┐
|
| 19 |
+
│ compose_loss(model, batch) │
|
| 20 |
+
└────────────────────────────────────────────┘
|
| 21 |
+
│
|
| 22 |
+
┌─────────────────────────────────┼─────────────────────────────────┐
|
| 23 |
+
▼ ▼ ▼
|
| 24 |
+
┌───────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐
|
| 25 |
+
│ Channel 1 (RL) │ │ Channel 2 (SDPO) │ │ Channel 3 (replay) │
|
| 26 |
+
│ GRPO │ │ hint-distillation │ │ multi-teacher DPO │
|
| 27 |
+
│ → lm_ce stub in │ │ generalized JSD │ │ on (chosen, │
|
| 28 |
+
│ verification │ │ student vs teacher │ │ rejected) pairs │
|
| 29 |
+
│ harness │ │ (hint-conditioned) │ │ from N teachers │
|
| 30 |
+
└─────────┬─────────┘ └──────────┬───────────┘ └──────────┬───────────┘
|
| 31 |
+
│ weight = 1 (always on) │ alpha_sdpo beta_replay │
|
| 32 |
+
└────────────────┬──────────────┴─────────────────┬──────────────────┘
|
| 33 |
+
▼ ▼
|
| 34 |
+
total = lm_ce + α·sdpo_jsd + β·trace_replay_dpo
|
| 35 |
+
(channel auto-disables if its weight=0 OR its inputs are missing)
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Two API surfaces, on purpose:
|
| 39 |
+
|
| 40 |
+
- **Verification harness** — `compose_loss(model, batch, ...)` is a free
|
| 41 |
+
function (channel 1 = LM cross-entropy, the GRPO limit under deterministic
|
| 42 |
+
rewards). Use it for CPU smokes, unit tests, and gradient-flow debugging.
|
| 43 |
+
- **Production trainer** — `ComposerReplicationTrainer` is a `trl.GRPOTrainer`
|
| 44 |
+
subclass that overrides `_compute_loss` with the same 3 channels on top of
|
| 45 |
+
TRL's real reward + advantage machinery.
|
| 46 |
+
|
| 47 |
+
The verification harness is what you'll use for sections 2–6; the production
|
| 48 |
+
trainer (and its alternates VeRL/PRIME-RL/Monarch) is section 8.
|
| 49 |
+
|
| 50 |
+
Source of truth: `composer_replication/loss.py` for `compose_loss`,
|
| 51 |
+
`composer_replication/trainer/composer_trainer.py` for the trainer subclass.
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## 2. Install — which extras to pick
|
| 56 |
+
|
| 57 |
+
Always start with the core install:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
git clone https://huggingface.co/Codeseys/composer-replication-framework
|
| 61 |
+
cd composer-replication-framework
|
| 62 |
+
pip install -e .
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
That gets you `torch>=2.0` + `transformers>=4.46` and is enough for the
|
| 66 |
+
verification harness on CPU (sections 3, 5, 6).
|
| 67 |
+
|
| 68 |
+
The seven optional extras are declared in `pyproject.toml` `[project.optional-dependencies]`:
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
Do you need …
|
| 72 |
+
│
|
| 73 |
+
┌──────────────────────────┼──────────────────────────┐
|
| 74 |
+
▼ ▼ ▼
|
| 75 |
+
real teacher calls DiLoCo on production
|
| 76 |
+
over OpenRouter? >1 GPU? GRPO training?
|
| 77 |
+
│ │ │
|
| 78 |
+
│ yes │ yes │ yes
|
| 79 |
+
▼ ▼ ▼
|
| 80 |
+
pip install -e ".[replay]" pip install -e ".[diloco]" pip install -e ".[train]"
|
| 81 |
+
(httpx) (torchft-nightly) (trl, peft, accelerate, datasets)
|
| 82 |
+
│ │ │
|
| 83 |
+
│ + want CPU-side │ + scaling beyond a │ + want PRIME-RL
|
| 84 |
+
│ DPO normalization? │ single host? │ (Recipe C)?
|
| 85 |
+
▼ ▼ ▼
|
| 86 |
+
pip install -e \".[replaysim]\" pip install -e \".[serverless]\" pip install -e \".[prime-rl]\"
|
| 87 |
+
(data-juicer; depends (fsspec, huggingface_hub) (prime-rl>=0.5)
|
| 88 |
+
on [replay])
|
| 89 |
+
│ + Monarch actor mesh?
|
| 90 |
+
▼
|
| 91 |
+
pip install -e \".[monarch]\"
|
| 92 |
+
(monarch>=0.4.1)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Quick decision table:
|
| 96 |
+
|
| 97 |
+
| Goal | Install |
|
| 98 |
+
|-------------------------------------------------------|------------------------------------------|
|
| 99 |
+
| CPU smoke / verification (sections 3, 5, 6) | `pip install -e .` |
|
| 100 |
+
| Section 4 (replaysim DJNormalizer) | `pip install -e ".[replaysim]"` |
|
| 101 |
+
| Section 7 dev loop (LocalProcessExecutor + file://) | `pip install -e ".[serverless]"` |
|
| 102 |
+
| Real DiLoCo outer-loop | `pip install -e ".[diloco,serverless]"` |
|
| 103 |
+
| Section 8 Recipe A (TRL GRPO) | `pip install -e ".[train]"` |
|
| 104 |
+
| Section 8 Recipe C (PRIME-RL) | `pip install -e ".[prime-rl]"` |
|
| 105 |
+
| Section 8 Recipe C+D (PRIME-RL + Monarch) | `pip install -e ".[prime-rl,monarch]"` |
|
| 106 |
+
| Everything for development | `pip install -e ".[dev]"` |
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 3. Quickstart: `examples/qwen_05b_quickstart` end-to-end on CPU
|
| 111 |
+
|
| 112 |
+
The fastest way to convince yourself the framework works on a real HF model.
|
| 113 |
+
~3–5 min wall-clock on CPU, ~1 GB disk for Qwen2.5-0.5B weights.
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
pip install -e .
|
| 117 |
+
python examples/qwen_05b_quickstart/run.py
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
What the script does (read the source at
|
| 121 |
+
`examples/qwen_05b_quickstart/run.py`):
|
| 122 |
+
|
| 123 |
+
1. Pin RNG (`random.seed(42)`, `torch.manual_seed(42)`) so the per-step
|
| 124 |
+
numbers below are reproducible.
|
| 125 |
+
2. Load `Qwen/Qwen2.5-0.5B-Instruct` on CPU in fp32, set `model.train()`.
|
| 126 |
+
3. `batch = build_batch(tokenizer, device="cpu")` — a real chat-template-formatted
|
| 127 |
+
batch with all keys the 3-channel composer might consume.
|
| 128 |
+
4. Five backward steps with `compose_loss(model, batch, alpha_sdpo=0.1,
|
| 129 |
+
beta_replay=0.05)`; an `AdamW(lr=1e-5)` optimizer; finite-grad check
|
| 130 |
+
after each step.
|
| 131 |
+
|
| 132 |
+
Expected output (transcribed from `examples/qwen_05b_quickstart/run.log`):
|
| 133 |
+
|
| 134 |
+
```
|
| 135 |
+
[quickstart] loading Qwen/Qwen2.5-0.5B-Instruct (CPU, fp32) ...
|
| 136 |
+
[quickstart] loaded — 0.494B params
|
| 137 |
+
[quickstart] building real chat-template batch ...
|
| 138 |
+
[quickstart] running 5 backward steps ...
|
| 139 |
+
step 0: total=0.7390 lm_ce=0.7358 sdpo=0.0000 dpo=0.0639 finite=True
|
| 140 |
+
step 1: total=0.0379 lm_ce=0.0351 sdpo=0.0000 dpo=0.0563 finite=True
|
| 141 |
+
step 2: total=0.0122 lm_ce=0.0110 sdpo=0.0000 dpo=0.0240 finite=True
|
| 142 |
+
step 3: total=0.0060 lm_ce=0.0055 sdpo=0.0000 dpo=0.0098 finite=True
|
| 143 |
+
step 4: total=0.0031 lm_ce=0.0029 sdpo=0.0000 dpo=0.0044 finite=True
|
| 144 |
+
========================================================
|
| 145 |
+
Initial loss: 0.7390 → Final loss: 0.0031 → Reduction: 99.6%
|
| 146 |
+
Verdict: PASS
|
| 147 |
+
========================================================
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
How to read this:
|
| 151 |
+
|
| 152 |
+
- **`total` collapses by ~99%.** The model successfully memorizes the
|
| 153 |
+
single batch — exactly what you expect from an SGD pass on a 0.5B model
|
| 154 |
+
with one fixed input. This is a wiring check, not a generalization claim.
|
| 155 |
+
- **`lm_ce` carries almost all the magnitude.** Channel 1 (the GRPO stub)
|
| 156 |
+
is doing the work — the response tokens are short and have low entropy
|
| 157 |
+
under the trained model.
|
| 158 |
+
- **`sdpo=0.0000` on every step.** Channel 2 has auto-disabled because the
|
| 159 |
+
default `build_batch` does not include `ctx_teacher_input_ids`. Compare
|
| 160 |
+
the conditional in `compose_loss`:
|
| 161 |
+
```python
|
| 162 |
+
if (alpha_sdpo > 0.0
|
| 163 |
+
and "ctx_teacher_input_ids" in inputs
|
| 164 |
+
and inputs["ctx_teacher_input_ids"].numel() > 0):
|
| 165 |
+
```
|
| 166 |
+
— channel auto-off if either the weight or the inputs are missing.
|
| 167 |
+
- **`dpo > 0` and trending down.** The batch *does* include
|
| 168 |
+
`dpo_chosen_input_ids`, `dpo_chosen_response_mask`,
|
| 169 |
+
`dpo_chosen_ref_logprobs` (and the rejected counterparts), so channel 3
|
| 170 |
+
is live.
|
| 171 |
+
- **`finite=True`** — every step's `p.grad` was finite for every parameter.
|
| 172 |
+
This is the wiring contract; if it ever flips to `False` the smoke fails.
|
| 173 |
+
|
| 174 |
+
If you see `Verdict: PASS`, the framework is correctly installed and
|
| 175 |
+
gradients flow through all live channels. You are ready for section 4.
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## 4. Adding the trace-replay channel
|
| 180 |
+
|
| 181 |
+
The quickstart batch *had* DPO inputs, but they were synthetic — the
|
| 182 |
+
`build_batch` helper bakes them in. To get **real** DPO pairs from
|
| 183 |
+
multi-teacher disagreement, use the replaysim package.
|
| 184 |
+
|
| 185 |
+
### 4a. Spin up `replay_trace`
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
import asyncio
|
| 189 |
+
from composer_replication import (
|
| 190 |
+
DEFAULT_TEACHERS, replay_trace, extract_dpo_pairs,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Trace must be a list[TraceState]; see composer_replication/teacher_replay.py
|
| 194 |
+
# for the exact TypedDict shape. Each state holds a chat-messages prefix +
|
| 195 |
+
# the student's actual action at that step.
|
| 196 |
+
states = [...] # your frozen agentic trace; see spike 001 for a 50-step example
|
| 197 |
+
|
| 198 |
+
teacher_actions = asyncio.run(
|
| 199 |
+
replay_trace(
|
| 200 |
+
states=states,
|
| 201 |
+
teachers=DEFAULT_TEACHERS, # claude-opus-4.7 + gpt-5 + deepseek-v4-pro
|
| 202 |
+
max_total_usd=10.0, # hard ceiling (spike 001 measured $0.98/trace mean)
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
The 3 teachers are queried in parallel via OpenRouter
|
| 208 |
+
(`OPENROUTER_API_KEY` in env or `~/.hermes/.env`), latencies recorded,
|
| 209 |
+
costs tracked.
|
| 210 |
+
|
| 211 |
+
### 4b. Get `DPOPair`s from disagreement
|
| 212 |
+
|
| 213 |
+
```python
|
| 214 |
+
pairs = extract_dpo_pairs(
|
| 215 |
+
states=states,
|
| 216 |
+
teacher_actions=teacher_actions,
|
| 217 |
+
agreement_threshold=2, # at least 2/3 teachers must agree on the chosen action
|
| 218 |
+
)
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
Each pair is a `DPOPair` TypedDict with the exact shape the
|
| 222 |
+
`DJNormalizer` and downstream training expects:
|
| 223 |
+
|
| 224 |
+
```python
|
| 225 |
+
class DPOPair(TypedDict):
|
| 226 |
+
state_id: str
|
| 227 |
+
state_messages: list[dict] # conversation context
|
| 228 |
+
chosen: str # teacher-consensus action
|
| 229 |
+
rejected: str # student action
|
| 230 |
+
n_teachers_agreeing: int
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
(verified in `composer_replication/teacher_replay.py:99–105`).
|
| 234 |
+
|
| 235 |
+
### 4c. Run `DJNormalizer` with `default.yaml`
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
from composer_replication.replaysim import DJNormalizer
|
| 239 |
+
|
| 240 |
+
normalizer = DJNormalizer() # uses recipes/replaysim/default.yaml
|
| 241 |
+
normalized = normalizer.normalize(pairs)
|
| 242 |
+
# → list[NormalizedDPOPair]
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
`DJNormalizer` shells out to data-juicer's `DefaultExecutor` under the hood
|
| 246 |
+
(file-in / file-out contract). The default recipe at
|
| 247 |
+
`composer_replication/recipes/replaysim/default.yaml` runs four CPU-only ops
|
| 248 |
+
in order:
|
| 249 |
+
|
| 250 |
+
1. `text_length_filter` (8 ≤ chars ≤ 32000) on `chosen` and `rejected`
|
| 251 |
+
2. `words_num_filter` (2 ≤ words ≤ 4096) on both
|
| 252 |
+
3. `special_characters_filter` (≤50% non-alpha) on both
|
| 253 |
+
4. `document_deduplicator` (per-batch hashing, lowercase, ignore non-character) on `chosen`
|
| 254 |
+
|
| 255 |
+
Records carry **two parallel shapes** for `chosen`/`rejected`:
|
| 256 |
+
- flat strings (`chosen`, `rejected`) → consumed by data-juicer's text_key-based filters
|
| 257 |
+
- chat-messages lists (`chosen_messages`, `rejected_messages`) → preserved for chat-aware ops + round-trip
|
| 258 |
+
|
| 259 |
+
This dual-shape design (verified in the test
|
| 260 |
+
`test_dpo_pair_to_dj_record_shape`,
|
| 261 |
+
`composer_replication/replaysim/tests/test_replaysim.py:44`) is what
|
| 262 |
+
unblocked the data-juicer integration in Wave 14.
|
| 263 |
+
|
| 264 |
+
### 4d. The 3-record fixture from spike 001
|
| 265 |
+
|
| 266 |
+
The fixture lives at
|
| 267 |
+
`spikes/001-teacher-replay-cost/states.jsonl` (50 states) and
|
| 268 |
+
`spikes/001-teacher-replay-cost/results.jsonl` (the teacher responses, all
|
| 269 |
+
priced and timed). The first 3 states are:
|
| 270 |
+
|
| 271 |
+
```jsonl
|
| 272 |
+
{"id": "state-000", "task": "Fix the failing test in tests/test_auth.py::test_login_with_email", ...}
|
| 273 |
+
{"id": "state-001", "task": "Add rate-limiting middleware to the Flask app", ...}
|
| 274 |
+
{"id": "state-002", "task": "Refactor the parse_config function — it's 200 lines and has 3 responsibilities", ...}
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
For each, all 3 teachers answered (claude-opus-4.7, gpt-5, deepseek-v4-pro);
|
| 278 |
+
agreement on the `(c)` choice for state-000 and state-001 (read more
|
| 279 |
+
files / check schema first) drives a clean DPO pair where the student's
|
| 280 |
+
action becomes the rejected. For state-002, all 3 agreed on `(c)` (write
|
| 281 |
+
characterization tests first) → another clean pair. These three records
|
| 282 |
+
pass through the `DJNormalizer` default recipe unchanged (length, words,
|
| 283 |
+
special-char ratios all in bounds; no duplicates).
|
| 284 |
+
|
| 285 |
+
The full 50-state trace cost **$0.98 mean** end-to-end across all three
|
| 286 |
+
teachers (spike 001 verdict). The framework's cost ceiling
|
| 287 |
+
(`max_total_usd`) and VOI gating drop this to ~$0.30/trace projected.
|
| 288 |
+
|
| 289 |
+
### 4e. End-to-end one-liner
|
| 290 |
+
|
| 291 |
+
```python
|
| 292 |
+
from composer_replication.replaysim import replay_and_normalize_trace
|
| 293 |
+
|
| 294 |
+
teacher_actions, normalized_pairs = await replay_and_normalize_trace(
|
| 295 |
+
states=states,
|
| 296 |
+
teachers=DEFAULT_TEACHERS,
|
| 297 |
+
agreement_threshold=2,
|
| 298 |
+
max_total_usd=10.0,
|
| 299 |
+
)
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
(`async def`; for sync callers use the sibling `replay_and_normalize_trace_sync`
|
| 303 |
+
in `composer_replication.replaysim.normalize`.)
|
| 304 |
+
|
| 305 |
+
---
|
| 306 |
+
|
| 307 |
+
## 5. Switching DPO → SimPO: one kwarg
|
| 308 |
+
|
| 309 |
+
```python
|
| 310 |
+
components = compose_loss(
|
| 311 |
+
model, batch,
|
| 312 |
+
alpha_sdpo=0.1,
|
| 313 |
+
beta_replay=0.05,
|
| 314 |
+
dpo_variant="simpo", # ← the only line that changes
|
| 315 |
+
simpo_beta=2.0, # paper default
|
| 316 |
+
simpo_gamma=1.0, # paper default
|
| 317 |
+
)
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
The kwarg is verified in `compose_loss`'s signature
|
| 321 |
+
(`composer_replication/loss.py:81`):
|
| 322 |
+
|
| 323 |
+
```python
|
| 324 |
+
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
### What changes in the loss curve
|
| 328 |
+
|
| 329 |
+
- **Channel 3 input requirements drop.** `compose_loss` no longer reads
|
| 330 |
+
`dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`. Reference-model
|
| 331 |
+
VRAM cost goes to zero. (Source: `composer_replication/loss.py:111–113`
|
| 332 |
+
and `composer_replication/distillation/simpo.py:23–27`.)
|
| 333 |
+
- **Loss scale shifts.** Standard DPO is
|
| 334 |
+
`-logsigmoid(β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))])`.
|
| 335 |
+
SimPO is `-logsigmoid(β·[avg_logπ(c) - avg_logπ(r)] - γ)` — average
|
| 336 |
+
per-token log-prob (length-normalized) and a constant target margin γ.
|
| 337 |
+
- **Loss is ≤ DPO loss when chosen/rejected separation is large.** The
|
| 338 |
+
unit test `test_simpo_loss_lower_for_better_separation`
|
| 339 |
+
(`composer_replication/distillation/tests/test_distillation_losses.py:35`)
|
| 340 |
+
verifies that a wider chosen-vs-rejected gap drives lower SimPO loss —
|
| 341 |
+
meaning, in practice, SimPO curves are *steeper* than DPO when the
|
| 342 |
+
preference signal is strong, and *flatter* when it's weak.
|
| 343 |
+
- **No KL-against-reference regularization.** This is both the upside (no
|
| 344 |
+
ref-model serving) and the risk (more tendency to drift). Watch for
|
| 345 |
+
reward-hacking-style degeneracies if your preference data has noise.
|
| 346 |
+
|
| 347 |
+
### When to use SimPO
|
| 348 |
+
|
| 349 |
+
- **GPU-poor.** You can't afford to keep a frozen reference policy resident
|
| 350 |
+
alongside the trainer.
|
| 351 |
+
- **Cold-start preference data.** Length-normalization (avg_logπ vs sum)
|
| 352 |
+
helps when chosen/rejected lengths are wildly imbalanced — common in
|
| 353 |
+
agentic traces where the student's failed attempt is short and the
|
| 354 |
+
teacher's correction is long.
|
| 355 |
+
- **You don't have ref logprobs precomputed.** SimPO needs nothing from
|
| 356 |
+
the reference policy.
|
| 357 |
+
|
| 358 |
+
When to **stay on DPO**: when you need the explicit KL anchor against
|
| 359 |
+
a known-good reference (e.g., when training over a long horizon and you
|
| 360 |
+
want to bound the drift), or when your preference data is very noisy and
|
| 361 |
+
the reference acts as a regularizer.
|
| 362 |
+
|
| 363 |
+
---
|
| 364 |
+
|
| 365 |
+
## 6. Adding TAID / Entropy-Aware OPD wrappers
|
| 366 |
+
|
| 367 |
+
Channel 2 (SDPO/OPSD) can be wrapped by **TAID** (Sakana AI,
|
| 368 |
+
arXiv:2501.16937) for capacity-gap distillation, or replaced by
|
| 369 |
+
**Entropy-Aware OPD** (ICLR 2026 Spotlight) for per-token forward/reverse-KL
|
| 370 |
+
gating. Both are verified in the public `compose_loss` kwargs:
|
| 371 |
+
|
| 372 |
+
```python
|
| 373 |
+
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 374 |
+
taid_schedule_step: int | None = None,
|
| 375 |
+
taid_total_steps: int | None = None,
|
| 376 |
+
taid_schedule: str = "linear", # "linear" | "cosine" | "exp"
|
| 377 |
+
taid_alpha_min: float = 0.0,
|
| 378 |
+
taid_alpha_max: float = 1.0,
|
| 379 |
+
entropy_opd_h_max: float | None = None,
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
(verified at `composer_replication/loss.py:82–93`.)
|
| 383 |
+
|
| 384 |
+
### TAID schedule kwargs explained
|
| 385 |
+
|
| 386 |
+
TAID interpolates between the **student's own distribution at step 0**
|
| 387 |
+
(`P_student_init`) and the teacher distribution:
|
| 388 |
+
|
| 389 |
+
```
|
| 390 |
+
P_target(t) = (1 - α(t)) · P_student_init + α(t) · P_teacher
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
where `α(t)` is a schedule controlled by:
|
| 394 |
+
- **`taid_schedule_step`** — the current global step. Required when
|
| 395 |
+
`sdpo_wrapper="taid"`; `compose_loss` raises `ValueError` if you forget it.
|
| 396 |
+
- **`taid_total_steps`** — total planned training steps. Same.
|
| 397 |
+
- **`taid_schedule`** — `"linear"`, `"cosine"`, or `"exp"` (paper default
|
| 398 |
+
exp uses `1 - exp(-5·progress)`).
|
| 399 |
+
- **`taid_alpha_min`** / **`taid_alpha_max`** — schedule range. Default
|
| 400 |
+
`[0, 1]`. Pin both to `1.0` to recover plain SDPO; pin both to `0.0` to
|
| 401 |
+
pin the loss against `P_student_init` (a regularizer that ignores the
|
| 402 |
+
teacher entirely — see proof below).
|
| 403 |
+
|
| 404 |
+
To use TAID, also provide the frozen-init logits via either:
|
| 405 |
+
- `inputs["student_init_logits"]` (precomputed snapshot — preferred), OR
|
| 406 |
+
- `inputs["student_init_input_ids"]` (frozen forward fallback; only valid
|
| 407 |
+
early in training before the model has drifted).
|
| 408 |
+
|
| 409 |
+
If neither is provided, `_resolve_student_init_logits` raises
|
| 410 |
+
`ValueError` with a clear message
|
| 411 |
+
(`composer_replication/loss.py:351–392`).
|
| 412 |
+
|
| 413 |
+
### Entropy-Aware OPD
|
| 414 |
+
|
| 415 |
+
Drop-in for channel 2 — gates between forward KL (mode-covering) and
|
| 416 |
+
reverse KL (mode-seeking) per token, weighted by the teacher's entropy:
|
| 417 |
+
|
| 418 |
+
```
|
| 419 |
+
L = Σ_t w(t) · KL_fwd_t + (1 - w(t)) · KL_rev_t
|
| 420 |
+
w(t) = clamp(H_teacher(t) / h_max, 0, 1)
|
| 421 |
+
```
|
| 422 |
+
|
| 423 |
+
`entropy_opd_h_max=None` (the default) auto-sets `h_max = log(V)` (the
|
| 424 |
+
maximum-entropy bound for a vocab-V softmax).
|
| 425 |
+
|
| 426 |
+
### Boundary-condition unit test (proof of correctness)
|
| 427 |
+
|
| 428 |
+
The test `test_taid_loss_alpha_zero_ignores_teacher`
|
| 429 |
+
(`composer_replication/distillation/tests/test_distillation_losses.py:153`)
|
| 430 |
+
pins the most important TAID invariant — at `α=0` the teacher is
|
| 431 |
+
*completely* hidden from the gradient:
|
| 432 |
+
|
| 433 |
+
```python
|
| 434 |
+
def test_taid_loss_alpha_zero_ignores_teacher():
|
| 435 |
+
"""At alpha=0, teacher gradient should not flow through to student."""
|
| 436 |
+
B, T, V = 1, 2, 4
|
| 437 |
+
student_init = torch.randn(B, T, V)
|
| 438 |
+
s1 = torch.randn(B, T, V, requires_grad=True)
|
| 439 |
+
teacher_a = torch.zeros(B, T, V); teacher_a[..., 0] = 10.0
|
| 440 |
+
teacher_b = torch.zeros(B, T, V); teacher_b[..., 3] = 10.0
|
| 441 |
+
# alpha pinned to 0 → blended target = student_init regardless of teacher
|
| 442 |
+
loss_a = taid_loss(s1, teacher_a, student_init, schedule_step=0,
|
| 443 |
+
total_steps=100, alpha_min=0.0, alpha_max=0.0)
|
| 444 |
+
loss_b = taid_loss(s1, teacher_b, student_init, schedule_step=0,
|
| 445 |
+
total_steps=100, alpha_min=0.0, alpha_max=0.0)
|
| 446 |
+
# Two completely different teachers must give the same loss.
|
| 447 |
+
assert abs(float(loss_a) - float(loss_b)) < 1e-4
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
This is the load-bearing test for TAID: if the schedule's α=0 endpoint
|
| 451 |
+
ever leaks teacher signal into the gradient, this test fires and the
|
| 452 |
+
contract is broken. Companion tests
|
| 453 |
+
(`test_taid_alpha_schedule_endpoints` line 86,
|
| 454 |
+
`test_taid_blended_logits_endpoints` line 115) pin the schedule's
|
| 455 |
+
endpoints (α=0 → student_init, α=1 → teacher) and the half-way mixing
|
| 456 |
+
behavior.
|
| 457 |
+
|
| 458 |
+
For Entropy-OPD, the boundary test is
|
| 459 |
+
`test_entropy_aware_opd_zero_when_distributions_match` (line 217): when
|
| 460 |
+
student logits ≡ teacher logits, both KLs are 0 and the loss must be 0
|
| 461 |
+
to numerical precision.
|
| 462 |
+
|
| 463 |
+
---
|
| 464 |
+
|
| 465 |
+
## 7. Going multi-replica with serverless DiLoCo
|
| 466 |
+
|
| 467 |
+
DiLoCo is the outer-loop optimizer that lets you run N replicas in
|
| 468 |
+
parallel, sync them every H inner steps, and tolerate slow links — see
|
| 469 |
+
`docs/adrs/ADR-005-serverless-diloco.md` for the design. The framework
|
| 470 |
+
gives you three increasingly-distant deployments:
|
| 471 |
+
|
| 472 |
+
### Step 1 — `LocalProcessExecutor` for development
|
| 473 |
+
|
| 474 |
+
```python
|
| 475 |
+
from composer_replication.diloco.serverless import (
|
| 476 |
+
LocalProcessExecutor, ObjectStoreAllReduce,
|
| 477 |
+
)
|
| 478 |
+
import tempfile
|
| 479 |
+
|
| 480 |
+
with tempfile.TemporaryDirectory() as td:
|
| 481 |
+
rendezvous = ObjectStoreAllReduce(td, rank=0, world_size=4)
|
| 482 |
+
executor = LocalProcessExecutor()
|
| 483 |
+
handles = executor.launch_replicas(
|
| 484 |
+
n_replicas=4,
|
| 485 |
+
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
|
| 486 |
+
entrypoint_args={"rendezvous_uri": td, "rank_env": "REPLICA_RANK"},
|
| 487 |
+
)
|
| 488 |
+
results = executor.collect(handles, timeout=600)
|
| 489 |
+
```
|
| 490 |
+
|
| 491 |
+
`LocalProcessExecutor` (`composer_replication/diloco/serverless/executor.py:160`)
|
| 492 |
+
spawns N child processes via `multiprocessing.get_context("spawn")` and
|
| 493 |
+
sets `REPLICA_RANK={0..N-1}` in each child's env. It satisfies the
|
| 494 |
+
`ServerlessExecutor` Protocol (line 35) — the same Protocol the cloud
|
| 495 |
+
adapters implement. So the dev-loop code is byte-identical to the cloud
|
| 496 |
+
deploy: only the executor instance changes.
|
| 497 |
+
|
| 498 |
+
### Step 2 — `ObjectStoreAllReduce` as the rendezvous
|
| 499 |
+
|
| 500 |
+
```python
|
| 501 |
+
# Local file:// for tests
|
| 502 |
+
rendezvous = ObjectStoreAllReduce("/tmp/diloco-runs/run42/", rank=0, world_size=4)
|
| 503 |
+
|
| 504 |
+
# Real S3 (after `pip install -e .[serverless]`)
|
| 505 |
+
rendezvous = ObjectStoreAllReduce(
|
| 506 |
+
"s3://my-bucket/diloco-runs/run42/",
|
| 507 |
+
rank=0, world_size=4,
|
| 508 |
+
timeout_s=1800.0,
|
| 509 |
+
)
|
| 510 |
+
```
|
| 511 |
+
|
| 512 |
+
The communication pattern is `S3 PutObject + N GetObjects` once per
|
| 513 |
+
inner H steps (matches DiLoCo's actual sync cadence,
|
| 514 |
+
arXiv:2311.08105 §3.2). For 1B-param bf16, that's ~2 GB / 30 minutes
|
| 515 |
+
per replica — well within S3 free-tier. On the inner side the framework
|
| 516 |
+
exposes a `MockManager` that drops into the `torchft.Manager` slot, so
|
| 517 |
+
you can validate the rendezvous logic before plugging in real torchft
|
| 518 |
+
(verified by `test_serverless_diloco_integration.py`).
|
| 519 |
+
|
| 520 |
+
### Step 3 — point at `ModalExecutor` / `HFJobsExecutor`
|
| 521 |
+
|
| 522 |
+
```python
|
| 523 |
+
# Modal (skeleton at composer_replication/diloco/serverless/modal.py)
|
| 524 |
+
from composer_replication.diloco.serverless.modal import ModalExecutor
|
| 525 |
+
executor = ModalExecutor(image="modal:python3.11", gpu="A100")
|
| 526 |
+
|
| 527 |
+
# HuggingFace Jobs (skeleton at composer_replication/diloco/serverless/hf_jobs.py)
|
| 528 |
+
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
|
| 529 |
+
executor = HFJobsExecutor(hardware="a10g-large")
|
| 530 |
+
|
| 531 |
+
# Same Protocol — same launch_replicas / poll / collect calls as Local
|
| 532 |
+
handles = executor.launch_replicas(n_replicas=4, ...)
|
| 533 |
+
```
|
| 534 |
+
|
| 535 |
+
Both adapters check their cloud SDK at `__init__` time (not at module
|
| 536 |
+
import) so they don't break the package if you don't have `modal` or
|
| 537 |
+
`huggingface_hub` installed. Production maturity: dev-ready for cloud
|
| 538 |
+
trial; per ADR-005, full HA-cluster fan-out lives in v0.2+.
|
| 539 |
+
|
| 540 |
+
---
|
| 541 |
+
|
| 542 |
+
## 8. Picking an RL backend
|
| 543 |
+
|
| 544 |
+
Four canonical recipes, each tied to an upstream framework. Source:
|
| 545 |
+
`docs/INTEGRATION_ARCHITECTURE.md` Recipes A–D plus
|
| 546 |
+
`docs/adrs/ADR-006-rl-frameworks.md`.
|
| 547 |
+
|
| 548 |
+
### Recipe A — TRL `GRPOTrainer` subclass
|
| 549 |
+
|
| 550 |
+
`ComposerReplicationTrainer` is a `trl.GRPOTrainer` subclass that
|
| 551 |
+
overrides `_compute_loss(model, inputs)` to compose the same 3 channels
|
| 552 |
+
on top of TRL's real reward + advantage machinery. Install:
|
| 553 |
+
`pip install -e ".[train]"`. Then:
|
| 554 |
+
|
| 555 |
+
```python
|
| 556 |
+
from composer_replication import ComposerReplicationTrainer
|
| 557 |
+
trainer = ComposerReplicationTrainer(model=..., reward_funcs=[...], ...)
|
| 558 |
+
trainer.train()
|
| 559 |
+
```
|
| 560 |
+
|
| 561 |
+
**When to use it:** This is the v0.0/v0.1 recommended path. You want
|
| 562 |
+
real GRPO with rewards, you have HF model + dataset + (mostly) standard
|
| 563 |
+
GRPO infrastructure, and you don't need >100B-param scale. TRL is
|
| 564 |
+
mature, the trainer is a small subclass, and the same `compose_loss`
|
| 565 |
+
math runs in both the verification harness and in production with no
|
| 566 |
+
re-coding.
|
| 567 |
+
|
| 568 |
+
→ See `docs/INTEGRATION_ARCHITECTURE.md` § "Recipe A: TRL `GRPOTrainer`
|
| 569 |
+
subclass" (line 205).
|
| 570 |
+
|
| 571 |
+
### Recipe B — VeRL custom `adv_estimator` + DataProto extension
|
| 572 |
+
|
| 573 |
+
VeRL replaces TRL's reward+advantage machinery with a Ray-driven actor
|
| 574 |
+
graph that's specifically optimized for distributed RL training of
|
| 575 |
+
large LMs. Composition with the framework: extend `DataProto` with the
|
| 576 |
+
hint-conditioned columns + DPO pair fields, register a custom
|
| 577 |
+
`adv_estimator` that calls the same `compose_loss` body.
|
| 578 |
+
|
| 579 |
+
**When to use it:** You're past 7B-param, you have multi-host setup
|
| 580 |
+
(Ray cluster), and TRL's single-process trainer is the bottleneck. VeRL
|
| 581 |
+
is the recommended scale path for v0.2+. Trade-off: the extension surface
|
| 582 |
+
is larger and the docs are sparser than TRL's.
|
| 583 |
+
|
| 584 |
+
→ See `docs/INTEGRATION_ARCHITECTURE.md` § "Recipe B: VeRL custom
|
| 585 |
+
`adv_estimator`" (line 289).
|
| 586 |
+
|
| 587 |
+
### Recipe C — PRIME-RL with DPPO-clip details
|
| 588 |
+
|
| 589 |
+
`composer_replication/recipes/prime_rl/composer_loss.py` ships a
|
| 590 |
+
`loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2,
|
| 591 |
+
dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3)` adapter that maps
|
| 592 |
+
PRIME-RL's `LossInputs` struct (1-D per-sample tensors:
|
| 593 |
+
`trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`,
|
| 594 |
+
`advantages`, `loss_mask`) into our 3-channel composition.
|
| 595 |
+
|
| 596 |
+
The DPPO+KL bit is what makes PRIME-RL distinctive — and we mirror
|
| 597 |
+
PRIME-RL's upstream `default_loss_fn` exactly (verified against
|
| 598 |
+
`prime_rl/trainer/rl/loss.py` lines 116-165):
|
| 599 |
+
|
| 600 |
+
```python
|
| 601 |
+
log_ir = trainer_logprobs - inference_logprobs
|
| 602 |
+
ir = exp(log_ir) # importance ratio
|
| 603 |
+
probs_diff = exp(trainer_logprobs) - exp(inference_logprobs)
|
| 604 |
+
invalid_high = probs_diff > dppo_mask_high # for positive-advantage tokens
|
| 605 |
+
invalid_low = probs_diff < -dppo_mask_low # for negative-advantage tokens
|
| 606 |
+
invalid = where(advantages > 0, invalid_high, invalid_low)
|
| 607 |
+
keep = loss_mask & ~invalid
|
| 608 |
+
pg_loss = keep * (adv_tau * advantages) * ir
|
| 609 |
+
kl_loss = loss_mask * log_ir**2
|
| 610 |
+
loss = (-pg_loss + kl_tau * kl_loss).sum()
|
| 611 |
+
```
|
| 612 |
+
|
| 613 |
+
Three things to remember: (1) the mask gate is on **probability-space**
|
| 614 |
+
`exp(trainer_lp) - exp(inference_lp)`, not on the log-ratio; (2) the
|
| 615 |
+
policy-gradient term is multiplied by the importance ratio
|
| 616 |
+
`exp(trainer_lp - inference_lp)`, not by `trainer_lp` directly (proper
|
| 617 |
+
IS-corrected gradient, not REINFORCE); (3) the mask is **conditioned on
|
| 618 |
+
the sign of the advantage** — positive-advantage tokens are dropped on
|
| 619 |
+
the upper bound, negative-advantage tokens on the lower. Defaults
|
| 620 |
+
`dppo_mask_high=dppo_mask_low=0.2` and `adv_tau=1.0, kl_tau=1e-3`
|
| 621 |
+
match PRIME-RL's `DefaultLossConfig` (all fields `Field(..., ge=0)`).
|
| 622 |
+
SDPO (channel 2) is gated `NotImplementedError` in v0 because PRIME-RL
|
| 623 |
+
exposes log-probs, not full vocab logits; channel 3 (trace-replay DPO)
|
| 624 |
+
emits a warning if `beta_dpo != 0`.
|
| 625 |
+
|
| 626 |
+
**When to use it:** You're already in the PRIME-Intellect / decentralized
|
| 627 |
+
training universe, you want INTELLECT-style scaling on a long-horizon
|
| 628 |
+
training run, and DPPO masking is part of your existing reward+advantage
|
| 629 |
+
recipe. Install: `pip install -e ".[prime-rl]"`.
|
| 630 |
+
|
| 631 |
+
→ See `composer_replication/recipes/prime_rl/prime_rl_recipe.md` and
|
| 632 |
+
`docs/INTEGRATION_ARCHITECTURE.md` § "Recipe C: TorchForge + Monarch"
|
| 633 |
+
(line 356).
|
| 634 |
+
|
| 635 |
+
### Recipe C+D — Monarch as actor mesh
|
| 636 |
+
|
| 637 |
+
Monarch (the actor framework underpinning TorchForge) hosts the
|
| 638 |
+
trainer/generator/manager actors in a topology-aware mesh. The framework
|
| 639 |
+
ships *skeleton* actor definitions at
|
| 640 |
+
`composer_replication/recipes/monarch/actors.py` (TrainerActor,
|
| 641 |
+
GeneratorActor) and a layout doc at `monarch_actor_layout.md`. v0
|
| 642 |
+
intentionally *fails fast* if you try to instantiate them
|
| 643 |
+
(`raise NotImplementedError("v0 skeleton; deferred to v0.2 per ADR-006")`)
|
| 644 |
+
because the upstream Monarch API is still moving.
|
| 645 |
+
|
| 646 |
+
**When to use it:** Reference-pattern reading only in v0. Decision point
|
| 647 |
+
is v0.2 once the upstream actor API stabilizes. Treat the skeleton as
|
| 648 |
+
shape-of-the-final-answer documentation, not as a production target.
|
| 649 |
+
Install: `pip install -e ".[prime-rl,monarch]"` for the full surface.
|
| 650 |
+
|
| 651 |
+
→ See `composer_replication/recipes/monarch/monarch_actor_layout.md`
|
| 652 |
+
and `docs/adrs/ADR-006-rl-frameworks.md`.
|
| 653 |
+
|
| 654 |
+
---
|
| 655 |
+
|
| 656 |
+
## Common pitfalls + what tests catch them
|
| 657 |
+
|
| 658 |
+
The framework's 124-test suite is structured so each pitfall has a
|
| 659 |
+
specific test-file home. If you hit one of these in production, the
|
| 660 |
+
corresponding test is your fastest reproducer.
|
| 661 |
+
|
| 662 |
+
| Pitfall | Symptom | Test file (catches it) |
|
| 663 |
+
|-----------------------------------------------------------------------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------|
|
| 664 |
+
| Forgetting `taid_schedule_step` when `sdpo_wrapper="taid"` | `ValueError` at first step | `composer_replication/tests/test_compose_loss_integration.py` (kwarg validation) |
|
| 665 |
+
| TAID α=0 endpoint leaks teacher signal | Teacher swap changes the loss when α should be 0 | `test_taid_loss_alpha_zero_ignores_teacher` in `composer_replication/distillation/tests/test_distillation_losses.py:153` |
|
| 666 |
+
| TAID α=1 endpoint differs from plain SDPO | Bit-difference vs reference SDPO at the schedule end | `test_taid_blended_logits_endpoints` in `composer_replication/distillation/tests/test_distillation_losses.py:115` |
|
| 667 |
+
| SimPO loss not differentiable through the loss-of-sigmoid path | `chosen.grad is None` after backward | `test_simpo_loss_differentiable` in `composer_replication/distillation/tests/test_distillation_losses.py:50` |
|
| 668 |
+
| SimPO shape-mismatch slips through silently | Broadcasting bug, NaN downstream | `test_simpo_loss_shape_mismatch_raises` in `composer_replication/distillation/tests/test_distillation_losses.py:61` |
|
| 669 |
+
| Entropy-OPD failing to zero out when distributions match | Loss > 0 when student≡teacher | `test_entropy_aware_opd_zero_when_distributions_match` in `composer_replication/distillation/tests/test_distillation_losses.py:217` |
|
| 670 |
+
| Entropy of one-hot ≠ 0 / entropy of uniform ≠ log(V) | Wrong gating weights w(t) | `test_teacher_entropy_one_hot_is_zero` and `test_teacher_entropy_uniform_is_log_v` in `composer_replication/distillation/tests/test_distillation_losses.py:175,183` |
|
| 671 |
+
| `DJNormalizer` records missing the chat-messages shape | Filters silently no-op or crash | `test_dpo_pair_to_dj_record_shape` in `composer_replication/replaysim/tests/test_replaysim.py:44` |
|
| 672 |
+
| `DJNormalizer` round-trip drops `state_messages` / metadata | Lost provenance | `test_dj_record_to_normalized_roundtrip` and `test_dj_record_to_normalized_preserves_state_messages` in `composer_replication/replaysim/tests/test_replaysim.py` |
|
| 673 |
+
| `ObjectStoreAllReduce` accepts an out-of-bounds rank | Silent corruption of the all-reduce average | `test_object_store_allreduce_init_validates_rank` in `composer_replication/diloco/serverless/tests/test_serverless_local.py:31` |
|
| 674 |
+
| `ObjectStoreAllReduce(world_size=1)` doesn't passthrough cleanly | False all-reduce on single replica | `test_object_store_allreduce_world_size_1_passthrough` in `composer_replication/diloco/serverless/tests/test_serverless_local.py:46` |
|
| 675 |
+
| `LocalProcessExecutor` doesn't propagate child failures to `collect()` | Silent test pass when a replica crashed | `test_serverless_diloco_integration.py` in `composer_replication/diloco/serverless/tests/` |
|
| 676 |
+
| PRIME-RL adapter accidentally uses `(B, T)` shape instead of per-sample `(seq,)` | Shape mismatch / wrong reduction | `composer_replication/recipes/prime_rl/tests/test_composer_loss.py` (10 tests covering shape and DPPO mask edges) |
|
| 677 |
+
| Channel 2/3 fails to auto-disable when its inputs are absent | Crash on missing key, not graceful skip | `composer_replication/tests/test_compose_loss_integration.py` (`(a) defaults reproduce existing compose_loss output bit-exact`) |
|
| 678 |
+
|
| 679 |
+
Run the full suite with `pytest` from the repo root.
|
| 680 |
+
|
| 681 |
+
---
|
| 682 |
+
|
| 683 |
+
**File path:** `/mnt/e/CS/HF/composer-replication-framework/docs/USER_GUIDE.md`
|
|
@@ -101,28 +101,52 @@ pluggable distillation module:**
|
|
| 101 |
differentiable, returns scalar, matches paper formulas at boundary
|
| 102 |
conditions)
|
| 103 |
|
| 104 |
-
### Wave 14
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
`
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
```python
|
| 128 |
from composer_replication.distillation import simpo_loss, taid_loss
|
|
|
|
| 101 |
differentiable, returns scalar, matches paper formulas at boundary
|
| 102 |
conditions)
|
| 103 |
|
| 104 |
+
### Closed in Wave 14 — `compose_loss` integration landed
|
| 105 |
+
|
| 106 |
+
**Status (Wave 14, T1):** the four kwargs (`dpo_variant`,
|
| 107 |
+
`sdpo_wrapper`, `taid_schedule_step`, `taid_total_steps`) have been
|
| 108 |
+
added to `composer_replication.compose_loss` and are exercised by
|
| 109 |
+
`composer_replication/tests/test_compose_loss_integration.py`. The
|
| 110 |
+
gap flagged by the Wave 13 cross-model review
|
| 111 |
+
(`docs/research/WAVE_13_FINAL_REVIEW.md` Finding 2) is closed:
|
| 112 |
+
|
| 113 |
+
- `compose_loss(model, batch, dpo_variant="simpo", sdpo_wrapper="taid",
|
| 114 |
+
taid_schedule_step=step, taid_total_steps=max_steps, ...)` is now a
|
| 115 |
+
valid call signature.
|
| 116 |
+
- All three standalone losses (`simpo_loss`, `taid_loss`,
|
| 117 |
+
`entropy_aware_opd_loss`) remain importable and unit-tested as
|
| 118 |
+
before — the Wave 14 work was purely the kwarg surface + composition
|
| 119 |
+
glue, not a loss-formula change.
|
| 120 |
+
|
| 121 |
+
The historical sections below are preserved verbatim for context but
|
| 122 |
+
**describe the pre-Wave-14 state** and are superseded by the closed
|
| 123 |
+
status above.
|
| 124 |
+
|
| 125 |
+
---
|
| 126 |
+
|
| 127 |
+
#### Superseded — pre-Wave-14 wording (kept for history)
|
| 128 |
+
|
| 129 |
+
> An earlier draft of this ADR claimed `composer_replication.compose_loss`
|
| 130 |
+
> would receive new kwargs (`dpo_variant`, `sdpo_wrapper`, `taid_schedule_step`,
|
| 131 |
+
> `taid_total_steps`). **The Wave 13 cross-model review
|
| 132 |
+
> (docs/research/WAVE_13_FINAL_REVIEW.md Finding 2) flagged that those
|
| 133 |
+
> kwargs were never actually added to `compose_loss`** — the standalone
|
| 134 |
+
> losses landed but the integration into the framework's loss composition
|
| 135 |
+
> is not done. To stay honest:
|
| 136 |
+
>
|
| 137 |
+
> - **What works in Wave 13**: `from composer_replication.distillation
|
| 138 |
+
> import simpo_loss, taid_loss, entropy_aware_opd_loss` — all three are
|
| 139 |
+
> importable, type-checked, unit-tested, and ready to be called directly.
|
| 140 |
+
> - **What does NOT work in Wave 13**: passing
|
| 141 |
+
> `compose_loss(model, batch, dpo_variant="simpo", sdpo_wrapper="taid", ...)`.
|
| 142 |
+
> That call signature does not exist; it would raise `TypeError`.
|
| 143 |
+
> - **Wave 14 plan**: add the four kwargs to `compose_loss` with a small
|
| 144 |
+
> integration test exercising at least one combination (SDPO+TAID + plain
|
| 145 |
+
> DPO would suffice). Estimated ~30 LOC + 2-3 tests.
|
| 146 |
+
|
| 147 |
+
Users wanting the new losses as standalone callables can still use them
|
| 148 |
+
directly in their own loss-composition code (this path is unchanged by
|
| 149 |
+
the Wave 14 integration):
|
| 150 |
|
| 151 |
```python
|
| 152 |
from composer_replication.distillation import simpo_loss, taid_loss
|
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wave 14 Adversarial Cross-Model Review
|
| 2 |
+
|
| 3 |
+
**Reviewer:** Claude Opus 4.7 (sub-agent via delegate_task)
|
| 4 |
+
**Date:** 2026-05-26
|
| 5 |
+
**Method:** Read every Wave 13 finding, every Wave 14 closure, all 4 doc files, **cloned PRIME-RL upstream to verify T4 claims**, ran 61 wave-relevant tests.
|
| 6 |
+
|
| 7 |
+
## Top-line verdict
|
| 8 |
+
|
| 9 |
+
**CONDITIONAL PASS with 1 BLOCKER + 4 SUGGESTIONs + 2 NITs.** Wave 14
|
| 10 |
+
closes Wave 13 BLOCKER 2 (T1 — compose_loss kwargs) and Suggestion 3
|
| 11 |
+
(T2 — replaysim) cleanly. T3 (MockManager surface audit) is solid but
|
| 12 |
+
only tests `world_size=1`. **T4 (PRIME-RL "real GRPO + DPPO") does not
|
| 13 |
+
match PRIME-RL's actual `default_loss_fn`** despite claiming to mirror
|
| 14 |
+
it; that error has been pasted into USER_GUIDE.md, INTEGRATION_RECIPES.md,
|
| 15 |
+
and API_REFERENCE.md.
|
| 16 |
+
|
| 17 |
+
Same signal-to-noise as Wave 11 + Wave 13 reviewers: 1 genuine BLOCKER.
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Finding 1 — BLOCKER: T4 PRIME-RL "DPPO importance-sampling-ratio clip" is neither importance sampling nor matches PRIME-RL.
|
| 22 |
+
|
| 23 |
+
**Severity:** BLOCKER
|
| 24 |
+
**Evidence:** `composer_replication/recipes/prime_rl/composer_loss.py:118-131`.
|
| 25 |
+
The implementation computes
|
| 26 |
+
```python
|
| 27 |
+
grpo_loss = -(advantages * trainer_lp * keep_mask).sum() / keep_mask.sum()
|
| 28 |
+
```
|
| 29 |
+
That's **pure REINFORCE-with-advantage** — the masking gate is the only
|
| 30 |
+
nod toward DPPO; there is no importance-sampling ratio multiplication
|
| 31 |
+
anywhere.
|
| 32 |
+
|
| 33 |
+
**Real PRIME-RL** (`/tmp/prime-rl-clone/src/prime_rl/trainer/rl/loss.py:128-153`,
|
| 34 |
+
the `default_loss_fn` on `main` as of 2026-05-26):
|
| 35 |
+
```python
|
| 36 |
+
log_importance_ratio = trainer_logprobs - inference_logprobs
|
| 37 |
+
importance_ratio = torch.exp(log_importance_ratio)
|
| 38 |
+
probs_diff = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs)
|
| 39 |
+
positive_advantages = advantages > 0
|
| 40 |
+
dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high
|
| 41 |
+
dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low
|
| 42 |
+
dppo_invalid_mask = torch.where(positive_advantages, dppo_invalid_mask_high, dppo_invalid_mask_low)
|
| 43 |
+
keep_mask = loss_mask & ~dppo_invalid_mask
|
| 44 |
+
pg_loss = -(keep_mask * advantages * importance_ratio).sum() # NO division
|
| 45 |
+
kl_loss = adv_tau * (log_importance_ratio**2 * keep_mask).sum() # KL term
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**Three concrete divergences from Wave 14's implementation:**
|
| 49 |
+
|
| 50 |
+
1. **Mask gate is on `probs_diff`** (a probability-space quantity), NOT
|
| 51 |
+
`log_ratio` (a log-space quantity). These have different magnitudes:
|
| 52 |
+
`probs_diff=0.2` corresponds to `log_ratio≈log(1.2)≈0.18` for a
|
| 53 |
+
trainer prob of 1.0 vs inference prob of 0.8. With our `log_ratio>4.0`
|
| 54 |
+
gate, the mask never fires for normal training distributions; PRIME-RL's
|
| 55 |
+
`probs_diff>0.2` gate fires routinely.
|
| 56 |
+
|
| 57 |
+
2. **PRIME-RL multiplies by `importance_ratio = exp(log_ratio)`**;
|
| 58 |
+
Wave 14 multiplies by `trainer_lp` directly. This is the difference
|
| 59 |
+
between actual policy-gradient correction (PRIME-RL) and naive
|
| 60 |
+
REINFORCE.
|
| 61 |
+
|
| 62 |
+
3. **PRIME-RL's mask is sign-conditioned on advantage** (positive
|
| 63 |
+
advantages clipped against `dppo_mask_high`, negative against
|
| 64 |
+
`-dppo_mask_low`); Wave 14 ORs them together unconditionally.
|
| 65 |
+
|
| 66 |
+
**Plus:** the KL term is missing entirely.
|
| 67 |
+
|
| 68 |
+
**Plus:** the defaults claimed as "PRIME-RL's defaults" — `dppo_mask_high=4.0,
|
| 69 |
+
dppo_mask_low=-4.0` — are wrong. PRIME-RL's `DefaultLossConfig`
|
| 70 |
+
(`configs/trainer.py:412-424`) sets `dppo_mask_high=0.2, dppo_mask_low=0.2`
|
| 71 |
+
with `Field(..., ge=0)` validation that would *reject* a negative value.
|
| 72 |
+
PRIME-RL's code negates at use site: `probs_diff < -loss_config.dppo_mask_low`.
|
| 73 |
+
|
| 74 |
+
**Plus:** the docstring (`composer_loss.py:32-49`), USER_GUIDE.md:599-608,
|
| 75 |
+
INTEGRATION_RECIPES.md:426-429 + 482-487, and API_REFERENCE.md:1364 all
|
| 76 |
+
repeat the wrong formula and the wrong "matches PRIME-RL" claim.
|
| 77 |
+
|
| 78 |
+
**Fix direction:** Either (a) actually mirror `default_loss_fn` (mask on
|
| 79 |
+
`probs_diff`, multiply by `importance_ratio`, add KL term, advantage-
|
| 80 |
+
conditioned mask, `.sum()` reduction with token-count returned for
|
| 81 |
+
caller-side scaling), or (b) drop the "matches PRIME-RL" framing and
|
| 82 |
+
rename to "REINFORCE-with-advantage stub + log-ratio mask" everywhere.
|
| 83 |
+
|
| 84 |
+
Wave 13 Finding 6 is **not actually closed** by Wave 14.
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Finding 2 — SUGGESTION: ADR-007 still says Wave 14 hasn't done the integration.
|
| 89 |
+
|
| 90 |
+
**Severity:** SUGGESTION
|
| 91 |
+
**Evidence:** `docs/adrs/ADR-007-self-distillation-losses.md:104-122` reads:
|
| 92 |
+
> **Wave 14+ work — `compose_loss` integration is NOT in this wave**
|
| 93 |
+
> ... Wave 14 plan: add the four kwargs ...
|
| 94 |
+
|
| 95 |
+
But Wave 14 *did* add them (verified — `loss.py:80-93`). The ADR was
|
| 96 |
+
written defensively after Wave 13 review and never updated when T1 landed.
|
| 97 |
+
|
| 98 |
+
**Net effect:** a user reading ADR-007 is told the SimPO/TAID kwargs
|
| 99 |
+
don't work; a user reading USER_GUIDE/API_REFERENCE is told they do.
|
| 100 |
+
|
| 101 |
+
**Fix direction:** flip ADR-007 status section to "Closed in Wave 14 —
|
| 102 |
+
see test_compose_loss_integration.py".
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## Finding 3 — SUGGESTION: ModalExecutor instantiation example in INTEGRATION_RECIPES is dead code.
|
| 107 |
+
|
| 108 |
+
**Severity:** SUGGESTION
|
| 109 |
+
**Evidence:** `docs/INTEGRATION_RECIPES.md:519-533` shows
|
| 110 |
+
```python
|
| 111 |
+
executor = ModalExecutor(app="composer-prime-rl")
|
| 112 |
+
executor.launch_replicas(...)
|
| 113 |
+
```
|
| 114 |
+
But `composer_replication/diloco/serverless/modal.py:64-66` raises
|
| 115 |
+
`NotImplementedError` from `__init__`. Same pattern in `HFJobsExecutor`.
|
| 116 |
+
The recipe doc warns about skeleton-status much further down (line 731),
|
| 117 |
+
but the inline code example at line 519 will break the moment a reader
|
| 118 |
+
copy-pastes it.
|
| 119 |
+
|
| 120 |
+
Wave 13 Finding 7 noted this softness; Wave 14 made it worse by writing
|
| 121 |
+
example code that calls a constructor that always raises.
|
| 122 |
+
|
| 123 |
+
**Fix direction:** in every code block that calls `ModalExecutor(...)`,
|
| 124 |
+
prepend a comment `# Wave 14: skeleton — raises NotImplementedError`
|
| 125 |
+
or flip examples to `LocalProcessExecutor`.
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## Finding 4 — SUGGESTION: MockManager + DiLoCo integration test only exercises `world_size=1`.
|
| 130 |
+
|
| 131 |
+
**Severity:** SUGGESTION
|
| 132 |
+
**Evidence:** `composer_replication/diloco/serverless/tests/test_serverless_diloco_integration.py:44-51`,
|
| 133 |
+
`:108-109`, `:161`. Both `test_mockmanager_diloco_outer_round_completes`
|
| 134 |
+
and `test_mockmanager_diloco_two_outer_rounds_step_counter` use
|
| 135 |
+
`world_size=1`.
|
| 136 |
+
|
| 137 |
+
With one replica, `ObjectStoreAllReduce.allreduce` returns the tensor
|
| 138 |
+
unchanged (its own mean), so an averaging bug in the multi-replica path
|
| 139 |
+
could not be caught by this test. The pseudo-gradient sign convention
|
| 140 |
+
is pinned by the unrelated spike-008 test, but **no test combines
|
| 141 |
+
MockManager + DiLoCo + multi-process** — i.e. the actual deployment
|
| 142 |
+
scenario is unverified end-to-end.
|
| 143 |
+
|
| 144 |
+
Wave 13 Finding 4 is closed in spirit (call surface is now exhaustive)
|
| 145 |
+
but not in the deepest sense.
|
| 146 |
+
|
| 147 |
+
**Fix direction:** add one multi-process test that spawns `n_replicas`
|
| 148 |
+
subprocesses, each constructing `MockManager(store) → make_diloco_outer_loop`,
|
| 149 |
+
and asserts that after one outer round all replicas converge to the same
|
| 150 |
+
parameter values (i.e. averaging actually happened).
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## Finding 5 — SUGGESTION: T4 unit tests pin the wrong implementation as ground truth.
|
| 155 |
+
|
| 156 |
+
**Severity:** SUGGESTION
|
| 157 |
+
**Evidence:** `composer_replication/recipes/prime_rl/tests/test_composer_loss.py:90-128`
|
| 158 |
+
(`test_dppo_mask_clips_extreme_ratios`). The expected value `1.5/3` is
|
| 159 |
+
computed against the buggy formula (Finding 1).
|
| 160 |
+
|
| 161 |
+
The 10 PRIME-RL tests all pass — but they're testing self-consistency,
|
| 162 |
+
not parity with PRIME-RL. A reader looking at "10 unit tests, all green"
|
| 163 |
+
infers correctness; correctness is not what they verify. This is the
|
| 164 |
+
kind of test honesty failure that Wave 11 + Wave 13 reviewers found in
|
| 165 |
+
different forms.
|
| 166 |
+
|
| 167 |
+
**Fix direction:** add at least one test whose expected value is
|
| 168 |
+
hand-computed from `default_loss_fn` in PRIME-RL (or import + invoke
|
| 169 |
+
`default_loss_fn` if the dependency is available, mark the test
|
| 170 |
+
`@pytest.mark.skipif(not _HAS_PRIME_RL)`).
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## Finding 6 — NIT: README/test-count drift.
|
| 175 |
+
|
| 176 |
+
Wave 14 task description claims "124 tests passing as of Wave 14"; actual
|
| 177 |
+
`pytest --collect-only` reports **134 collected**. Of those, the 61-test
|
| 178 |
+
wave-relevant subset all pass. Not a defect, but the headline number is
|
| 179 |
+
now off in the same way Wave 13's "9 multi-process tests" was off.
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## Finding 7 — NIT: `loss_fn` docstring claims "DPPO importance-sampling-ratio clipping — implemented" (`composer_loss.py:9`).
|
| 184 |
+
|
| 185 |
+
Implementation contains no importance-ratio multiplication anywhere.
|
| 186 |
+
Even if Finding 1 is rejected and the team decides "PRIME-RL match isn't
|
| 187 |
+
a goal", the docstring is internally false: it announces ISR clipping
|
| 188 |
+
in a function that does not multiply by `exp(log_ratio)`.
|
| 189 |
+
|
| 190 |
+
---
|
| 191 |
+
|
| 192 |
+
## Cross-cutting
|
| 193 |
+
|
| 194 |
+
The four doc subagents wrote internally consistent text but inherited
|
| 195 |
+
T4's mathematical error. **Three of the four doc files repeat the same
|
| 196 |
+
wrong formula verbatim.** This is exactly the failure mode Wave 11/13
|
| 197 |
+
reviewers flagged: parallel subagents cross-citing each other rather
|
| 198 |
+
than the upstream source of truth.
|
| 199 |
+
|
| 200 |
+
The 61 tests in the Wave-14-touched dirs pass cleanly. T1, T2, and T3
|
| 201 |
+
are real closures with real coverage. The framework is in a **better**
|
| 202 |
+
state than end-of-Wave-13 — but it has not actually closed Wave 13
|
| 203 |
+
Finding 6, and it has propagated a subtler version of the same
|
| 204 |
+
mathematical-mismatch bug into the user-facing documentation.
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
## Summary scorecard
|
| 209 |
+
|
| 210 |
+
| Wave 13 Finding | Wave 14 status | Verdict |
|
| 211 |
+
|---|---|---|
|
| 212 |
+
| BLOCKER 1 (PRIME-RL SDPO degenerate) | Fixed parent-side; channel 2 raises NotImplementedError | ✅ closed |
|
| 213 |
+
| BLOCKER 2 (compose_loss kwargs not added) | T1 added them + 11 integration tests | ✅ closed |
|
| 214 |
+
| Suggestion 3 (replaysim YAML field types) | T2 dual-shape reshape + real DJ e2e + caught related bug | ✅ closed |
|
| 215 |
+
| Suggestion 4 (MockManager → DiLoCo gap) | T3 surface audit + integration test | 🟡 closed for `world_size=1`; multi-process unverified |
|
| 216 |
+
| Suggestion 5 ("9 multi-process tests" inflated count) | Not addressed | 🟡 carried over |
|
| 217 |
+
| Suggestion 6 (PRIME-RL channel 1 REINFORCE not GRPO) | T4 thought it closed this | ❌ **NOT closed — mathematically wrong** |
|
| 218 |
+
| Suggestion 7 (Modal/HFJobs skeleton clarity) | Made worse by INTEGRATION_RECIPES dead code | 🟡 regression |
|
| 219 |
+
| NIT 8 (SimPO test positive log-probs) | Not addressed | 🟡 carried over |
|
| 220 |
+
|
| 221 |
+
## Wave 14b follow-up (2026-05-26)
|
| 222 |
+
|
| 223 |
+
After Wave 14b closed Finding 1 by re-reading PRIME-RL upstream and
|
| 224 |
+
matching `default_loss_fn` byte-for-byte, the Wave 14b subagent flagged
|
| 225 |
+
a **new** structural issue not in the Wave 14 review:
|
| 226 |
+
|
| 227 |
+
**PRIME-RL's `setup_loss_fns` (upstream `loss.py:320-327`) expects the
|
| 228 |
+
custom loss function to return `LossOutputs(loss, metrics={...})`, not
|
| 229 |
+
a bare scalar tensor.** Our recipe still returns a bare scalar. This
|
| 230 |
+
predates Wave 14 (it's been wrong since the recipe was first written in
|
| 231 |
+
Wave 13) but was never caught because no test runs against actual
|
| 232 |
+
PRIME-RL.
|
| 233 |
+
|
| 234 |
+
**Status:** documented; deferred to Wave 15. Not blocking for Wave 14b's
|
| 235 |
+
closure of Finding 1, because the formula now matches upstream — the
|
| 236 |
+
return-shape mismatch is a separate adapter-level issue. Tests still
|
| 237 |
+
pass because they invoke our `loss_fn` directly without going through
|
| 238 |
+
PRIME-RL's `compute_loss` pipeline.
|
| 239 |
+
|
| 240 |
+
**Fix direction (Wave 15):** wrap the return value in a duck-typed
|
| 241 |
+
`LossOutputs` (provided by PRIME-RL when installed; substituted with a
|
| 242 |
+
NamedTuple shim when not). Add an integration smoke test against PRIME-RL
|
| 243 |
+
to catch this and similar adapter-shape regressions.
|
| 244 |
+
|
| 245 |
+
## Final Wave 14 + 14b status
|
| 246 |
+
|
| 247 |
+
| Wave 13 / 14 finding | Wave 14b status |
|
| 248 |
+
|---|---|
|
| 249 |
+
| W13 BLOCKER 1: PRIME-RL SDPO degenerate | ✅ closed (parent, channel 2 deferred) |
|
| 250 |
+
| W13 BLOCKER 2: compose_loss kwargs not added | ✅ closed (Wave 14 T1) |
|
| 251 |
+
| W13 Suggestion 3: replaysim YAML field types | ✅ closed (Wave 14 T2) |
|
| 252 |
+
| W13 Suggestion 4: MockManager → DiLoCo gap | ✅ closed (Wave 14 T3 + Wave 14b multi-process test) |
|
| 253 |
+
| W13 Suggestion 6: PRIME-RL channel 1 REINFORCE not GRPO | ✅ **closed in Wave 14b** (matches upstream `default_loss_fn`) |
|
| 254 |
+
| W14 Finding 1: PRIME-RL impl wrong | ✅ closed in Wave 14b |
|
| 255 |
+
| W14 Finding 2: ADR-007 stale | ✅ closed in Wave 14b |
|
| 256 |
+
| W14 Finding 3: ModalExecutor dead code | ✅ closed in Wave 14b |
|
| 257 |
+
| W14 Finding 4: world_size=1 only | ✅ closed in Wave 14b (multi-process convergence test) |
|
| 258 |
+
| W14 Finding 5: tests pin wrong impl as ground truth | ✅ closed in Wave 14b (parity test added) |
|
| 259 |
+
| W14 NIT 6: test count drift | 🟡 carried |
|
| 260 |
+
| W14 NIT 7: docstring claims ISR clipping | ✅ closed in Wave 14b (real ISR now implemented) |
|
| 261 |
+
| **NEW (Wave 14b)**: PRIME-RL `LossOutputs` return shape | 🟡 deferred to Wave 15 |
|
| 262 |
+
|
| 263 |
+
**Test count post-Wave-14b: 130 passing + 1 skip-marked (PRIME-RL
|
| 264 |
+
parity test, runs when prime-rl is installed).**
|