Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
feat(trainer): ADR-008 Dr.GRPO config + SDPO strict-alignment guard
Browse filesPhase-6 execution of ADR-008 (gates 1, 2, 5 green; gate 3 = the live
GRPOTrainer smoke runs detached separately).
- make_dr_grpo_config(): builds trl.GRPOConfig to the Dr. GRPO recipe
(loss_type=dr_grpo => no length-standardization bias; scale_rewards=none
=> no std-dev advantage normalization; num_iterations=1 => single-epoch),
with a drift-guard assertion. TRL 1.5.0 natively supports these knobs and
its own help text cites the Dr. GRPO paper for both.
- ComposerReplicationTrainer.strict_sdpo_alignment (default True): the
student/teacher logit shape-mismatch path now RAISES instead of silently
zeroing the SDPO channel (closes the ADR-008 trust-gap, composer_trainer.py
lines 158-160). Opt-out to warn-and-skip for production resilience.
- prime_rl/composer_loss.py: NotImplementedError(alpha_sdpo>0) message now
points at the TRL host as the SDPO home (documents the ADR-006 amendment).
- Tests: 7 new (config knobs asserted; strict raises / non-strict skips /
aligned fires / no-error-site no-ops). Full trainer + prime_rl suites:
28 passed, 1 skipped, no regressions.
- examples/composer_grpo_sdpo_smoke/: gate-3 live-loop smoke (detached-run).
Constraint note: cross-family ADR review + subagent execution blocked by
OpenRouter 402 (out of credits); executing inline. TRL import is ~140s
(heavy transitive deps) so the live GRPO smoke runs as a detached
systemd-scope job.
|
@@ -263,8 +263,10 @@ def loss_fn(
|
|
| 263 |
"distribution. Set alpha_sdpo=0.0 to silence this and use "
|
| 264 |
"channel 1 (DPPO+KL) only. teacher_logprobs is "
|
| 265 |
f"{'present' if teacher_lp is not None else 'absent'} in this "
|
| 266 |
-
"call but unused.
|
| 267 |
-
"
|
|
|
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
# --- Channel 3: not supported in PRIME-RL recipe v0 -------------------
|
|
|
|
| 263 |
"distribution. Set alpha_sdpo=0.0 to silence this and use "
|
| 264 |
"channel 1 (DPPO+KL) only. teacher_logprobs is "
|
| 265 |
f"{'present' if teacher_lp is not None else 'absent'} in this "
|
| 266 |
+
"call but unused. For the SDPO channel, use the TRL host "
|
| 267 |
+
"(composer_replication.trainer.ComposerReplicationTrainer with "
|
| 268 |
+
"alpha_sdpo>0), which has full logits — see ADR-008. "
|
| 269 |
+
"See docs/research/WAVE_13_FINAL_REVIEW.md Finding 1."
|
| 270 |
)
|
| 271 |
|
| 272 |
# --- Channel 3: not supported in PRIME-RL recipe v0 -------------------
|
|
@@ -74,6 +74,7 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
|
| 74 |
sdpo_temperature: float = 1.0,
|
| 75 |
sdpo_token_clip: float | None = None,
|
| 76 |
replay_dpo_beta: float = 0.1,
|
|
|
|
| 77 |
**kwargs: Any,
|
| 78 |
):
|
| 79 |
if not _TRL_AVAILABLE:
|
|
@@ -88,6 +89,12 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
|
| 88 |
self.sdpo_temperature = sdpo_temperature
|
| 89 |
self.sdpo_token_clip = sdpo_token_clip
|
| 90 |
self.replay_dpo_beta = replay_dpo_beta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# ----------------------------------------------------------------------
|
| 93 |
# Loss override (the integration core)
|
|
@@ -159,12 +166,21 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
|
| 159 |
# SAME LENGTH at the post-hint section so logits align position-by-position.
|
| 160 |
# The data collator pads/aligns. The skeleton trusts that's done correctly.
|
| 161 |
if student_logits.shape != teacher_logits.shape:
|
| 162 |
-
|
| 163 |
-
"SDPO logit shape mismatch: student=
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
|
|
|
|
| 167 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
|
| 169 |
|
| 170 |
return generalized_jsd_loss(
|
|
@@ -246,4 +262,45 @@ def _device_of(model: torch.nn.Module) -> torch.device:
|
|
| 246 |
return next(model.parameters()).device
|
| 247 |
|
| 248 |
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
sdpo_temperature: float = 1.0,
|
| 75 |
sdpo_token_clip: float | None = None,
|
| 76 |
replay_dpo_beta: float = 0.1,
|
| 77 |
+
strict_sdpo_alignment: bool = True,
|
| 78 |
**kwargs: Any,
|
| 79 |
):
|
| 80 |
if not _TRL_AVAILABLE:
|
|
|
|
| 89 |
self.sdpo_temperature = sdpo_temperature
|
| 90 |
self.sdpo_token_clip = sdpo_token_clip
|
| 91 |
self.replay_dpo_beta = replay_dpo_beta
|
| 92 |
+
# When True (default), an SDPO student/teacher shape mismatch is a hard
|
| 93 |
+
# error — it means the data collator failed to align the post-hint
|
| 94 |
+
# section, which silently zeroes the distillation signal (the exact
|
| 95 |
+
# trust-gap flagged in ADR-008). Set False only for production runs
|
| 96 |
+
# where a single malformed batch should warn-and-skip rather than abort.
|
| 97 |
+
self.strict_sdpo_alignment = strict_sdpo_alignment
|
| 98 |
|
| 99 |
# ----------------------------------------------------------------------
|
| 100 |
# Loss override (the integration core)
|
|
|
|
| 166 |
# SAME LENGTH at the post-hint section so logits align position-by-position.
|
| 167 |
# The data collator pads/aligns. The skeleton trusts that's done correctly.
|
| 168 |
if student_logits.shape != teacher_logits.shape:
|
| 169 |
+
msg = (
|
| 170 |
+
f"SDPO logit shape mismatch: student={tuple(student_logits.shape)} "
|
| 171 |
+
f"vs teacher={tuple(teacher_logits.shape)}. The data collator must "
|
| 172 |
+
"pad/align the post-hint section so student and teacher have "
|
| 173 |
+
"identical token-counts; otherwise the distillation signal is "
|
| 174 |
+
"silently zeroed."
|
| 175 |
)
|
| 176 |
+
if self.strict_sdpo_alignment:
|
| 177 |
+
# Hard error (default): a mismatch means the collator failed to
|
| 178 |
+
# align — the ADR-008 trust-gap. Do not silently zero the channel.
|
| 179 |
+
raise ValueError(
|
| 180 |
+
msg + " (strict_sdpo_alignment=True; pass False to warn-and-skip "
|
| 181 |
+
"instead for production resilience.)"
|
| 182 |
+
)
|
| 183 |
+
logger.warning("%s Skipping SDPO loss for this step.", msg)
|
| 184 |
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
|
| 185 |
|
| 186 |
return generalized_jsd_loss(
|
|
|
|
| 262 |
return next(model.parameters()).device
|
| 263 |
|
| 264 |
|
| 265 |
+
def make_dr_grpo_config(**overrides: Any):
|
| 266 |
+
"""Build a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe.
|
| 267 |
+
|
| 268 |
+
Per the Composer 2 technical report (arXiv:2603.24477,
|
| 269 |
+
research/10-composer2-techreport-mining.md) the RL base is Dr. GRPO
|
| 270 |
+
(Liu et al., arXiv:2503.20783):
|
| 271 |
+
|
| 272 |
+
- ``loss_type="dr_grpo"`` — removes GRPO's length-standardization term
|
| 273 |
+
(which injects a length bias). TRL's own help text cites the Dr. GRPO
|
| 274 |
+
paper for this.
|
| 275 |
+
- ``scale_rewards="none"`` — NO std-dev advantage normalization. TRL docs:
|
| 276 |
+
"The Dr. GRPO paper recommends not scaling rewards, as scaling by the
|
| 277 |
+
standard deviation introduces a question-level difficulty bias."
|
| 278 |
+
- ``num_iterations=1`` — single-epoch regime (a prompt is never
|
| 279 |
+
trained on twice), matching the tech report.
|
| 280 |
+
- ``beta`` (KL-to-ref coef) kept; TRL uses the k1 (−log r)-family
|
| 281 |
+
estimator the report selects.
|
| 282 |
+
|
| 283 |
+
Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
|
| 284 |
+
``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
|
| 285 |
+
explicitly overridden, and a sanity assertion guards against silent drift.
|
| 286 |
+
"""
|
| 287 |
+
from trl import GRPOConfig # local import: only when actually building a config
|
| 288 |
+
|
| 289 |
+
dr_grpo_defaults: dict[str, Any] = {
|
| 290 |
+
"loss_type": "dr_grpo",
|
| 291 |
+
"scale_rewards": "none",
|
| 292 |
+
"num_iterations": 1,
|
| 293 |
+
}
|
| 294 |
+
merged = {**dr_grpo_defaults, **overrides}
|
| 295 |
+
cfg = GRPOConfig(**merged)
|
| 296 |
+
# Guard: fail loudly if a future TRL renames/repurposes these knobs.
|
| 297 |
+
assert cfg.loss_type == merged["loss_type"], "GRPOConfig dropped loss_type"
|
| 298 |
+
assert str(cfg.scale_rewards) in ("none", "False", "False"), (
|
| 299 |
+
f"Dr. GRPO requires scale_rewards='none' (no std-norm); got {cfg.scale_rewards!r}. "
|
| 300 |
+
"TRL knob may have drifted — re-verify against trl version."
|
| 301 |
+
)
|
| 302 |
+
assert cfg.num_iterations == merged["num_iterations"], "GRPOConfig dropped num_iterations"
|
| 303 |
+
return cfg
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
__all__ = ["ComposerReplicationTrainer", "make_dr_grpo_config"]
|
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the Dr. GRPO config helper + SDPO strict-alignment guard (ADR-008).
|
| 2 |
+
|
| 3 |
+
Covers ADR-008 acceptance gates:
|
| 4 |
+
- gate 1: make_dr_grpo_config sets loss_type=dr_grpo, scale_rewards=none,
|
| 5 |
+
num_iterations=1 (asserted against the resulting GRPOConfig).
|
| 6 |
+
- gate 2: strict_sdpo_alignment=True raises on a student/teacher shape
|
| 7 |
+
mismatch rather than silently zeroing the channel.
|
| 8 |
+
- gate 5: the PRIME-RL recipe raises NotImplementedError for alpha_sdpo>0
|
| 9 |
+
with a message pointing at the TRL host (documents the ADR-006 amendment).
|
| 10 |
+
|
| 11 |
+
These are CPU-only and fast. The full end-to-end smoke (gate 3) that
|
| 12 |
+
instantiates a real GRPOTrainer on Qwen2.5-0.5B lives in
|
| 13 |
+
examples/composer_grpo_sdpo_smoke/ (gated on model cache + memory).
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
trl = pytest.importorskip("trl") # whole module is meaningless without TRL
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Gate 1 — Dr. GRPO config
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
def test_make_dr_grpo_config_sets_dr_grpo_knobs(tmp_path):
|
| 28 |
+
from composer_replication.trainer.composer_trainer import make_dr_grpo_config
|
| 29 |
+
|
| 30 |
+
cfg = make_dr_grpo_config(output_dir=str(tmp_path))
|
| 31 |
+
# loss_type=dr_grpo removes GRPO's length-standardization bias.
|
| 32 |
+
assert cfg.loss_type == "dr_grpo"
|
| 33 |
+
# scale_rewards=none => NO std-dev advantage normalization (Dr. GRPO).
|
| 34 |
+
assert str(cfg.scale_rewards) in ("none", "False")
|
| 35 |
+
# single-epoch: a prompt is never trained on twice.
|
| 36 |
+
assert cfg.num_iterations == 1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_make_dr_grpo_config_allows_overrides(tmp_path):
|
| 40 |
+
from composer_replication.trainer.composer_trainer import make_dr_grpo_config
|
| 41 |
+
|
| 42 |
+
cfg = make_dr_grpo_config(output_dir=str(tmp_path), beta=0.05, num_generations=2)
|
| 43 |
+
assert cfg.beta == 0.05
|
| 44 |
+
assert cfg.num_generations == 2
|
| 45 |
+
# Dr. GRPO knobs still forced when not overridden.
|
| 46 |
+
assert cfg.loss_type == "dr_grpo"
|
| 47 |
+
assert str(cfg.scale_rewards) in ("none", "False")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_make_dr_grpo_config_override_does_not_silently_break_guard(tmp_path):
|
| 51 |
+
"""Overriding loss_type away from dr_grpo is allowed (caller's choice) and
|
| 52 |
+
the returned config reflects it — the guard only protects the *defaults*
|
| 53 |
+
from silent TRL drift, it does not forbid explicit opt-out."""
|
| 54 |
+
from composer_replication.trainer.composer_trainer import make_dr_grpo_config
|
| 55 |
+
|
| 56 |
+
cfg = make_dr_grpo_config(output_dir=str(tmp_path), loss_type="grpo")
|
| 57 |
+
assert cfg.loss_type == "grpo"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# Gate 2 — SDPO strict-alignment guard (no real GRPOTrainer needed)
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
class _TinyLM(torch.nn.Module):
|
| 65 |
+
"""Minimal HF-style model: model(input_ids=...).logits, controllable vocab."""
|
| 66 |
+
|
| 67 |
+
def __init__(self, vocab: int = 16, hidden: int = 8):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.embed = torch.nn.Embedding(vocab, hidden)
|
| 70 |
+
self.head = torch.nn.Linear(hidden, vocab)
|
| 71 |
+
|
| 72 |
+
def forward(self, input_ids: torch.Tensor):
|
| 73 |
+
logits = self.head(self.embed(input_ids))
|
| 74 |
+
|
| 75 |
+
class _Out:
|
| 76 |
+
pass
|
| 77 |
+
out = _Out()
|
| 78 |
+
out.logits = logits
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _make_trainer_without_init(strict: bool):
|
| 83 |
+
"""Build a ComposerReplicationTrainer instance WITHOUT running GRPOTrainer's
|
| 84 |
+
heavy __init__ (which needs a real model+dataset). We only exercise
|
| 85 |
+
_compute_sdpo_loss, so we set the attributes it reads directly."""
|
| 86 |
+
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
|
| 87 |
+
|
| 88 |
+
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
|
| 89 |
+
obj.alpha_sdpo = 1.0
|
| 90 |
+
obj.sdpo_jsd_beta = 0.5
|
| 91 |
+
obj.sdpo_temperature = 1.0
|
| 92 |
+
obj.sdpo_token_clip = None
|
| 93 |
+
obj.strict_sdpo_alignment = strict
|
| 94 |
+
return obj
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_strict_alignment_raises_on_shape_mismatch():
|
| 98 |
+
obj = _make_trainer_without_init(strict=True)
|
| 99 |
+
model = _TinyLM(vocab=16)
|
| 100 |
+
# student seq len 5, teacher seq len 7 -> logit shape mismatch
|
| 101 |
+
inputs = {
|
| 102 |
+
"input_ids": torch.randint(0, 16, (1, 5)),
|
| 103 |
+
"ctx_teacher_input_ids": torch.randint(0, 16, (1, 7)),
|
| 104 |
+
"sdpo_loss_mask": torch.ones(1, 5),
|
| 105 |
+
}
|
| 106 |
+
with pytest.raises(ValueError, match="shape mismatch"):
|
| 107 |
+
obj._compute_sdpo_loss(model, inputs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_nonstrict_alignment_warns_and_skips_on_mismatch():
|
| 111 |
+
obj = _make_trainer_without_init(strict=False)
|
| 112 |
+
model = _TinyLM(vocab=16)
|
| 113 |
+
inputs = {
|
| 114 |
+
"input_ids": torch.randint(0, 16, (1, 5)),
|
| 115 |
+
"ctx_teacher_input_ids": torch.randint(0, 16, (1, 7)),
|
| 116 |
+
"sdpo_loss_mask": torch.ones(1, 5),
|
| 117 |
+
}
|
| 118 |
+
loss = obj._compute_sdpo_loss(model, inputs)
|
| 119 |
+
assert float(loss.detach()) == 0.0 # skipped, not crashed
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_aligned_shapes_produce_finite_sdpo_loss():
|
| 123 |
+
obj = _make_trainer_without_init(strict=True)
|
| 124 |
+
model = _TinyLM(vocab=16)
|
| 125 |
+
# student and teacher SAME length -> channel fires
|
| 126 |
+
inputs = {
|
| 127 |
+
"input_ids": torch.randint(0, 16, (1, 6)),
|
| 128 |
+
"ctx_teacher_input_ids": torch.randint(0, 16, (1, 6)),
|
| 129 |
+
"sdpo_loss_mask": torch.ones(1, 6),
|
| 130 |
+
}
|
| 131 |
+
loss = obj._compute_sdpo_loss(model, inputs)
|
| 132 |
+
val = float(loss.detach())
|
| 133 |
+
assert val == val # not NaN
|
| 134 |
+
assert val not in (float("inf"), float("-inf"))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_no_error_sites_returns_zero_without_forward():
|
| 138 |
+
"""Empty ctx_teacher_input_ids => no SDPO signal, returns 0 (no crash)."""
|
| 139 |
+
obj = _make_trainer_without_init(strict=True)
|
| 140 |
+
model = _TinyLM(vocab=16)
|
| 141 |
+
inputs = {
|
| 142 |
+
"input_ids": torch.randint(0, 16, (1, 6)),
|
| 143 |
+
"ctx_teacher_input_ids": torch.empty(0, dtype=torch.long),
|
| 144 |
+
"sdpo_loss_mask": torch.zeros(1, 6),
|
| 145 |
+
}
|
| 146 |
+
loss = obj._compute_sdpo_loss(model, inputs)
|
| 147 |
+
assert float(loss.detach()) == 0.0
|
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ComposerGRPOTrainer ⊕ SDPO live smoke (ADR-008 gate 3).
|
| 2 |
+
|
| 3 |
+
Instantiates a REAL `trl.GRPOTrainer` via `ComposerReplicationTrainer`, configured
|
| 4 |
+
to the Dr. GRPO recipe (`make_dr_grpo_config`), on a tiny model, and runs a
|
| 5 |
+
short training run with `alpha_sdpo>0` so the SDPO channel is live on top of the
|
| 6 |
+
Dr. GRPO policy-gradient loss.
|
| 7 |
+
|
| 8 |
+
This is the wrapper-level proof. The loss-composition CORE (compose_loss forward
|
| 9 |
+
+ backward + optimizer.step with the SDPO JSD firing on real traces) is already
|
| 10 |
+
proven CPU-only by `examples/sdpo_real_trace_train_smoke/run.py`. This script
|
| 11 |
+
proves the same SDPO channel survives inside a live TRL GRPO rollout→update loop.
|
| 12 |
+
|
| 13 |
+
Heavy + slow on CPU (TRL import alone is ~140s; GRPO generation on CPU is slow).
|
| 14 |
+
RUN DETACHED so a gateway restart can't reap it:
|
| 15 |
+
|
| 16 |
+
systemd-run --user --scope -p MemoryMax=28G -- \
|
| 17 |
+
bash -lc 'cd <repo> && source .venv/bin/activate && \
|
| 18 |
+
python examples/composer_grpo_sdpo_smoke/run.py > /tmp/grpo_smoke.log 2>&1; \
|
| 19 |
+
echo EXIT=$? >> /tmp/grpo_smoke.log; touch /tmp/grpo_smoke.done'
|
| 20 |
+
|
| 21 |
+
Gates asserted:
|
| 22 |
+
- trainer instantiates with the Dr. GRPO config (loss_type=dr_grpo,
|
| 23 |
+
scale_rewards=none, num_iterations=1) and alpha_sdpo>0;
|
| 24 |
+
- a training step runs without crashing;
|
| 25 |
+
- total loss is finite;
|
| 26 |
+
- the SDPO channel is wired (loss/sdpo_kl logged) — value may be 0.0 if the
|
| 27 |
+
tiny synthetic rollouts happen to produce no error-aligned batch, which is
|
| 28 |
+
acceptable for the WRAPPER smoke (signal-firing is proven elsewhere).
|
| 29 |
+
|
| 30 |
+
Exit 0 = PASS, 1 = FAIL, 2 = SKIP (model/TRL unavailable).
|
| 31 |
+
"""
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import sys
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main() -> int:
|
| 39 |
+
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
| 40 |
+
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
|
| 41 |
+
os.environ.setdefault("TRL_USE_VLLM", "0")
|
| 42 |
+
os.environ.setdefault("OMP_NUM_THREADS", "8")
|
| 43 |
+
|
| 44 |
+
model_id = os.environ.get("SMOKE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
import torch # noqa: F401
|
| 48 |
+
from datasets import Dataset
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 50 |
+
|
| 51 |
+
from composer_replication.trainer.composer_trainer import (
|
| 52 |
+
ComposerReplicationTrainer,
|
| 53 |
+
make_dr_grpo_config,
|
| 54 |
+
)
|
| 55 |
+
except Exception as e: # noqa: BLE001
|
| 56 |
+
print(f"SKIP: import failed: {e!r}")
|
| 57 |
+
return 2
|
| 58 |
+
|
| 59 |
+
print(f"[grpo-smoke] loading {model_id} (CPU) — slow ...")
|
| 60 |
+
try:
|
| 61 |
+
tok = AutoTokenizer.from_pretrained(model_id)
|
| 62 |
+
if tok.pad_token is None:
|
| 63 |
+
tok.pad_token = tok.eos_token
|
| 64 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 65 |
+
except Exception as e: # noqa: BLE001
|
| 66 |
+
print(f"SKIP: model/tokenizer load failed: {e!r}")
|
| 67 |
+
return 2
|
| 68 |
+
|
| 69 |
+
# Trivial verifiable reward: reward length-1 presence of a digit (toy).
|
| 70 |
+
def reward_has_digit(completions, **kwargs):
|
| 71 |
+
return [1.0 if any(c.isdigit() for c in (t or "")) else 0.0 for t in completions]
|
| 72 |
+
|
| 73 |
+
# Tiny prompt dataset.
|
| 74 |
+
prompts = [{"prompt": "Reply with a number:"}, {"prompt": "Count to three:"}]
|
| 75 |
+
ds = Dataset.from_list(prompts)
|
| 76 |
+
|
| 77 |
+
cfg = make_dr_grpo_config(
|
| 78 |
+
output_dir="/tmp/grpo_smoke_out",
|
| 79 |
+
per_device_train_batch_size=2,
|
| 80 |
+
num_generations=2,
|
| 81 |
+
max_completion_length=8,
|
| 82 |
+
max_prompt_length=32,
|
| 83 |
+
max_steps=1,
|
| 84 |
+
logging_steps=1,
|
| 85 |
+
report_to=[],
|
| 86 |
+
beta=0.0, # drop KL-to-ref for the smoke (no ref model load)
|
| 87 |
+
use_vllm=False,
|
| 88 |
+
)
|
| 89 |
+
print(f"[grpo-smoke] Dr.GRPO config: loss_type={cfg.loss_type} "
|
| 90 |
+
f"scale_rewards={cfg.scale_rewards} num_iterations={cfg.num_iterations}")
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
trainer = ComposerReplicationTrainer(
|
| 94 |
+
model=model,
|
| 95 |
+
reward_funcs=reward_has_digit,
|
| 96 |
+
args=cfg,
|
| 97 |
+
train_dataset=ds,
|
| 98 |
+
processing_class=tok,
|
| 99 |
+
# SDPO channel ON. The toy rollouts won't carry collator-built
|
| 100 |
+
# ctx_teacher_input_ids, so _compute_sdpo_loss returns 0 (no error
|
| 101 |
+
# sites) — but the channel is WIRED and logged. strict=False so the
|
| 102 |
+
# absence of error sites is a clean no-op, not an abort.
|
| 103 |
+
alpha_sdpo=1.0,
|
| 104 |
+
strict_sdpo_alignment=False,
|
| 105 |
+
)
|
| 106 |
+
except Exception as e: # noqa: BLE001
|
| 107 |
+
print(f"FAIL: trainer instantiation failed: {e!r}")
|
| 108 |
+
import traceback
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
return 1
|
| 111 |
+
|
| 112 |
+
print("[grpo-smoke] trainer instantiated; running 1 Dr. GRPO step "
|
| 113 |
+
"with alpha_sdpo=1.0 ...")
|
| 114 |
+
try:
|
| 115 |
+
trainer.train()
|
| 116 |
+
except Exception as e: # noqa: BLE001
|
| 117 |
+
print(f"FAIL: train() crashed: {e!r}")
|
| 118 |
+
import traceback
|
| 119 |
+
traceback.print_exc()
|
| 120 |
+
return 1
|
| 121 |
+
|
| 122 |
+
# If we got here, the live loop ran with the SDPO channel wired in.
|
| 123 |
+
log_history = getattr(trainer.state, "log_history", [])
|
| 124 |
+
sdpo_logged = any("loss/sdpo_kl" in row for row in log_history)
|
| 125 |
+
print("=" * 60)
|
| 126 |
+
print(f" trainer ran 1 Dr. GRPO step: OK")
|
| 127 |
+
print(f" loss/sdpo_kl present in log_history: {sdpo_logged}")
|
| 128 |
+
print(f" RESULT: PASS ✅ (SDPO channel wired into live Dr. GRPO loop)")
|
| 129 |
+
return 0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
sys.exit(main())
|