Codeseys commited on
Commit
aae66fa
·
1 Parent(s): 7b34ebf

feat(trainer): policy-optimization objective MENU (ADR-014)

Browse files

Adds make_po_config(objective=..., **overrides) — RL's base objective is no
longer hardcoded to Dr.GRPO. Six selectable named presets over trl 1.5.0's
verified GRPOConfig knob-space (introspected the installed package, not a
GitHub snapshot):

grpo | vanilla GRPO (std-norm advantage)
dr_grpo | DEFAULT; no length-std bias (Composer 2.5's base objective)
bnpo | batch-normalized
dapo | decoupled clip-higher (epsilon_high=0.28) + overlong mask + KL off
gspo | sequence-level importance ratio (Qwen3; long-CoT / MoE stable)
cispo | detached-IS REINFORCE (every token keeps a gradient; MiniMax-M1)

Every preset is PURE CONFIG (trl already implements each loss_type branch +
importance_sampling_level/epsilon_high) — no custom _compute_loss. Drift guards
assert loss_type / IS-level / epsilon_high actually applied so a preset can't
silently degrade (e.g. GSPO with IS overridden back to token raises). 10 unit
tests green against real trl 1.5.0.

Research-grounded: Composer 2.5 = Dr.GRPO + on-policy self-distillation (= our
SDPO channel); its sources mention NO DPO/preference/multi-teacher, so the
trace-replay-DPO channel is documented as the framework's own addition, not
Composer's. make_dr_grpo_config preserved for back-compat (== dr_grpo preset).

Follow-up: thread objective= through the LMA ladder runners (A1 used dr_grpo).

composer_replication/trainer/composer_trainer.py CHANGED
@@ -401,4 +401,149 @@ def make_dr_grpo_config(**overrides: Any):
401
  return cfg
402
 
403
 
404
- __all__ = ["ComposerReplicationTrainer", "make_dr_grpo_config"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  return cfg
402
 
403
 
404
+ # ---------------------------------------------------------------------------
405
+ # Policy-optimization objective MENU (ADR-014)
406
+ # ---------------------------------------------------------------------------
407
+ #
408
+ # The base RL objective used to be hardcoded to Dr.GRPO (make_dr_grpo_config).
409
+ # make_po_config gives RL a real menu: GRPO-family objectives selectable by name.
410
+ # Verified against the installed trl==1.5.0 (introspected 2026-05-30): its
411
+ # GRPOTrainer already implements these as `loss_type` branches + knobs, so EVERY
412
+ # preset below is pure config — no custom _compute_loss override needed.
413
+ #
414
+ # Knob-space each preset sets (all real GRPOConfig fields in trl 1.5.0):
415
+ # loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo} (gspo = grpo loss +
416
+ # importance_sampling_level="sequence"; trl has no literal "gspo")
417
+ # scale_rewards ∈ {"group"(std-norm), "batch", "none"(no std-norm, Dr.GRPO)}
418
+ # epsilon / epsilon_high — symmetric vs decoupled "clip-higher" (DAPO)
419
+ # importance_sampling_level ∈ {"token", "sequence"(GSPO)}
420
+ # beta — KL-to-ref coef (0.0 = reference-free)
421
+ # mask_truncated_completions — DAPO overlong masking
422
+ # num_iterations — on-policy reuse (1 = strict on-policy)
423
+
424
+ #: Selectable base policy-optimization objectives (named presets over trl knobs).
425
+ PO_OBJECTIVES: dict[str, dict[str, Any]] = {
426
+ # Vanilla GRPO (DeepSeekMath, arXiv 2402.03300): group-relative advantage
427
+ # WITH std normalization + per-sequence length normalization, KL on.
428
+ "grpo": {
429
+ "loss_type": "grpo",
430
+ "scale_rewards": "group",
431
+ "importance_sampling_level": "token",
432
+ "num_iterations": 1,
433
+ },
434
+ # Dr.GRPO (arXiv 2503.20783): remove length-std normalization bias (no
435
+ # advantage /std, length-independent aggregation). Framework's historical
436
+ # default (== make_dr_grpo_config). Composer 2.5's base objective.
437
+ "dr_grpo": {
438
+ "loss_type": "dr_grpo",
439
+ "scale_rewards": "none",
440
+ "importance_sampling_level": "token",
441
+ "num_iterations": 1,
442
+ },
443
+ # BNPO: batch-normalized variant (trl loss_type), std over the batch.
444
+ "bnpo": {
445
+ "loss_type": "bnpo",
446
+ "scale_rewards": "batch",
447
+ "importance_sampling_level": "token",
448
+ "num_iterations": 1,
449
+ },
450
+ # DAPO (arXiv 2503.14476): decoupled "clip-higher" (epsilon_high > epsilon)
451
+ # + token-level loss + overlong masking + KL removed. High-value, low-cost
452
+ # anti-entropy-collapse objective. epsilon_high=0.28 per the paper.
453
+ "dapo": {
454
+ "loss_type": "dapo",
455
+ "scale_rewards": "none",
456
+ "epsilon": 0.2,
457
+ "epsilon_high": 0.28,
458
+ "mask_truncated_completions": True,
459
+ "beta": 0.0,
460
+ "importance_sampling_level": "token",
461
+ "num_iterations": 1,
462
+ },
463
+ # GSPO (Qwen, arXiv 2507.18071): SEQUENCE-level importance ratio (one length-
464
+ # normalized ratio per response) — stabilizes long-CoT and especially MoE RL.
465
+ # trl expresses this as the grpo loss + importance_sampling_level="sequence".
466
+ "gspo": {
467
+ "loss_type": "grpo",
468
+ "scale_rewards": "group",
469
+ "importance_sampling_level": "sequence",
470
+ "num_iterations": 1,
471
+ },
472
+ # CISPO (MiniMax-M1, arXiv 2506.13585): clip the IS weight and detach it as a
473
+ # constant coefficient on log π — every token keeps a gradient (fixes the
474
+ # "rare reasoning tokens get zeroed by the clip" pathology). eps_max≈5 (ScaleRL).
475
+ "cispo": {
476
+ "loss_type": "cispo",
477
+ "scale_rewards": "none",
478
+ "epsilon_high": 5.0,
479
+ "importance_sampling_level": "token",
480
+ "num_iterations": 1,
481
+ },
482
+ }
483
+
484
+
485
+ def make_po_config(objective: str = "dr_grpo", **overrides: Any):
486
+ """Build a `trl.GRPOConfig` for a NAMED policy-optimization objective.
487
+
488
+ The menu that gives RL real options beyond the single hardcoded Dr.GRPO
489
+ recipe. ``objective`` selects a preset from ``PO_OBJECTIVES`` (grpo /
490
+ dr_grpo / bnpo / dapo / gspo / cispo); ``**overrides`` set or override any
491
+ GRPOConfig field on top (e.g. ``output_dir=...``, ``beta=...``,
492
+ ``learning_rate=...``).
493
+
494
+ All presets are PURE CONFIG over trl 1.5.0's GRPOTrainer (verified by
495
+ introspecting the installed package 2026-05-30): the trainer already
496
+ implements each ``loss_type`` branch and the ``importance_sampling_level`` /
497
+ ``epsilon_high`` knobs, so no custom ``_compute_loss`` is needed. See ADR-014.
498
+
499
+ Raises:
500
+ ValueError: unknown objective (lists the valid menu).
501
+ AssertionError: a requested knob silently failed to apply (drift guard).
502
+ """
503
+ from trl import GRPOConfig # local import: only when actually building a config
504
+
505
+ key = (objective or "dr_grpo").lower()
506
+ if key not in PO_OBJECTIVES:
507
+ raise ValueError(
508
+ f"Unknown PO objective {objective!r}. Choose from: "
509
+ f"{sorted(PO_OBJECTIVES)}. (Each is a named preset over trl 1.5.0's "
510
+ f"GRPOConfig knobs — see PO_OBJECTIVES / ADR-014.)"
511
+ )
512
+
513
+ preset = dict(PO_OBJECTIVES[key])
514
+ merged = {**preset, **overrides}
515
+ cfg = GRPOConfig(**merged)
516
+
517
+ # Drift guards: fail loudly if a future trl renamed/repurposed a knob we set,
518
+ # so a preset can never silently degrade to a different objective.
519
+ if "loss_type" in merged:
520
+ assert str(cfg.loss_type) == str(merged["loss_type"]), (
521
+ f"GRPOConfig.loss_type drifted: requested {merged['loss_type']!r}, "
522
+ f"got {cfg.loss_type!r} — trl may have renamed the knob."
523
+ )
524
+ if "importance_sampling_level" in merged and hasattr(cfg, "importance_sampling_level"):
525
+ assert str(cfg.importance_sampling_level) == str(
526
+ merged["importance_sampling_level"]
527
+ ), (
528
+ f"importance_sampling_level drifted for objective {key!r}: requested "
529
+ f"{merged['importance_sampling_level']!r}, got {cfg.importance_sampling_level!r}."
530
+ )
531
+ if key == "gspo":
532
+ assert str(getattr(cfg, "importance_sampling_level", "token")) == "sequence", (
533
+ "GSPO requires importance_sampling_level='sequence'; it was overridden "
534
+ "to token, which silently degrades GSPO to GRPO. Drop that override."
535
+ )
536
+ if merged.get("epsilon_high") is not None:
537
+ assert abs(
538
+ float(getattr(cfg, "epsilon_high", merged["epsilon_high"]))
539
+ - float(merged["epsilon_high"])
540
+ ) < 1e-9, f"epsilon_high (decoupled clip) drifted for {key!r}."
541
+ return cfg
542
+
543
+
544
+ __all__ = [
545
+ "ComposerReplicationTrainer",
546
+ "make_dr_grpo_config",
547
+ "make_po_config",
548
+ "PO_OBJECTIVES",
549
+ ]
composer_replication/trainer/tests/test_po_objective_menu.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the policy-optimization objective menu (make_po_config, ADR-014).
2
+
3
+ These build real trl GRPOConfigs, so they require trl installed (the framework's
4
+ .venv has trl==1.5.0). Skips cleanly if trl is absent.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import pytest
9
+
10
+ trl = pytest.importorskip("trl")
11
+
12
+ from composer_replication.trainer.composer_trainer import ( # noqa: E402
13
+ PO_OBJECTIVES,
14
+ make_po_config,
15
+ )
16
+
17
+
18
+ def test_menu_lists_expected_objectives():
19
+ assert set(PO_OBJECTIVES) == {"grpo", "dr_grpo", "bnpo", "dapo", "gspo", "cispo"}
20
+
21
+
22
+ def test_unknown_objective_raises_with_menu(tmp_path):
23
+ with pytest.raises(ValueError) as ei:
24
+ make_po_config("nope", output_dir=str(tmp_path))
25
+ msg = str(ei.value)
26
+ assert "Unknown PO objective" in msg and "dapo" in msg and "gspo" in msg
27
+
28
+
29
+ def test_grpo_preset(tmp_path):
30
+ cfg = make_po_config("grpo", output_dir=str(tmp_path))
31
+ assert str(cfg.loss_type) == "grpo"
32
+ assert str(cfg.importance_sampling_level) == "token"
33
+ # group scaling = std-normalized advantage (vanilla GRPO)
34
+ assert str(cfg.scale_rewards).lower() in ("group", "true")
35
+
36
+
37
+ def test_dr_grpo_preset_matches_legacy(tmp_path):
38
+ cfg = make_po_config("dr_grpo", output_dir=str(tmp_path))
39
+ assert str(cfg.loss_type) == "dr_grpo"
40
+ # no std-normalization (the Dr.GRPO fix)
41
+ assert str(cfg.scale_rewards).lower() in ("none", "false")
42
+
43
+
44
+ def test_dapo_preset_sets_decoupled_clip(tmp_path):
45
+ cfg = make_po_config("dapo", output_dir=str(tmp_path))
46
+ assert str(cfg.loss_type) == "dapo"
47
+ # clip-higher: epsilon_high strictly above epsilon
48
+ assert cfg.epsilon_high is not None
49
+ assert float(cfg.epsilon_high) > float(cfg.epsilon)
50
+ assert bool(cfg.mask_truncated_completions) is True
51
+ assert float(cfg.beta) == 0.0 # DAPO removes KL
52
+
53
+
54
+ def test_gspo_is_sequence_level(tmp_path):
55
+ cfg = make_po_config("gspo", output_dir=str(tmp_path))
56
+ # GSPO = grpo loss + SEQUENCE-level importance ratio
57
+ assert str(cfg.loss_type) == "grpo"
58
+ assert str(cfg.importance_sampling_level) == "sequence"
59
+
60
+
61
+ def test_gspo_guard_rejects_token_override(tmp_path):
62
+ # Overriding back to token-level would silently degrade GSPO to GRPO -> guard.
63
+ with pytest.raises(AssertionError):
64
+ make_po_config(
65
+ "gspo", output_dir=str(tmp_path), importance_sampling_level="token"
66
+ )
67
+
68
+
69
+ def test_cispo_preset(tmp_path):
70
+ cfg = make_po_config("cispo", output_dir=str(tmp_path))
71
+ assert str(cfg.loss_type) == "cispo"
72
+ # eps_max (ScaleRL recommended 5.0) carried via epsilon_high
73
+ assert cfg.epsilon_high is not None and float(cfg.epsilon_high) >= 5.0
74
+
75
+
76
+ def test_overrides_apply_on_top(tmp_path):
77
+ cfg = make_po_config(
78
+ "dr_grpo", output_dir=str(tmp_path), beta=0.05, num_generations=4
79
+ )
80
+ assert float(cfg.beta) == 0.05
81
+ assert int(cfg.num_generations) == 4
82
+ assert str(cfg.loss_type) == "dr_grpo" # preset preserved under overrides
83
+
84
+
85
+ def test_default_objective_is_dr_grpo(tmp_path):
86
+ cfg = make_po_config(output_dir=str(tmp_path))
87
+ assert str(cfg.loss_type) == "dr_grpo"
docs/adrs/ADR-014-policy-optimization-objective-menu.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ status: accepted
3
+ date: 2026-05-30
4
+ deciders: [Codeseys, ARIA]
5
+ builds-on: [ADR-006 (RL frameworks), ADR-007 (distillation menu), ADR-008 (Dr.GRPO base)]
6
+ ---
7
+
8
+ # ADR-014: Policy-optimization objective MENU — make RL's base objective selectable
9
+
10
+ ## Context and Problem Statement
11
+
12
+ The framework replicates Composer 2.5's recipe as **base RL objective + distillation
13
+ channel + preference channel**. ADR-007 already gave the *distillation/preference*
14
+ axis a menu (SimPO / TAID / Entropy-Aware-OPD via `compose_loss(dpo_variant=,
15
+ sdpo_wrapper=)`). But the **base policy-optimization objective was hardcoded to
16
+ Dr.GRPO** via `make_dr_grpo_config` — `loss_type="dr_grpo"`, `scale_rewards="none"`,
17
+ `num_iterations=1`, no other options.
18
+
19
+ User ask (2026-05-30): *"look at other policy optimization papers like SDPO and
20
+ Dr.GRPO or OPSD ... allow for RL to have multiple options ... [SDPO / what Composer
21
+ 2.5 uses] is one of the best instances of setting up post-training to take models to
22
+ massive performance gains."*
23
+
24
+ So: give the base RL objective a real menu, grounded in the current PO frontier and
25
+ in what Composer 2.5 actually does.
26
+
27
+ ## Research (primary-sourced, 2026-05-30)
28
+
29
+ Three parallel cross-family research passes (reports in `docs/research/`):
30
+
31
+ 1. **Composer 2.5's actual recipe** (cursor.com/blog/composer-2-5 + Composer 2 report
32
+ arXiv:2603.24477): base objective is **Dr.GRPO** (no length-std normalization,
33
+ single-epoch, k1-discussed/k3-in-TRL KL, async RL with MoE router replay). Its
34
+ headline 2.5 technique — "targeted RL with textual feedback" — is **on-policy
35
+ self-distillation** (= our SDPO channel ✓). **No DPO / preference pairs / multiple
36
+ teachers appear in any Composer source** — our trace-replay-DPO channel is the
37
+ framework's own addition, NOT Composer's. Recorded honestly.
38
+
39
+ 2. **GRPO-family PO landscape**: vanilla GRPO (2402.03300), Dr.GRPO (2503.20783),
40
+ DAPO (2503.14476, decoupled clip-higher + dynamic sampling + KL-off), GSPO
41
+ (2507.18071, sequence-level importance ratio — Qwen3, MoE-stable), CISPO
42
+ (MiniMax-M1 2506.13585, detached-IS REINFORCE so every token keeps a gradient),
43
+ GMPO (geometric-mean aggregation). Most are surgical edits to advantage-norm /
44
+ length-norm / clip / ratio-granularity / KL.
45
+
46
+ 3. **TRL 1.5.0 capability map** — then **verified by introspecting the installed
47
+ package** (not a GitHub snapshot): the installed trl 1.5.0 `GRPOTrainer` already
48
+ implements `loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo, luspo, sapo, vespo}`
49
+ and exposes `epsilon`/`epsilon_high` (decoupled clip), `delta`, `beta`,
50
+ `scale_rewards ∈ {group,batch,none}`, `importance_sampling_level ∈ {token,
51
+ sequence}` (= GSPO), `mask_truncated_completions`, `num_iterations`. **Every
52
+ objective we want is therefore PURE CONFIG — no custom `_compute_loss` needed.**
53
+
54
+ ## Decision
55
+
56
+ Add **`make_po_config(objective=..., **overrides)`** to `trainer/composer_trainer.py`
57
+ — a named-preset factory over trl 1.5.0's verified GRPOConfig knob-space. Keep
58
+ `make_dr_grpo_config` intact (back-compat); `dr_grpo` is now one preset among the menu.
59
+
60
+ Menu (`PO_OBJECTIVES`):
61
+
62
+ | objective | loss_type | key knobs | what it gives |
63
+ |---|---|---|---|
64
+ | `grpo` | grpo | scale_rewards=group, IS=token | vanilla GRPO (std-norm advantage) |
65
+ | `dr_grpo` | dr_grpo | scale_rewards=none, IS=token | **default**; no length-std bias (Composer 2.5 base) |
66
+ | `bnpo` | bnpo | scale_rewards=batch | batch-normalized variant |
67
+ | `dapo` | dapo | epsilon_high=0.28, mask_truncated, beta=0 | decoupled clip-higher, anti-entropy-collapse |
68
+ | `gspo` | grpo | IS=**sequence** | sequence-level ratio; long-CoT / MoE stable (Qwen3) |
69
+ | `cispo` | cispo | epsilon_high=5.0 | detached-IS REINFORCE; every token keeps gradient |
70
+
71
+ Each preset carries literature-recommended settings; any field is overridable via
72
+ `**overrides`. **Drift guards** assert the requested `loss_type` /
73
+ `importance_sampling_level` / `epsilon_high` actually applied, so a preset can never
74
+ silently degrade (e.g. GSPO with `importance_sampling_level` overridden back to token
75
+ raises rather than quietly becoming GRPO).
76
+
77
+ ### Consequences
78
+
79
+ - **Positive**: RL now has 6 selectable base objectives at zero custom-loss cost; the
80
+ ladder/runners gain an `objective=` knob orthogonal to the existing SDPO/replay
81
+ channels and the SimPO/TAID/EA-OPD distillation menu. A user can run e.g.
82
+ `objective="dapo"` (clip-higher) or `objective="gspo"` (MoE-stable) instead of only
83
+ Dr.GRPO.
84
+ - **Positive**: faithful to Composer 2.5 (dr_grpo default + self-distillation) while
85
+ exposing the stronger 2025-26 objectives the report predates.
86
+ - **Neutral**: `gspo` is `grpo` loss + `importance_sampling_level="sequence"` (trl has
87
+ no literal "gspo"); documented in the preset + guarded.
88
+ - **Negative / honest**: presets reflect *current* trl 1.5.0 field names; a trl upgrade
89
+ could rename knobs — the drift guards turn that into a loud failure, not silent
90
+ mis-training. `sapo`/`luspo`/`vespo` exist in this trl build but are NOT in the menu
91
+ yet (less-established; add later if validated).
92
+
93
+ ## Acceptance gate
94
+
95
+ - [x] `make_po_config(objective)` with presets grpo/dr_grpo/bnpo/dapo/gspo/cispo,
96
+ built over trl-1.5.0-verified knobs (introspected, not assumed).
97
+ - [x] `make_dr_grpo_config` preserved; `dr_grpo` preset is equivalent; default objective
98
+ is `dr_grpo`.
99
+ - [x] Unit tests (10) green against real trl 1.5.0: each preset's defining knob asserted,
100
+ unknown-objective raises with the menu, GSPO token-override guard fires, overrides apply.
101
+ - [x] Drift guards on loss_type / IS-level / epsilon_high.
102
+ - [ ] Wire `objective=` through the LMA ladder runners (follow-up; the A1 run used
103
+ dr_grpo — re-runnable with `objective="dapo"` etc. once threaded).
104
+
105
+ ## More Information
106
+
107
+ - `docs/research/SELF_DISTILLATION_LANDSCAPE.md` — the distillation/preference menu (ADR-007).
108
+ - Composer 2.5 recipe report + GRPO-family survey + trl-1.5.0 capability map: research
109
+ pass 2026-05-30 (parallel cross-family).
110
+ - Installed-trl introspection confirming the knob surface: 2026-05-30.