composer-replication-framework / docs /research /WAVE_14_FINAL_REVIEW.md
Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
# Wave 14 Adversarial Cross-Model Review
**Reviewer:** Claude Opus 4.7 (sub-agent via delegate_task)
**Date:** 2026-05-26
**Method:** Read every Wave 13 finding, every Wave 14 closure, all 4 doc files, **cloned PRIME-RL upstream to verify T4 claims**, ran 61 wave-relevant tests.
## Top-line verdict
**CONDITIONAL PASS with 1 BLOCKER + 4 SUGGESTIONs + 2 NITs.** Wave 14
closes Wave 13 BLOCKER 2 (T1 — compose_loss kwargs) and Suggestion 3
(T2 — replaysim) cleanly. T3 (MockManager surface audit) is solid but
only tests `world_size=1`. **T4 (PRIME-RL "real GRPO + DPPO") does not
match PRIME-RL's actual `default_loss_fn`** despite claiming to mirror
it; that error has been pasted into USER_GUIDE.md, INTEGRATION_RECIPES.md,
and API_REFERENCE.md.
Same signal-to-noise as Wave 11 + Wave 13 reviewers: 1 genuine BLOCKER.
---
## Finding 1 — BLOCKER: T4 PRIME-RL "DPPO importance-sampling-ratio clip" is neither importance sampling nor matches PRIME-RL.
**Severity:** BLOCKER
**Evidence:** `composer_replication/recipes/prime_rl/composer_loss.py:118-131`.
The implementation computes
```python
grpo_loss = -(advantages * trainer_lp * keep_mask).sum() / keep_mask.sum()
```
That's **pure REINFORCE-with-advantage** — the masking gate is the only
nod toward DPPO; there is no importance-sampling ratio multiplication
anywhere.
**Real PRIME-RL** (`/tmp/prime-rl-clone/src/prime_rl/trainer/rl/loss.py:128-153`,
the `default_loss_fn` on `main` as of 2026-05-26):
```python
log_importance_ratio = trainer_logprobs - inference_logprobs
importance_ratio = torch.exp(log_importance_ratio)
probs_diff = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs)
positive_advantages = advantages > 0
dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high
dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low
dppo_invalid_mask = torch.where(positive_advantages, dppo_invalid_mask_high, dppo_invalid_mask_low)
keep_mask = loss_mask & ~dppo_invalid_mask
pg_loss = -(keep_mask * advantages * importance_ratio).sum() # NO division
kl_loss = adv_tau * (log_importance_ratio**2 * keep_mask).sum() # KL term
```
**Three concrete divergences from Wave 14's implementation:**
1. **Mask gate is on `probs_diff`** (a probability-space quantity), NOT
`log_ratio` (a log-space quantity). These have different magnitudes:
`probs_diff=0.2` corresponds to `log_ratio≈log(1.2)≈0.18` for a
trainer prob of 1.0 vs inference prob of 0.8. With our `log_ratio>4.0`
gate, the mask never fires for normal training distributions; PRIME-RL's
`probs_diff>0.2` gate fires routinely.
2. **PRIME-RL multiplies by `importance_ratio = exp(log_ratio)`**;
Wave 14 multiplies by `trainer_lp` directly. This is the difference
between actual policy-gradient correction (PRIME-RL) and naive
REINFORCE.
3. **PRIME-RL's mask is sign-conditioned on advantage** (positive
advantages clipped against `dppo_mask_high`, negative against
`-dppo_mask_low`); Wave 14 ORs them together unconditionally.
**Plus:** the KL term is missing entirely.
**Plus:** the defaults claimed as "PRIME-RL's defaults" — `dppo_mask_high=4.0,
dppo_mask_low=-4.0` — are wrong. PRIME-RL's `DefaultLossConfig`
(`configs/trainer.py:412-424`) sets `dppo_mask_high=0.2, dppo_mask_low=0.2`
with `Field(..., ge=0)` validation that would *reject* a negative value.
PRIME-RL's code negates at use site: `probs_diff < -loss_config.dppo_mask_low`.
**Plus:** the docstring (`composer_loss.py:32-49`), USER_GUIDE.md:599-608,
INTEGRATION_RECIPES.md:426-429 + 482-487, and API_REFERENCE.md:1364 all
repeat the wrong formula and the wrong "matches PRIME-RL" claim.
**Fix direction:** Either (a) actually mirror `default_loss_fn` (mask on
`probs_diff`, multiply by `importance_ratio`, add KL term, advantage-
conditioned mask, `.sum()` reduction with token-count returned for
caller-side scaling), or (b) drop the "matches PRIME-RL" framing and
rename to "REINFORCE-with-advantage stub + log-ratio mask" everywhere.
Wave 13 Finding 6 is **not actually closed** by Wave 14.
---
## Finding 2 — SUGGESTION: ADR-007 still says Wave 14 hasn't done the integration.
**Severity:** SUGGESTION
**Evidence:** `docs/adrs/ADR-007-self-distillation-losses.md:104-122` reads:
> **Wave 14+ work — `compose_loss` integration is NOT in this wave**
> ... Wave 14 plan: add the four kwargs ...
But Wave 14 *did* add them (verified — `loss.py:80-93`). The ADR was
written defensively after Wave 13 review and never updated when T1 landed.
**Net effect:** a user reading ADR-007 is told the SimPO/TAID kwargs
don't work; a user reading USER_GUIDE/API_REFERENCE is told they do.
**Fix direction:** flip ADR-007 status section to "Closed in Wave 14 —
see test_compose_loss_integration.py".
---
## Finding 3 — SUGGESTION: ModalExecutor instantiation example in INTEGRATION_RECIPES is dead code.
**Severity:** SUGGESTION
**Evidence:** `docs/INTEGRATION_RECIPES.md:519-533` shows
```python
executor = ModalExecutor(app="composer-prime-rl")
executor.launch_replicas(...)
```
But `composer_replication/diloco/serverless/modal.py:64-66` raises
`NotImplementedError` from `__init__`. Same pattern in `HFJobsExecutor`.
The recipe doc warns about skeleton-status much further down (line 731),
but the inline code example at line 519 will break the moment a reader
copy-pastes it.
Wave 13 Finding 7 noted this softness; Wave 14 made it worse by writing
example code that calls a constructor that always raises.
**Fix direction:** in every code block that calls `ModalExecutor(...)`,
prepend a comment `# Wave 14: skeleton — raises NotImplementedError`
or flip examples to `LocalProcessExecutor`.
---
## Finding 4 — SUGGESTION: MockManager + DiLoCo integration test only exercises `world_size=1`.
**Severity:** SUGGESTION
**Evidence:** `composer_replication/diloco/serverless/tests/test_serverless_diloco_integration.py:44-51`,
`:108-109`, `:161`. Both `test_mockmanager_diloco_outer_round_completes`
and `test_mockmanager_diloco_two_outer_rounds_step_counter` use
`world_size=1`.
With one replica, `ObjectStoreAllReduce.allreduce` returns the tensor
unchanged (its own mean), so an averaging bug in the multi-replica path
could not be caught by this test. The pseudo-gradient sign convention
is pinned by the unrelated spike-008 test, but **no test combines
MockManager + DiLoCo + multi-process** — i.e. the actual deployment
scenario is unverified end-to-end.
Wave 13 Finding 4 is closed in spirit (call surface is now exhaustive)
but not in the deepest sense.
**Fix direction:** add one multi-process test that spawns `n_replicas`
subprocesses, each constructing `MockManager(store) → make_diloco_outer_loop`,
and asserts that after one outer round all replicas converge to the same
parameter values (i.e. averaging actually happened).
---
## Finding 5 — SUGGESTION: T4 unit tests pin the wrong implementation as ground truth.
**Severity:** SUGGESTION
**Evidence:** `composer_replication/recipes/prime_rl/tests/test_composer_loss.py:90-128`
(`test_dppo_mask_clips_extreme_ratios`). The expected value `1.5/3` is
computed against the buggy formula (Finding 1).
The 10 PRIME-RL tests all pass — but they're testing self-consistency,
not parity with PRIME-RL. A reader looking at "10 unit tests, all green"
infers correctness; correctness is not what they verify. This is the
kind of test honesty failure that Wave 11 + Wave 13 reviewers found in
different forms.
**Fix direction:** add at least one test whose expected value is
hand-computed from `default_loss_fn` in PRIME-RL (or import + invoke
`default_loss_fn` if the dependency is available, mark the test
`@pytest.mark.skipif(not _HAS_PRIME_RL)`).
---
## Finding 6 — NIT: README/test-count drift.
Wave 14 task description claims "124 tests passing as of Wave 14"; actual
`pytest --collect-only` reports **134 collected**. Of those, the 61-test
wave-relevant subset all pass. Not a defect, but the headline number is
now off in the same way Wave 13's "9 multi-process tests" was off.
---
## Finding 7 — NIT: `loss_fn` docstring claims "DPPO importance-sampling-ratio clipping — implemented" (`composer_loss.py:9`).
Implementation contains no importance-ratio multiplication anywhere.
Even if Finding 1 is rejected and the team decides "PRIME-RL match isn't
a goal", the docstring is internally false: it announces ISR clipping
in a function that does not multiply by `exp(log_ratio)`.
---
## Cross-cutting
The four doc subagents wrote internally consistent text but inherited
T4's mathematical error. **Three of the four doc files repeat the same
wrong formula verbatim.** This is exactly the failure mode Wave 11/13
reviewers flagged: parallel subagents cross-citing each other rather
than the upstream source of truth.
The 61 tests in the Wave-14-touched dirs pass cleanly. T1, T2, and T3
are real closures with real coverage. The framework is in a **better**
state than end-of-Wave-13 — but it has not actually closed Wave 13
Finding 6, and it has propagated a subtler version of the same
mathematical-mismatch bug into the user-facing documentation.
---
## Summary scorecard
| Wave 13 Finding | Wave 14 status | Verdict |
|---|---|---|
| BLOCKER 1 (PRIME-RL SDPO degenerate) | Fixed parent-side; channel 2 raises NotImplementedError | ✅ closed |
| BLOCKER 2 (compose_loss kwargs not added) | T1 added them + 11 integration tests | ✅ closed |
| Suggestion 3 (replaysim YAML field types) | T2 dual-shape reshape + real DJ e2e + caught related bug | ✅ closed |
| Suggestion 4 (MockManager → DiLoCo gap) | T3 surface audit + integration test | 🟡 closed for `world_size=1`; multi-process unverified |
| Suggestion 5 ("9 multi-process tests" inflated count) | Not addressed | 🟡 carried over |
| Suggestion 6 (PRIME-RL channel 1 REINFORCE not GRPO) | T4 thought it closed this | ❌ **NOT closed — mathematically wrong** |
| Suggestion 7 (Modal/HFJobs skeleton clarity) | Made worse by INTEGRATION_RECIPES dead code | 🟡 regression |
| NIT 8 (SimPO test positive log-probs) | Not addressed | 🟡 carried over |
## Wave 14b follow-up (2026-05-26)
After Wave 14b closed Finding 1 by re-reading PRIME-RL upstream and
matching `default_loss_fn` byte-for-byte, the Wave 14b subagent flagged
a **new** structural issue not in the Wave 14 review:
**PRIME-RL's `setup_loss_fns` (upstream `loss.py:320-327`) expects the
custom loss function to return `LossOutputs(loss, metrics={...})`, not
a bare scalar tensor.** Our recipe still returns a bare scalar. This
predates Wave 14 (it's been wrong since the recipe was first written in
Wave 13) but was never caught because no test runs against actual
PRIME-RL.
**Status:** documented; deferred to Wave 15. Not blocking for Wave 14b's
closure of Finding 1, because the formula now matches upstream — the
return-shape mismatch is a separate adapter-level issue. Tests still
pass because they invoke our `loss_fn` directly without going through
PRIME-RL's `compute_loss` pipeline.
**Fix direction (Wave 15):** wrap the return value in a duck-typed
`LossOutputs` (provided by PRIME-RL when installed; substituted with a
NamedTuple shim when not). Add an integration smoke test against PRIME-RL
to catch this and similar adapter-shape regressions.
## Final Wave 14 + 14b status
| Wave 13 / 14 finding | Wave 14b status |
|---|---|
| W13 BLOCKER 1: PRIME-RL SDPO degenerate | ✅ closed (parent, channel 2 deferred) |
| W13 BLOCKER 2: compose_loss kwargs not added | ✅ closed (Wave 14 T1) |
| W13 Suggestion 3: replaysim YAML field types | ✅ closed (Wave 14 T2) |
| W13 Suggestion 4: MockManager → DiLoCo gap | ✅ closed (Wave 14 T3 + Wave 14b multi-process test) |
| W13 Suggestion 6: PRIME-RL channel 1 REINFORCE not GRPO | ✅ **closed in Wave 14b** (matches upstream `default_loss_fn`) |
| W14 Finding 1: PRIME-RL impl wrong | ✅ closed in Wave 14b |
| W14 Finding 2: ADR-007 stale | ✅ closed in Wave 14b |
| W14 Finding 3: ModalExecutor dead code | ✅ closed in Wave 14b |
| W14 Finding 4: world_size=1 only | ✅ closed in Wave 14b (multi-process convergence test) |
| W14 Finding 5: tests pin wrong impl as ground truth | ✅ closed in Wave 14b (parity test added) |
| W14 NIT 6: test count drift | 🟡 carried |
| W14 NIT 7: docstring claims ISR clipping | ✅ closed in Wave 14b (real ISR now implemented) |
| **NEW (Wave 14b)**: PRIME-RL `LossOutputs` return shape | 🟡 deferred to Wave 15 |
**Tests as of Wave 14b: 115 passing + 1 skip-marked (OPSD parity test, runs when upstream cloned).** (Wave 12: 72; Wave 13: 93; Wave 14: 124; Wave 14b: 130; Wave 15: 115 after TAID rewrite consolidation + OPSD parity.)