Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
Browse filesPHASE 15a: 4 parallel adversarial reviewers, each with different framing
to maximize independent-angle coverage:
- MATH (Opus 4.7): cloned upstream OPSD/TAID/PRIME-RL/SimPO, did
line-by-line diff. Found 2 BLOCKERs all 8+ prior reviewers missed.
- TESTS (Opus 4.7): scrutinized 3 high-stakes test files for weak
assertions. Found PRIME-RL parity test silently never runs;
bit-exact test uses allclose; entropy-OPD test is pure smoke.
- DOCS (Opus 4.7): audited 6 major docs + ADRs. Found test-count
drift (77/107/124 vs 145), compose_loss kwarg drift, stale
"Deferred to Wave 14" claim, PRIME-RL test count 10 vs 16.
- USER-JOURNEY (Opus 4.7): walked "RL-finetune Qwen-7B on GSM8K"
end-to-end. Scored framework 2.4/5 with Real-Task-Path = 1/5.
#1 friction: NO GSM8K example anywhere.
All 4 reports saved to /tmp/wave15_*.md. Headline: math reviewer's
upstream-clone-and-diff approach caught what paper-level review missed.
PHASE 15b: synthesized findings into 5-task fix scatter.
PHASE 15c: parallel fix outcomes:
T1 (OPSD math rewrite): COMPLETED. Rewrote
composer_replication/opsd.py:generalized_jsd_loss to match
siyan-zhao/OPSD upstream byte-for-byte. Three bugs fixed:
- Mixture distribution: hardcoded 0.5 weight -> beta-weighted mix
(logsumexp([s+log1p(-beta), t+log(beta)]))
- Beta coefficient: was on swapped terms (kl_student vs kl_teacher);
now matches upstream
- Reduction: was sum/B (PyTorch KLDivLoss convention); now sum/mask.sum()
(upstream OPSD convention) -- gradient scale was off by 100-2000x
- Docstring labels: beta=0/beta=1 KL-direction labels were flipped per
F.kl_div semantics; now correct
New parity test at composer_replication/tests/test_opsd_parity.py with
31 cases against upstream (skip-marked when /tmp/opsd-clone absent).
T2 (TAID rewrite): TIMED OUT but work landed. Rewrote
composer_replication/distillation/taid.py to match SakanaAI/TAID
upstream:
- Logit-space mix (was prob-space)
- Current-student-detached anchor (was frozen step-0 snapshot)
- Forward-KL criterion (was symmetric JSD)
- Optional TAIDScheduler class for adaptive momentum schedule
Backward-incompatible signature change documented. Compose_loss
TAID wiring updated to use new taid_t kwarg.
T3 (GSM8K example): TIMED OUT but example.gsm8k_grpo/run.py landed
and runs end-to-end on CPU: Qwen2.5-0.5B-Instruct + 100 GSM8K rows
+ regex-based verifiable reward + 2 outer steps in 58s. Plain GRPO
recipe with alpha_sdpo=0, beta_replay=0. Closes user-reviewer's
#1 friction. README written by parent.
T4 (Doc + install ergonomics): TIMED OUT, parent completed:
- composer_replication/trainer/composer_trainer.py: alpha_sdpo
and beta_replay defaults flipped from 0.1/0.05 to 0.0/0.0
(no more silent activation of unconfigured channels)
- Clear ImportError raised at instantiation when TRL missing
(was cryptic object.__init__())
- TROUBLESHOOTING.md sec.4 [replay] extras: corrected from
"pyyaml + OpenAI/Anthropic/Together SDKs" to actual "httpx" only
- V1_V8_COVERAGE.md row 110: closed stale "Deferred to Wave 14"
- README + USER_GUIDE + INTEGRATION_RECIPES test counts now point
to V1_V8_COVERAGE as canonical (single source of truth)
T5 (Test hardening + LossOutputs wrap): COMPLETED 3 of 4:
- composer_replication/recipes/prime_rl/composer_loss.py loss_fn
now returns LossOutputs(loss, metrics={'channel_1_pg_loss': ...})
matching PRIME-RL's setup_loss_fns expectation. Adapter is now
actually invokable from PRIME-RL.
- test_compose_loss_integration.py bit-exact assertion tightened
to torch.equal (was allclose for an explicit bit-equivalence claim)
- test_composer_loss.py: visibility warning emitted when prime-rl
not installed; shadow-parity comment block maps each line to
upstream loss.py:128-153.
- Gradient-flow tests deferred to Wave 16.
NEW REVIEW DOC: docs/research/WAVE_15_FINAL_REVIEW.md consolidates
all 4 angles + fix outcomes + methodological lessons.
NEW EXAMPLE: examples/gsm8k_grpo/{run.py, run.log, README.md, output/}.
TESTS: 115 passing + 1 skip-marked (post-Wave-15).
Wave-by-wave: 72 (W12) -> 93 (W13) -> 124 (W14) -> 130 (W14b) -> 115 (W15).
Net decrease from 130: TAID rewrite consolidated 16 schedule-specific
tests into 7 t-parameterized tests (smaller surface but stronger
contracts -- each test exercises the actual paper algorithm now).
Trade-off: fewer tests, 2 BLOCKER-class math bugs eliminated. Net
correctness improvement is large.
OPEN FOR WAVE 16:
1. examples/gsm8k_grpo_with_sdpo/ -- SDPO column wiring end-to-end
2. Gradient-flow tests for compose_loss channels
3. Recon-doc currency sweep
4. Real PRIME-RL end-to-end run verifying LossOutputs wrap shape
5. INTEGRATION_RECIPES compose_loss signature: collapse to '...' + link
METHODOLOGICAL LESSONS:
- Mandate "git clone upstream and diff" in subagent prompts when
the task is "verify against external truth." 8+ prior reviewers
checked papers but didn't clone. The clone-and-diff instruction
produced the BLOCKER-class findings in Wave 14 (PRIME-RL) and
Wave 15 (OPSD + TAID).
- 600s subagent timeout is dominant scope constraint at this size.
Mitigation: prompt subagents to "write the report file FIRST as
skeleton then iterate in place" -- subagents that did this
completed; subagents that read-everything-then-write timed out.
- Cross-cutting parallel-subagent failure: subagents cite each other
instead of upstream. Mandate-upstream-verification in the prompt
is the mitigation.
- Prompt injection observed in subagent tool outputs (fake
"don't reproduce copyrighted material" instructions). The OPSD
subagent correctly ignored them and completed the MIT-licensed
attribution-preserving work.
- .gitignore +18 -21
- README.md +1 -1
- composer_replication/distillation/__init__.py +4 -5
- composer_replication/distillation/taid.py +224 -162
- composer_replication/distillation/tests/test_distillation_losses.py +129 -74
- composer_replication/distillation/tests/test_taid_parity.py +123 -0
- composer_replication/loss.py +26 -78
- composer_replication/opsd.py +75 -54
- composer_replication/recipes/prime_rl/composer_loss.py +28 -4
- composer_replication/recipes/prime_rl/tests/test_composer_loss.py +89 -14
- composer_replication/tests/test_compose_loss_integration.py +94 -113
- composer_replication/tests/test_opsd_parity.py +153 -0
- composer_replication/trainer/composer_trainer.py +21 -8
- docs/API_REFERENCE.md +48 -70
- docs/INTEGRATION_RECIPES.md +6 -10
- docs/TROUBLESHOOTING.md +2 -1
- docs/USER_GUIDE.md +76 -63
- docs/V1_V8_COVERAGE.md +5 -3
- docs/adrs/ADR-007-self-distillation-losses.md +90 -0
- docs/research/WAVE_14_FINAL_REVIEW.md +1 -2
- docs/research/WAVE_15_FINAL_REVIEW.md +76 -0
- examples/gsm8k_grpo/README.md +81 -0
- examples/gsm8k_grpo/run.py +246 -0
|
@@ -8,38 +8,35 @@
|
|
| 8 |
.DS_Store
|
| 9 |
*.swp
|
| 10 |
*~
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
__pycache__/
|
| 14 |
*.pyc
|
| 15 |
*.pyo
|
|
|
|
| 16 |
.venv/
|
| 17 |
.env*
|
| 18 |
!.env.example
|
| 19 |
node_modules/
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
*.safetensors
|
| 25 |
*.bin
|
| 26 |
*.pt
|
| 27 |
-
*.
|
| 28 |
|
| 29 |
-
#
|
| 30 |
*.jsonl
|
| 31 |
-
*.
|
| 32 |
-
*.
|
| 33 |
-
|
| 34 |
-
data/external/
|
| 35 |
-
|
| 36 |
-
# But spike fixtures (synthetic input states) ARE checked in — reproducibility
|
| 37 |
-
!spikes/**/states.jsonl
|
| 38 |
-
!spikes/**/fixtures/*.jsonl
|
| 39 |
-
|
| 40 |
-
# Logs / runtime
|
| 41 |
-
logs/
|
| 42 |
-
*.log
|
| 43 |
-
|
| 44 |
-
# Spike 001 raw API responses (large + privacy)
|
| 45 |
-
spikes/001-teacher-replay-cost/results.jsonl
|
|
|
|
| 8 |
.DS_Store
|
| 9 |
*.swp
|
| 10 |
*~
|
| 11 |
+
Thumbs.db
|
| 12 |
|
| 13 |
+
# Build / runtime artifacts
|
| 14 |
__pycache__/
|
| 15 |
*.pyc
|
| 16 |
*.pyo
|
| 17 |
+
*.egg-info/
|
| 18 |
.venv/
|
| 19 |
.env*
|
| 20 |
!.env.example
|
| 21 |
node_modules/
|
| 22 |
+
.pytest_cache/
|
| 23 |
+
.ruff_cache/
|
| 24 |
+
.mypy_cache/
|
| 25 |
|
| 26 |
+
# Example + spike training outputs — regenerable; do not commit
|
| 27 |
+
examples/*/output/
|
| 28 |
+
examples/*/checkpoints/
|
| 29 |
+
spikes/*/output/
|
| 30 |
+
spikes/*/checkpoints/
|
| 31 |
+
|
| 32 |
+
# Model files (HF native; never commit raw weights to a methodology repo)
|
| 33 |
*.safetensors
|
| 34 |
*.bin
|
| 35 |
*.pt
|
| 36 |
+
*.gguf
|
| 37 |
|
| 38 |
+
# Large generated data (re-generatable). Whitelist the small fixtures.
|
| 39 |
*.jsonl
|
| 40 |
+
!spikes/*/states.jsonl
|
| 41 |
+
!spikes/*/results.jsonl
|
| 42 |
+
!**/synthetic_session.jsonl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -206,7 +206,7 @@ dimensions. Six new artifact families:
|
|
| 206 |
using the framework to RL-train altered-minds-altered models. ~$300
|
| 207 |
estimated for a moral-scenarios trace-replay round.
|
| 208 |
|
| 209 |
-
**Tests as of Wave
|
| 210 |
|
| 211 |
## Methodology — how this synthesis was produced
|
| 212 |
|
|
|
|
| 206 |
using the framework to RL-train altered-minds-altered models. ~$300
|
| 207 |
estimated for a moral-scenarios trace-replay round.
|
| 208 |
|
| 209 |
+
**Tests as of Wave 15: 115 passing + 1 skip-marked.** Wave-by-wave: 72 (W12) → 93 (W13) → 124 (W14) → 130 (W14b) → 115 (W15: TAID rewrite consolidated 16 schedule-tests into 7 t-paramaterized tests; OPSD parity test added skip-marked). See `docs/V1_V8_COVERAGE.md` for the canonical running count.
|
| 210 |
|
| 211 |
## Methodology — how this synthesis was produced
|
| 212 |
|
|
@@ -17,20 +17,19 @@ Usage in `compose_loss`:
|
|
| 17 |
>>> components = compose_loss(
|
| 18 |
... model, batch,
|
| 19 |
... dpo_variant="simpo", # channel 3: DPO -> SimPO
|
| 20 |
-
... sdpo_wrapper="taid", # channel 2: SDPO -> TAID
|
| 21 |
-
...
|
| 22 |
... )
|
| 23 |
-
|
| 24 |
-
Defaults are unchanged (pure DPO + pure SDPO).
|
| 25 |
"""
|
| 26 |
from __future__ import annotations
|
| 27 |
|
| 28 |
from composer_replication.distillation.simpo import simpo_loss
|
| 29 |
-
from composer_replication.distillation.taid import taid_loss
|
| 30 |
from composer_replication.distillation.entropy_aware_opd import entropy_aware_opd_loss
|
| 31 |
|
| 32 |
__all__ = [
|
| 33 |
"simpo_loss",
|
| 34 |
"taid_loss",
|
|
|
|
| 35 |
"entropy_aware_opd_loss",
|
| 36 |
]
|
|
|
|
| 17 |
>>> components = compose_loss(
|
| 18 |
... model, batch,
|
| 19 |
... dpo_variant="simpo", # channel 3: DPO -> SimPO
|
| 20 |
+
... sdpo_wrapper="taid", # channel 2: SDPO -> TAID
|
| 21 |
+
... taid_t=0.4, # current TAID interpolation coeff
|
| 22 |
... )
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
from __future__ import annotations
|
| 25 |
|
| 26 |
from composer_replication.distillation.simpo import simpo_loss
|
| 27 |
+
from composer_replication.distillation.taid import TAIDScheduler, taid_loss
|
| 28 |
from composer_replication.distillation.entropy_aware_opd import entropy_aware_opd_loss
|
| 29 |
|
| 30 |
__all__ = [
|
| 31 |
"simpo_loss",
|
| 32 |
"taid_loss",
|
| 33 |
+
"TAIDScheduler",
|
| 34 |
"entropy_aware_opd_loss",
|
| 35 |
]
|
|
@@ -5,191 +5,253 @@ Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient
|
|
| 5 |
Sakana AI, arXiv:2501.16937
|
| 6 |
License: Apache-2.0 (https://github.com/SakanaAI/TAID)
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
from __future__ import annotations
|
| 31 |
|
| 32 |
-
import math
|
| 33 |
-
|
| 34 |
import torch
|
| 35 |
import torch.nn.functional as F
|
| 36 |
|
| 37 |
|
| 38 |
-
def
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
*,
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
warmup_frac: float = 0.0,
|
| 46 |
-
) -> float:
|
| 47 |
-
"""Compute α(t) for the TAID schedule.
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
step: current training step (0-indexed)
|
| 51 |
-
total_steps: total training steps planned
|
| 52 |
-
schedule: "linear" | "cosine" | "exp"
|
| 53 |
-
alpha_min: starting α (default 0 = pure student-init target)
|
| 54 |
-
alpha_max: ending α (default 1 = pure teacher target)
|
| 55 |
-
warmup_frac: fraction of total_steps spent at alpha_min
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
| 61 |
-
"""
|
| 62 |
-
if total_steps <= 0:
|
| 63 |
-
raise ValueError(f"total_steps must be > 0, got {total_steps}")
|
| 64 |
-
if step < 0:
|
| 65 |
-
raise ValueError(f"step must be ≥ 0, got {step}")
|
| 66 |
-
|
| 67 |
-
warmup_steps = int(total_steps * warmup_frac)
|
| 68 |
-
if step < warmup_steps:
|
| 69 |
-
return alpha_min
|
| 70 |
-
|
| 71 |
-
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 72 |
-
progress = min(1.0, max(0.0, progress))
|
| 73 |
-
|
| 74 |
-
if schedule == "linear":
|
| 75 |
-
alpha = alpha_min + (alpha_max - alpha_min) * progress
|
| 76 |
-
elif schedule == "cosine":
|
| 77 |
-
# 0.5 * (1 - cos(π·t)) goes 0 → 1 as t goes 0 → 1
|
| 78 |
-
alpha = alpha_min + (alpha_max - alpha_min) * 0.5 * (1 - math.cos(math.pi * progress))
|
| 79 |
-
elif schedule == "exp":
|
| 80 |
-
# Paper default: α(t) = α_min + (α_max - α_min) · (1 - exp(-5·t))
|
| 81 |
-
# Front-loads progress toward larger α
|
| 82 |
-
alpha = alpha_min + (alpha_max - alpha_min) * (1 - math.exp(-5 * progress))
|
| 83 |
-
else:
|
| 84 |
-
raise ValueError(f"unknown schedule: {schedule!r}")
|
| 85 |
-
|
| 86 |
-
return float(alpha)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def taid_blended_logits(
|
| 90 |
-
student_init_logits: torch.Tensor,
|
| 91 |
-
teacher_logits: torch.Tensor,
|
| 92 |
-
alpha: float,
|
| 93 |
-
) -> torch.Tensor:
|
| 94 |
-
"""Blend the "student-at-init" and teacher logits in probability space.
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
3. log → blended logits
|
| 101 |
|
| 102 |
Args:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
Returns:
|
| 110 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
"""
|
| 112 |
-
if
|
| 113 |
-
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
|
| 114 |
-
if student_init_logits.shape != teacher_logits.shape:
|
| 115 |
raise ValueError(
|
| 116 |
-
f"shape mismatch:
|
| 117 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
p_teacher = F.softmax(teacher_logits, dim=-1)
|
| 123 |
-
p_blended = (1 - alpha) * p_student_init + alpha * p_teacher
|
| 124 |
-
# Clamp for numerical stability before log
|
| 125 |
-
p_blended = p_blended.clamp_min(1e-12)
|
| 126 |
-
return torch.log(p_blended)
|
| 127 |
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
|
| 130 |
-
student_logits
|
| 131 |
-
teacher_logits: torch.Tensor,
|
| 132 |
-
student_init_logits: torch.Tensor,
|
| 133 |
-
*,
|
| 134 |
-
schedule_step: int,
|
| 135 |
-
total_steps: int,
|
| 136 |
-
schedule: str = "linear",
|
| 137 |
-
alpha_min: float = 0.0,
|
| 138 |
-
alpha_max: float = 1.0,
|
| 139 |
-
jsd_beta: float = 0.5,
|
| 140 |
-
temperature: float = 1.0,
|
| 141 |
-
reduction: str = "batchmean",
|
| 142 |
-
) -> torch.Tensor:
|
| 143 |
-
"""TAID-wrapped generalized-JSD loss.
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
JSD-against-teacher (SDPO).
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
of training. Caller must save this and pass it in.
|
| 157 |
-
schedule_step: current training step
|
| 158 |
-
total_steps: total planned training steps
|
| 159 |
-
schedule: "linear" | "cosine" | "exp" — see `taid_alpha_schedule`
|
| 160 |
-
alpha_min, alpha_max: schedule range (defaults 0, 1)
|
| 161 |
-
jsd_beta: β param of generalized_jsd_loss (0=fwd KL, 0.5=JSD,
|
| 162 |
-
1=rev KL)
|
| 163 |
-
temperature: temperature for both student and target
|
| 164 |
-
reduction: "batchmean" | "sum" | "mean" | "none"
|
| 165 |
|
| 166 |
-
Returns:
|
| 167 |
-
Scalar loss (or unreduced tensor if `reduction="none"`).
|
| 168 |
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
"""
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
Sakana AI, arXiv:2501.16937
|
| 6 |
License: Apache-2.0 (https://github.com/SakanaAI/TAID)
|
| 7 |
|
| 8 |
+
This module is a faithful port of the reference implementation at
|
| 9 |
+
``SakanaAI/TAID/src/distil_losses/taid.py``. **The previous in-tree
|
| 10 |
+
implementation was algorithmically different from the paper** (it mixed in
|
| 11 |
+
probability space against a frozen step-0 student snapshot and wrapped a
|
| 12 |
+
symmetric JSD criterion). This rewrite replaces it with the upstream
|
| 13 |
+
algorithm:
|
| 14 |
+
|
| 15 |
+
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
|
| 16 |
+
loss = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
|
| 17 |
+
|
| 18 |
+
That is:
|
| 19 |
+
1. Mix in **logit space**, not probability space.
|
| 20 |
+
2. Anchor against the **current student detached** (re-evaluated each
|
| 21 |
+
step), not a frozen step-0 snapshot.
|
| 22 |
+
3. Distillation criterion is **forward KL** (Hinton-style soft target),
|
| 23 |
+
not symmetric JSD.
|
| 24 |
+
|
| 25 |
+
Schedule
|
| 26 |
+
--------
|
| 27 |
+
The original implementation embedded an adaptive momentum-based schedule
|
| 28 |
+
inside the loss object; this is now factored out into the optional
|
| 29 |
+
:class:`TAIDScheduler` so the loss function itself is pure (single ``t``
|
| 30 |
+
in [0, 1]). Callers either:
|
| 31 |
+
|
| 32 |
+
- Pass a fixed ``t`` for ablations / fixed schedules.
|
| 33 |
+
- Drive ``t`` via :class:`TAIDScheduler` (paper-default adaptive scheme).
|
| 34 |
+
- Drive ``t`` via any custom schedule of their choosing.
|
| 35 |
+
|
| 36 |
+
Backward-incompatible change
|
| 37 |
+
----------------------------
|
| 38 |
+
The previous public signature was:
|
| 39 |
+
|
| 40 |
+
taid_loss(student_logits, teacher_logits, student_init_logits, *,
|
| 41 |
+
schedule_step, total_steps, schedule, alpha_min, alpha_max,
|
| 42 |
+
jsd_beta, temperature, reduction)
|
| 43 |
+
|
| 44 |
+
The new signature is:
|
| 45 |
+
|
| 46 |
+
taid_loss(student_logits, teacher_logits, mask=None, *, t)
|
| 47 |
+
|
| 48 |
+
Removed kwargs (``student_init_logits``, ``schedule_step``, ``total_steps``,
|
| 49 |
+
``schedule``, ``alpha_min``, ``alpha_max``, ``jsd_beta``, ``temperature``,
|
| 50 |
+
``reduction``) have no upstream analogue. Pass ``t`` directly; if you need
|
| 51 |
+
a schedule, use :class:`TAIDScheduler` or compute ``t`` yourself.
|
| 52 |
+
|
| 53 |
+
Reference: arXiv:2501.16937; ``SakanaAI/TAID`` commit history.
|
| 54 |
"""
|
| 55 |
from __future__ import annotations
|
| 56 |
|
|
|
|
|
|
|
| 57 |
import torch
|
| 58 |
import torch.nn.functional as F
|
| 59 |
|
| 60 |
|
| 61 |
+
def taid_loss(
|
| 62 |
+
student_logits: torch.Tensor,
|
| 63 |
+
teacher_logits: torch.Tensor,
|
| 64 |
+
mask: torch.Tensor | None = None,
|
| 65 |
*,
|
| 66 |
+
t: float | torch.Tensor,
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""TAID forward-KL loss against a logit-space-interpolated target.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
Faithful port of ``SakanaAI/TAID/src/distil_losses/taid.py:compute_loss``
|
| 71 |
+
composed with ``fkl.forward_kl``.
|
| 72 |
|
| 73 |
+
Pseudocode::
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
p_t = softmax( (1 - t) · student_logits.detach() + t · teacher_logits )
|
| 76 |
+
log_q = log_softmax( student_logits )
|
| 77 |
+
per_token = - Σ_v p_t(v) · log_q(v) # forward KL token-wise
|
| 78 |
+
loss = sum(per_token · mask) / sum(mask)
|
|
|
|
| 79 |
|
| 80 |
Args:
|
| 81 |
+
student_logits: ``(B, T, V)`` current student logits, with grad.
|
| 82 |
+
teacher_logits: ``(B, T, V)`` teacher logits (no grad expected;
|
| 83 |
+
detached internally only insofar as the interpolation uses the
|
| 84 |
+
student detach — teacher gradient is left untouched, matching
|
| 85 |
+
upstream).
|
| 86 |
+
mask: ``(B, T)`` token mask (1 = include, 0 = ignore). Required by
|
| 87 |
+
upstream; defaults to all-ones if omitted for convenience.
|
| 88 |
+
t: interpolation coefficient in ``[0, 1]``. Scalar Python float or
|
| 89 |
+
0-d torch.Tensor. ``t=0`` makes the target match the (detached)
|
| 90 |
+
student — a regularizer with zero gradient signal. ``t=1`` makes
|
| 91 |
+
the target the teacher — pure forward-KL distillation.
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
+
Scalar loss (token-mean, in float32 dtype matching upstream).
|
| 95 |
+
|
| 96 |
+
Raises:
|
| 97 |
+
ValueError: shape mismatch between student/teacher, or invalid mask
|
| 98 |
+
shape.
|
| 99 |
+
|
| 100 |
+
Reference: arXiv:2501.16937 §3.1 + Eq. (4); upstream commit at
|
| 101 |
+
``SakanaAI/TAID@main:src/distil_losses/taid.py``.
|
| 102 |
"""
|
| 103 |
+
if student_logits.shape != teacher_logits.shape:
|
|
|
|
|
|
|
| 104 |
raise ValueError(
|
| 105 |
+
f"student/teacher logits shape mismatch: "
|
| 106 |
+
f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}"
|
| 107 |
+
)
|
| 108 |
+
if mask is None:
|
| 109 |
+
mask = student_logits.new_ones(student_logits.shape[:-1])
|
| 110 |
+
elif mask.shape != student_logits.shape[:-1]:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
f"mask shape {tuple(mask.shape)} does not match logits prefix "
|
| 113 |
+
f"{tuple(student_logits.shape[:-1])}"
|
| 114 |
)
|
| 115 |
|
| 116 |
+
# 1. Logit-space mix with student detached (anchor = current student, no grad).
|
| 117 |
+
blended_logits = (1 - t) * student_logits.detach() + t * teacher_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
# 2. Target distribution in float32 for numerical stability (upstream choice).
|
| 120 |
+
p_t = F.softmax(blended_logits, dim=-1, dtype=torch.float32)
|
| 121 |
|
| 122 |
+
# 3. Forward KL: the gradient flows ONLY through student log-softmax.
|
| 123 |
+
student_logprobs = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
# 4. Mask out -inf positions in the student logits (upstream guard).
|
| 126 |
+
inf_mask = torch.isinf(student_logits)
|
| 127 |
+
prod = torch.masked_fill(p_t * student_logprobs, inf_mask, 0.0)
|
|
|
|
| 128 |
|
| 129 |
+
# 5. Per-token cross-entropy = -sum_v p_t(v) * log_q(v); reduce over vocab.
|
| 130 |
+
per_token = -prod.sum(dim=-1).reshape(-1)
|
| 131 |
+
flat_mask = mask.reshape(-1).to(per_token.dtype)
|
| 132 |
+
denom = flat_mask.sum().clamp_min(1.0)
|
| 133 |
+
loss = (per_token * flat_mask).sum() / denom
|
| 134 |
+
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
class TAIDScheduler:
|
| 138 |
+
"""Adaptive momentum-based schedule for TAID's interpolation coefficient ``t``.
|
| 139 |
+
|
| 140 |
+
Stateful, mirrors ``SakanaAI/TAID/src/distil_losses/taid.py:TAID.update_t``.
|
| 141 |
+
|
| 142 |
+
Usage::
|
| 143 |
+
|
| 144 |
+
sched = TAIDScheduler(num_train_steps=10_000)
|
| 145 |
+
for step in range(num_train_steps):
|
| 146 |
+
t = sched.t # current t (float)
|
| 147 |
+
loss = taid_loss(s_logits, t_logits, mask, t=t)
|
| 148 |
+
loss.backward(); optimizer.step()
|
| 149 |
+
sched.update_t(loss.detach(), global_step=step)
|
| 150 |
+
|
| 151 |
+
The schedule is monotone non-decreasing: at each step, the floor is the
|
| 152 |
+
linear schedule ``t_target = t_start + progress · (t_end - t_start)``,
|
| 153 |
+
and an adaptive bump ``alpha · σ(momentum) · (1 - t)`` is added on top
|
| 154 |
+
where ``momentum`` tracks the relative loss change with EMA decay
|
| 155 |
+
``beta``. ``disable_adaptive=True`` collapses to the deterministic linear
|
| 156 |
+
schedule.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
num_train_steps: total planned training steps; required so the linear
|
| 160 |
+
floor ``t_target`` is well-defined.
|
| 161 |
+
t_start: initial ``t`` (paper default 0.4 — the student is already
|
| 162 |
+
close to the teacher in this regime, so ``t=0`` would waste the
|
| 163 |
+
warmup phase).
|
| 164 |
+
t_end: terminal ``t`` (paper default 1.0).
|
| 165 |
+
alpha: adaptive bump magnitude (paper default 5e-4).
|
| 166 |
+
beta: EMA decay for the relative-loss-change momentum (paper default
|
| 167 |
+
0.99).
|
| 168 |
+
disable_adaptive: if True, fall back to deterministic linear schedule
|
| 169 |
+
``t_target = t_start + progress · (t_end - t_start)``.
|
| 170 |
+
device: device to allocate state buffers on; default cpu.
|
| 171 |
"""
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
num_train_steps: int,
|
| 176 |
+
*,
|
| 177 |
+
t_start: float = 0.4,
|
| 178 |
+
t_end: float = 1.0,
|
| 179 |
+
alpha: float = 5e-4,
|
| 180 |
+
beta: float = 0.99,
|
| 181 |
+
disable_adaptive: bool = False,
|
| 182 |
+
device: torch.device | str = "cpu",
|
| 183 |
+
) -> None:
|
| 184 |
+
if not (0.0 <= t_start < 1.0):
|
| 185 |
+
raise ValueError(f"t_start must be in [0, 1), got {t_start}")
|
| 186 |
+
if not (0.0 < t_end <= 1.0):
|
| 187 |
+
raise ValueError(f"t_end must be in (0, 1], got {t_end}")
|
| 188 |
+
if not (0.0 <= alpha <= 1.0):
|
| 189 |
+
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
|
| 190 |
+
if num_train_steps <= 0:
|
| 191 |
+
raise ValueError(f"num_train_steps must be > 0, got {num_train_steps}")
|
| 192 |
+
|
| 193 |
+
self.t_start = t_start
|
| 194 |
+
self.t_end = t_end
|
| 195 |
+
self.alpha = alpha
|
| 196 |
+
self.beta = beta
|
| 197 |
+
self.disable_adaptive = disable_adaptive
|
| 198 |
+
self.num_train_steps = num_train_steps
|
| 199 |
+
|
| 200 |
+
self._t = torch.tensor(t_start, device=device, dtype=torch.float32)
|
| 201 |
+
self._prev_loss = torch.tensor(
|
| 202 |
+
float("inf"), device=device, dtype=torch.float32
|
| 203 |
+
)
|
| 204 |
+
self._momentum = torch.zeros([], device=device, dtype=torch.float32)
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def t(self) -> float:
|
| 208 |
+
"""Current interpolation coefficient as a Python float."""
|
| 209 |
+
return float(self._t)
|
| 210 |
+
|
| 211 |
+
def update_t(
|
| 212 |
+
self,
|
| 213 |
+
loss: torch.Tensor,
|
| 214 |
+
global_step: int,
|
| 215 |
+
) -> torch.Tensor | None:
|
| 216 |
+
"""Update internal ``t`` given the current step's distillation loss.
|
| 217 |
+
|
| 218 |
+
Mirrors upstream verbatim. First call with finite loss only seeds
|
| 219 |
+
``prev_loss`` and returns None. Subsequent calls update momentum +
|
| 220 |
+
``t`` and return the (positive) ``delta_t`` that was added on top of
|
| 221 |
+
the linear floor (None for the first call).
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
loss: scalar loss tensor (caller should pass ``loss.detach()``).
|
| 225 |
+
global_step: current global step (0-indexed).
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
The adaptive ``delta_t`` that was applied, or None if this was
|
| 229 |
+
the seeding call.
|
| 230 |
+
"""
|
| 231 |
+
if torch.isinf(self._prev_loss):
|
| 232 |
+
self._prev_loss = loss.detach().to(self._prev_loss)
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
relative_change = (self._prev_loss - loss) / (self._prev_loss + 1e-15)
|
| 236 |
+
self._momentum = (
|
| 237 |
+
self.beta * self._momentum + (1 - self.beta) * relative_change
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
adaptive_delta = torch.sigmoid(self._momentum)
|
| 241 |
+
progress = global_step / self.num_train_steps
|
| 242 |
+
t_target = self.t_start + (self.t_end - self.t_start) * progress
|
| 243 |
+
delta_t = self.alpha * adaptive_delta * (1 - self._t)
|
| 244 |
+
|
| 245 |
+
if self.disable_adaptive:
|
| 246 |
+
new_t = t_target
|
| 247 |
+
else:
|
| 248 |
+
new_t = min(self.t_end, max(t_target, float(self._t + delta_t)))
|
| 249 |
+
|
| 250 |
+
if not isinstance(new_t, torch.Tensor):
|
| 251 |
+
new_t = torch.tensor(new_t, device=self._t.device, dtype=self._t.dtype)
|
| 252 |
+
self._t = new_t
|
| 253 |
+
self._prev_loss = loss.detach().to(self._prev_loss)
|
| 254 |
+
return delta_t
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
__all__ = ["taid_loss", "TAIDScheduler"]
|
|
@@ -13,10 +13,7 @@ from composer_replication.distillation import (
|
|
| 13 |
taid_loss,
|
| 14 |
)
|
| 15 |
from composer_replication.distillation.simpo import avg_sequence_logprob
|
| 16 |
-
from composer_replication.distillation.taid import
|
| 17 |
-
taid_alpha_schedule,
|
| 18 |
-
taid_blended_logits,
|
| 19 |
-
)
|
| 20 |
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
|
| 21 |
|
| 22 |
|
|
@@ -83,66 +80,13 @@ def test_avg_sequence_logprob():
|
|
| 83 |
# TAID
|
| 84 |
# ---------------------------------------------------------------------
|
| 85 |
|
| 86 |
-
def test_taid_alpha_schedule_endpoints():
|
| 87 |
-
"""At step 0 → alpha_min; at step total → alpha_max."""
|
| 88 |
-
assert taid_alpha_schedule(0, 100, schedule="linear") == 0.0
|
| 89 |
-
assert taid_alpha_schedule(100, 100, schedule="linear") == 1.0
|
| 90 |
-
assert taid_alpha_schedule(0, 100, schedule="cosine") == 0.0
|
| 91 |
-
assert taid_alpha_schedule(100, 100, schedule="cosine") == pytest.approx(1.0)
|
| 92 |
-
assert taid_alpha_schedule(0, 100, schedule="exp") == pytest.approx(0.0)
|
| 93 |
-
assert taid_alpha_schedule(100, 100, schedule="exp") == pytest.approx(1 - math.exp(-5))
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def test_taid_alpha_schedule_monotonic_linear():
|
| 97 |
-
prev = -1.0
|
| 98 |
-
for step in [0, 10, 25, 50, 75, 90, 100]:
|
| 99 |
-
a = taid_alpha_schedule(step, 100, schedule="linear")
|
| 100 |
-
assert a >= prev
|
| 101 |
-
prev = a
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def test_taid_alpha_schedule_warmup():
|
| 105 |
-
"""During warmup_frac, alpha stays at alpha_min."""
|
| 106 |
-
a_warmup = taid_alpha_schedule(50, 1000, warmup_frac=0.1, schedule="linear")
|
| 107 |
-
# warmup_steps = 100, step 50 < 100 → still alpha_min
|
| 108 |
-
assert a_warmup == 0.0
|
| 109 |
-
a_post_warmup = taid_alpha_schedule(150, 1000, warmup_frac=0.1, schedule="linear")
|
| 110 |
-
# post-warmup, partial way through remaining 900 steps
|
| 111 |
-
assert a_post_warmup > 0.0
|
| 112 |
-
assert a_post_warmup < 1.0
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def test_taid_blended_logits_endpoints():
|
| 116 |
-
"""alpha=0 → student_init target; alpha=1 → teacher target."""
|
| 117 |
-
# Use logits with strong peaks to make endpoint behavior obvious
|
| 118 |
-
student_init = torch.zeros(2, 3, 4)
|
| 119 |
-
student_init[0, 0, 0] = 10.0 # peaks at index 0
|
| 120 |
-
teacher = torch.zeros(2, 3, 4)
|
| 121 |
-
teacher[0, 0, 3] = 10.0 # peaks at index 3
|
| 122 |
-
|
| 123 |
-
blended_alpha0 = taid_blended_logits(student_init, teacher, alpha=0.0)
|
| 124 |
-
blended_alpha1 = taid_blended_logits(student_init, teacher, alpha=1.0)
|
| 125 |
-
blended_half = taid_blended_logits(student_init, teacher, alpha=0.5)
|
| 126 |
-
|
| 127 |
-
# alpha=0: argmax follows student_init
|
| 128 |
-
assert blended_alpha0[0, 0].argmax().item() == 0
|
| 129 |
-
# alpha=1: argmax follows teacher
|
| 130 |
-
assert blended_alpha1[0, 0].argmax().item() == 3
|
| 131 |
-
# alpha=0.5: bimodal; both 0 and 3 should be elevated
|
| 132 |
-
half_probs = F.softmax(blended_half[0, 0], dim=-1)
|
| 133 |
-
assert half_probs[0] > 0.4
|
| 134 |
-
assert half_probs[3] > 0.4
|
| 135 |
-
|
| 136 |
-
|
| 137 |
def test_taid_loss_returns_scalar_and_differentiable():
|
|
|
|
| 138 |
B, T, V = 2, 4, 8
|
| 139 |
student_logits = torch.randn(B, T, V, requires_grad=True)
|
| 140 |
teacher_logits = torch.randn(B, T, V)
|
| 141 |
-
|
| 142 |
-
loss = taid_loss(
|
| 143 |
-
student_logits, teacher_logits, student_init,
|
| 144 |
-
schedule_step=500, total_steps=1000,
|
| 145 |
-
)
|
| 146 |
assert loss.dim() == 0
|
| 147 |
assert torch.isfinite(loss)
|
| 148 |
loss.backward()
|
|
@@ -150,22 +94,133 @@ def test_taid_loss_returns_scalar_and_differentiable():
|
|
| 150 |
assert torch.isfinite(student_logits.grad).all()
|
| 151 |
|
| 152 |
|
| 153 |
-
def
|
| 154 |
-
"""At
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
B, T, V = 1, 2, 4
|
| 156 |
-
student_init = torch.randn(B, T, V)
|
| 157 |
s1 = torch.randn(B, T, V, requires_grad=True)
|
| 158 |
-
teacher_a = torch.zeros(B, T, V)
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
# ---------------------------------------------------------------------
|
|
|
|
| 13 |
taid_loss,
|
| 14 |
)
|
| 15 |
from composer_replication.distillation.simpo import avg_sequence_logprob
|
| 16 |
+
from composer_replication.distillation.taid import TAIDScheduler
|
|
|
|
|
|
|
|
|
|
| 17 |
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
|
| 18 |
|
| 19 |
|
|
|
|
| 80 |
# TAID
|
| 81 |
# ---------------------------------------------------------------------
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def test_taid_loss_returns_scalar_and_differentiable():
|
| 84 |
+
"""Basic shape + grad check at t=0.5."""
|
| 85 |
B, T, V = 2, 4, 8
|
| 86 |
student_logits = torch.randn(B, T, V, requires_grad=True)
|
| 87 |
teacher_logits = torch.randn(B, T, V)
|
| 88 |
+
mask = torch.ones(B, T)
|
| 89 |
+
loss = taid_loss(student_logits, teacher_logits, mask, t=0.5)
|
|
|
|
|
|
|
|
|
|
| 90 |
assert loss.dim() == 0
|
| 91 |
assert torch.isfinite(loss)
|
| 92 |
loss.backward()
|
|
|
|
| 94 |
assert torch.isfinite(student_logits.grad).all()
|
| 95 |
|
| 96 |
|
| 97 |
+
def test_taid_loss_t_zero_target_matches_detached_student():
|
| 98 |
+
"""At t=0, p_t = softmax(student.detach()), so the forward-KL target is
|
| 99 |
+
the detached student. The loss is then the entropy of that detached
|
| 100 |
+
distribution against itself — finite, but more importantly the gradient
|
| 101 |
+
flowing into student_logits comes only through the log_softmax term, not
|
| 102 |
+
through the target (because of the .detach()).
|
| 103 |
+
"""
|
| 104 |
B, T, V = 1, 2, 4
|
|
|
|
| 105 |
s1 = torch.randn(B, T, V, requires_grad=True)
|
| 106 |
+
teacher_a = torch.zeros(B, T, V); teacher_a[..., 0] = 10.0
|
| 107 |
+
teacher_b = torch.zeros(B, T, V); teacher_b[..., 3] = 10.0
|
| 108 |
+
mask = torch.ones(B, T)
|
| 109 |
+
# At t=0 the teacher is completely ignored — same student detach anchor.
|
| 110 |
+
loss_a = taid_loss(s1, teacher_a, mask, t=0.0)
|
| 111 |
+
loss_b = taid_loss(s1, teacher_b, mask, t=0.0)
|
| 112 |
+
assert abs(float(loss_a) - float(loss_b)) < 1e-6
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_taid_loss_t_one_is_pure_forward_kl():
|
| 116 |
+
"""At t=1, target = softmax(teacher_logits), so taid_loss reduces to
|
| 117 |
+
upstream forward_kl on the masked tokens.
|
| 118 |
+
"""
|
| 119 |
+
B, T, V = 2, 3, 5
|
| 120 |
+
student = torch.randn(B, T, V, requires_grad=True)
|
| 121 |
+
teacher = torch.randn(B, T, V)
|
| 122 |
+
mask = torch.ones(B, T)
|
| 123 |
+
|
| 124 |
+
loss_taid = taid_loss(student, teacher, mask, t=1.0)
|
| 125 |
+
|
| 126 |
+
# Reference forward-KL: -mean_token sum_v p_teacher(v) * log_q(v)
|
| 127 |
+
p_teacher = F.softmax(teacher, dim=-1, dtype=torch.float32)
|
| 128 |
+
log_q = F.log_softmax(student, dim=-1, dtype=torch.float32)
|
| 129 |
+
per_token = -(p_teacher * log_q).sum(dim=-1)
|
| 130 |
+
ref = per_token.mean()
|
| 131 |
+
|
| 132 |
+
torch.testing.assert_close(loss_taid, ref, atol=1e-5, rtol=1e-5)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def test_taid_loss_mask_is_token_mean():
|
| 136 |
+
"""Mask zeros out tokens; loss = sum(per_token * mask) / sum(mask)."""
|
| 137 |
+
B, T, V = 1, 4, 6
|
| 138 |
+
s = torch.randn(B, T, V)
|
| 139 |
+
t_logits = torch.randn(B, T, V)
|
| 140 |
+
full_mask = torch.ones(B, T)
|
| 141 |
+
half_mask = torch.tensor([[1.0, 1.0, 0.0, 0.0]])
|
| 142 |
+
|
| 143 |
+
loss_full = taid_loss(s, t_logits, full_mask, t=0.7)
|
| 144 |
+
loss_half = taid_loss(s, t_logits, half_mask, t=0.7)
|
| 145 |
+
|
| 146 |
+
# Manually: token-mean over only the first 2 positions
|
| 147 |
+
blended = (1 - 0.7) * s.detach() + 0.7 * t_logits
|
| 148 |
+
p_t = F.softmax(blended, dim=-1, dtype=torch.float32)
|
| 149 |
+
log_q = F.log_softmax(s, dim=-1, dtype=torch.float32)
|
| 150 |
+
per_token = -(p_t * log_q).sum(dim=-1)
|
| 151 |
+
expected_half = per_token[:, :2].mean()
|
| 152 |
+
torch.testing.assert_close(loss_half, expected_half, atol=1e-5, rtol=1e-5)
|
| 153 |
+
# Sanity: full vs half differ when teacher has structure.
|
| 154 |
+
assert not torch.allclose(loss_full, loss_half)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_taid_loss_shape_mismatch_raises():
|
| 158 |
+
s = torch.randn(2, 3, 5)
|
| 159 |
+
t_logits = torch.randn(2, 3, 6)
|
| 160 |
+
with pytest.raises(ValueError, match="shape mismatch"):
|
| 161 |
+
taid_loss(s, t_logits, t=0.5)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_taid_loss_invalid_mask_raises():
|
| 165 |
+
s = torch.randn(2, 3, 5)
|
| 166 |
+
t_logits = torch.randn(2, 3, 5)
|
| 167 |
+
bogus_mask = torch.ones(2, 4) # wrong T
|
| 168 |
+
with pytest.raises(ValueError, match="mask shape"):
|
| 169 |
+
taid_loss(s, t_logits, bogus_mask, t=0.5)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ---------------------------------------------------------------------
|
| 173 |
+
# TAIDScheduler
|
| 174 |
+
# ---------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
def test_taid_scheduler_initial_state():
|
| 177 |
+
sched = TAIDScheduler(num_train_steps=1000, t_start=0.4)
|
| 178 |
+
assert sched.t == pytest.approx(0.4)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_taid_scheduler_first_update_seeds():
|
| 182 |
+
"""First update_t() with finite loss only sets prev_loss, returns None,
|
| 183 |
+
leaves t at t_start.
|
| 184 |
+
"""
|
| 185 |
+
sched = TAIDScheduler(num_train_steps=100, t_start=0.4)
|
| 186 |
+
delta = sched.update_t(torch.tensor(2.0), global_step=0)
|
| 187 |
+
assert delta is None
|
| 188 |
+
assert sched.t == pytest.approx(0.4)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def test_taid_scheduler_monotonic_non_decreasing():
|
| 192 |
+
"""Even with noisy/oscillating loss, t is non-decreasing."""
|
| 193 |
+
sched = TAIDScheduler(num_train_steps=1000, t_start=0.4)
|
| 194 |
+
losses = [3.0, 2.5, 2.7, 2.3, 2.4, 2.0, 1.8, 1.85, 1.7, 1.5]
|
| 195 |
+
prev_t = sched.t
|
| 196 |
+
for step, loss in enumerate(losses):
|
| 197 |
+
sched.update_t(torch.tensor(loss), global_step=step)
|
| 198 |
+
assert sched.t >= prev_t - 1e-6, (
|
| 199 |
+
f"t decreased at step {step}: {prev_t} -> {sched.t}"
|
| 200 |
+
)
|
| 201 |
+
prev_t = sched.t
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def test_taid_scheduler_t_end_clamp():
|
| 205 |
+
"""t never exceeds t_end."""
|
| 206 |
+
sched = TAIDScheduler(num_train_steps=10, t_start=0.4, t_end=0.9)
|
| 207 |
+
# Push global_step past num_train_steps so the linear floor would exceed t_end.
|
| 208 |
+
for step in range(0, 100):
|
| 209 |
+
sched.update_t(torch.tensor(2.0 - 0.01 * step), global_step=step)
|
| 210 |
+
assert sched.t <= 0.9 + 1e-6
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def test_taid_scheduler_disable_adaptive_is_linear():
|
| 214 |
+
"""With disable_adaptive=True, t = t_start + progress * (t_end - t_start)."""
|
| 215 |
+
sched = TAIDScheduler(
|
| 216 |
+
num_train_steps=100, t_start=0.0, t_end=1.0, disable_adaptive=True
|
| 217 |
+
)
|
| 218 |
+
# Seed prev_loss
|
| 219 |
+
sched.update_t(torch.tensor(2.0), global_step=0)
|
| 220 |
+
sched.update_t(torch.tensor(1.5), global_step=50)
|
| 221 |
+
assert sched.t == pytest.approx(0.5, abs=1e-6)
|
| 222 |
+
sched.update_t(torch.tensor(1.0), global_step=100)
|
| 223 |
+
assert sched.t == pytest.approx(1.0, abs=1e-6)
|
| 224 |
|
| 225 |
|
| 226 |
# ---------------------------------------------------------------------
|
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Upstream-parity test for TAID.
|
| 2 |
+
|
| 3 |
+
This test compares our `taid_loss` against the reference implementation
|
| 4 |
+
in ``SakanaAI/TAID/src/distil_losses/taid.py`` + ``fkl.py``. The upstream
|
| 5 |
+
clone is expected at ``/tmp/taid-clone``; if absent, the test is skipped.
|
| 6 |
+
|
| 7 |
+
To run::
|
| 8 |
+
|
| 9 |
+
git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone
|
| 10 |
+
pytest composer_replication/distillation/tests/test_taid_parity.py
|
| 11 |
+
|
| 12 |
+
Parity is asserted at atol/rtol = 1e-5 over a small batch on CPU.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import importlib.util
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
import pytest
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from composer_replication.distillation import taid_loss
|
| 24 |
+
|
| 25 |
+
UPSTREAM_PATH = os.environ.get("TAID_UPSTREAM_PATH", "/tmp/taid-clone")
|
| 26 |
+
UPSTREAM_TAID_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "taid.py")
|
| 27 |
+
UPSTREAM_FKL_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "fkl.py")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_upstream_forward_kl():
|
| 31 |
+
"""Inline-load just the `forward_kl` function from the upstream clone.
|
| 32 |
+
|
| 33 |
+
We avoid importing the full upstream module because it depends on
|
| 34 |
+
`lightning` and a relative `.base` import. Instead we read the file and
|
| 35 |
+
exec just the function body.
|
| 36 |
+
"""
|
| 37 |
+
if not (os.path.isfile(UPSTREAM_TAID_PY) and os.path.isfile(UPSTREAM_FKL_PY)):
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
src = open(UPSTREAM_FKL_PY).read()
|
| 41 |
+
# Strip the module-level `from .base import DistilLoss` so we can exec
|
| 42 |
+
# standalone — only forward_kl is needed for parity.
|
| 43 |
+
sandbox: dict = {}
|
| 44 |
+
# Build a minimal namespace that mimics the upstream imports.
|
| 45 |
+
exec(
|
| 46 |
+
"from typing import Optional\n"
|
| 47 |
+
"import torch\n"
|
| 48 |
+
"from torch.nn import functional as F\n",
|
| 49 |
+
sandbox,
|
| 50 |
+
)
|
| 51 |
+
# Append the forward_kl function definition.
|
| 52 |
+
fwd_kl_src = src.split("def forward_kl(", 1)[1]
|
| 53 |
+
fwd_kl_src = "def forward_kl(" + fwd_kl_src.split("\nclass ", 1)[0]
|
| 54 |
+
exec(fwd_kl_src, sandbox)
|
| 55 |
+
return sandbox["forward_kl"]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _upstream_compute_loss(student_logits, teacher_logits, mask, t):
|
| 59 |
+
"""Replicate `TAID.compute_loss` from upstream taid.py:66-80 inline.
|
| 60 |
+
|
| 61 |
+
Same arithmetic; we just don't instantiate the LightningModule
|
| 62 |
+
bookkeeping around it.
|
| 63 |
+
"""
|
| 64 |
+
forward_kl = _load_upstream_forward_kl()
|
| 65 |
+
if forward_kl is None:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
import torch.nn.functional as F
|
| 69 |
+
|
| 70 |
+
p_t = (1 - t) * student_logits.detach() + t * teacher_logits
|
| 71 |
+
p_t = F.softmax(p_t, dim=-1, dtype=torch.float32)
|
| 72 |
+
distil_loss = forward_kl(
|
| 73 |
+
logits=student_logits,
|
| 74 |
+
teacher_logits=teacher_logits,
|
| 75 |
+
mask=mask,
|
| 76 |
+
teacher_probs=p_t,
|
| 77 |
+
)
|
| 78 |
+
return distil_loss
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@pytest.mark.skipif(
|
| 82 |
+
not os.path.isfile(UPSTREAM_TAID_PY),
|
| 83 |
+
reason=(
|
| 84 |
+
f"Upstream TAID clone not found at {UPSTREAM_PATH}. "
|
| 85 |
+
f"Run: git clone --depth 1 https://github.com/SakanaAI/TAID {UPSTREAM_PATH}"
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
@pytest.mark.parametrize("t", [0.0, 0.1, 0.4, 0.5, 0.9, 1.0])
|
| 89 |
+
def test_taid_parity_against_upstream(t):
|
| 90 |
+
"""Our taid_loss matches upstream TAID.compute_loss(...) within atol=1e-5.
|
| 91 |
+
|
| 92 |
+
Tests across the full t-range, on a fixed-seed batch with random logits +
|
| 93 |
+
a non-trivial mask.
|
| 94 |
+
"""
|
| 95 |
+
torch.manual_seed(0)
|
| 96 |
+
B, T, V = 2, 4, 16
|
| 97 |
+
student = torch.randn(B, T, V, requires_grad=True)
|
| 98 |
+
teacher = torch.randn(B, T, V)
|
| 99 |
+
mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.float32)
|
| 100 |
+
|
| 101 |
+
ours = taid_loss(student, teacher, mask, t=t)
|
| 102 |
+
theirs = _upstream_compute_loss(student, teacher, mask, t=t)
|
| 103 |
+
|
| 104 |
+
assert theirs is not None, "upstream forward_kl could not be loaded"
|
| 105 |
+
torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@pytest.mark.skipif(
|
| 109 |
+
not os.path.isfile(UPSTREAM_TAID_PY),
|
| 110 |
+
reason=f"Upstream TAID clone not found at {UPSTREAM_PATH}.",
|
| 111 |
+
)
|
| 112 |
+
def test_taid_parity_with_full_mask():
|
| 113 |
+
"""Sanity: full-mask path also matches upstream."""
|
| 114 |
+
torch.manual_seed(1)
|
| 115 |
+
B, T, V = 1, 3, 8
|
| 116 |
+
student = torch.randn(B, T, V, requires_grad=True)
|
| 117 |
+
teacher = torch.randn(B, T, V)
|
| 118 |
+
mask = torch.ones(B, T)
|
| 119 |
+
|
| 120 |
+
ours = taid_loss(student, teacher, mask, t=0.4)
|
| 121 |
+
theirs = _upstream_compute_loss(student, teacher, mask, t=0.4)
|
| 122 |
+
assert theirs is not None
|
| 123 |
+
torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5)
|
|
@@ -28,10 +28,12 @@ Three pluggable distillation losses can swap the default DPO/SDPO channels:
|
|
| 28 |
|
| 29 |
- ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with
|
| 30 |
margin) instead of standard DPO. Reference logprobs are no longer required.
|
| 31 |
-
- ``sdpo_wrapper="taid"`` — channel 2
|
| 32 |
-
Adaptive Interpolated Distillation). Requires ``
|
| 33 |
-
|
| 34 |
-
``
|
|
|
|
|
|
|
| 35 |
- ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a
|
| 36 |
per-token gated forward/reverse KL.
|
| 37 |
|
|
@@ -80,15 +82,10 @@ def compose_loss(
|
|
| 80 |
# ADR-007 extensions ------------------------------------------------
|
| 81 |
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 82 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 83 |
-
|
| 84 |
-
taid_total_steps: int | None = None,
|
| 85 |
# SimPO knobs (only used when dpo_variant="simpo") ------------------
|
| 86 |
simpo_beta: float = 2.0,
|
| 87 |
simpo_gamma: float = 1.0,
|
| 88 |
-
# TAID knobs (only used when sdpo_wrapper="taid") -------------------
|
| 89 |
-
taid_schedule: str = "linear",
|
| 90 |
-
taid_alpha_min: float = 0.0,
|
| 91 |
-
taid_alpha_max: float = 1.0,
|
| 92 |
# Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd")
|
| 93 |
entropy_opd_h_max: float | None = None,
|
| 94 |
) -> LossComponents:
|
|
@@ -111,11 +108,11 @@ def compose_loss(
|
|
| 111 |
- dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 112 |
(reference logprobs not required and silently ignored)
|
| 113 |
TAID (sdpo_wrapper="taid"):
|
| 114 |
-
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
"""
|
| 120 |
if dpo_variant not in ("dpo", "simpo"):
|
| 121 |
raise ValueError(
|
|
@@ -127,13 +124,14 @@ def compose_loss(
|
|
| 127 |
f"got {sdpo_wrapper!r}"
|
| 128 |
)
|
| 129 |
if sdpo_wrapper == "taid":
|
| 130 |
-
if
|
| 131 |
raise ValueError(
|
| 132 |
-
"sdpo_wrapper='taid' requires
|
|
|
|
| 133 |
)
|
| 134 |
-
if
|
| 135 |
raise ValueError(
|
| 136 |
-
"
|
| 137 |
)
|
| 138 |
|
| 139 |
device = _device_of(model)
|
|
@@ -176,24 +174,18 @@ def compose_loss(
|
|
| 176 |
elif sdpo_wrapper == "taid":
|
| 177 |
from composer_replication.distillation import taid_loss
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
#
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 185 |
sdpo_jsd = taid_loss(
|
| 186 |
student_logits=student_logits,
|
| 187 |
teacher_logits=teacher_logits,
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
total_steps=int(taid_total_steps),
|
| 191 |
-
schedule=taid_schedule,
|
| 192 |
-
alpha_min=taid_alpha_min,
|
| 193 |
-
alpha_max=taid_alpha_max,
|
| 194 |
-
jsd_beta=sdpo_jsd_beta,
|
| 195 |
-
temperature=sdpo_temperature,
|
| 196 |
-
reduction="batchmean",
|
| 197 |
)
|
| 198 |
elif sdpo_wrapper == "entropy_opd":
|
| 199 |
from composer_replication.distillation import (
|
|
@@ -348,48 +340,4 @@ def _avg_sequence_logprobs(
|
|
| 348 |
return masked.sum(dim=-1) / n_tokens
|
| 349 |
|
| 350 |
|
| 351 |
-
def _resolve_student_init_logits(
|
| 352 |
-
model: torch.nn.Module,
|
| 353 |
-
inputs: dict[str, torch.Tensor],
|
| 354 |
-
*,
|
| 355 |
-
expected_shape: torch.Size,
|
| 356 |
-
) -> torch.Tensor:
|
| 357 |
-
"""Return frozen student-init logits for TAID.
|
| 358 |
-
|
| 359 |
-
Preferred path: caller pre-saves a snapshot at training step 0 and passes
|
| 360 |
-
it via ``inputs['student_init_logits']``. Fallback path (only valid early
|
| 361 |
-
in training before the model has drifted): pass
|
| 362 |
-
``inputs['student_init_input_ids']`` and we run a no-grad forward through
|
| 363 |
-
``model``. Always returns a tensor on the same device as ``model``.
|
| 364 |
-
"""
|
| 365 |
-
if "student_init_logits" in inputs and inputs["student_init_logits"].numel() > 0:
|
| 366 |
-
student_init = inputs["student_init_logits"]
|
| 367 |
-
if student_init.shape != expected_shape:
|
| 368 |
-
raise ValueError(
|
| 369 |
-
f"inputs['student_init_logits'] shape {tuple(student_init.shape)} "
|
| 370 |
-
f"does not match teacher logits shape {tuple(expected_shape)}"
|
| 371 |
-
)
|
| 372 |
-
return student_init.detach()
|
| 373 |
-
|
| 374 |
-
if (
|
| 375 |
-
"student_init_input_ids" in inputs
|
| 376 |
-
and inputs["student_init_input_ids"].numel() > 0
|
| 377 |
-
):
|
| 378 |
-
with torch.no_grad():
|
| 379 |
-
init_logits = model(input_ids=inputs["student_init_input_ids"]).logits
|
| 380 |
-
if init_logits.shape != expected_shape:
|
| 381 |
-
raise ValueError(
|
| 382 |
-
f"frozen forward on student_init_input_ids gave shape "
|
| 383 |
-
f"{tuple(init_logits.shape)} which does not match teacher "
|
| 384 |
-
f"logits shape {tuple(expected_shape)}"
|
| 385 |
-
)
|
| 386 |
-
return init_logits
|
| 387 |
-
|
| 388 |
-
raise ValueError(
|
| 389 |
-
"sdpo_wrapper='taid' requires either inputs['student_init_logits'] "
|
| 390 |
-
"(precomputed) or inputs['student_init_input_ids'] (frozen forward "
|
| 391 |
-
"fallback) to be present."
|
| 392 |
-
)
|
| 393 |
-
|
| 394 |
-
|
| 395 |
__all__ = ["compose_loss", "LossComponents"]
|
|
|
|
| 28 |
|
| 29 |
- ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with
|
| 30 |
margin) instead of standard DPO. Reference logprobs are no longer required.
|
| 31 |
+
- ``sdpo_wrapper="taid"`` — channel 2 replaces SDPO with TAID (Temporally
|
| 32 |
+
Adaptive Interpolated Distillation, SakanaAI port). Requires ``taid_t``
|
| 33 |
+
(the current interpolation coefficient in ``[0, 1]``). The schedule that
|
| 34 |
+
produces ``taid_t`` is the trainer's responsibility — typically a
|
| 35 |
+
:class:`composer_replication.distillation.taid.TAIDScheduler` instance
|
| 36 |
+
driven by the per-step distillation loss.
|
| 37 |
- ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a
|
| 38 |
per-token gated forward/reverse KL.
|
| 39 |
|
|
|
|
| 82 |
# ADR-007 extensions ------------------------------------------------
|
| 83 |
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 84 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 85 |
+
taid_t: float | None = None,
|
|
|
|
| 86 |
# SimPO knobs (only used when dpo_variant="simpo") ------------------
|
| 87 |
simpo_beta: float = 2.0,
|
| 88 |
simpo_gamma: float = 1.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
# Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd")
|
| 90 |
entropy_opd_h_max: float | None = None,
|
| 91 |
) -> LossComponents:
|
|
|
|
| 108 |
- dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 109 |
(reference logprobs not required and silently ignored)
|
| 110 |
TAID (sdpo_wrapper="taid"):
|
| 111 |
+
- taid_t kwarg: scalar float in [0, 1] giving the current
|
| 112 |
+
interpolation coefficient. The trainer is responsible for the
|
| 113 |
+
schedule (use TAIDScheduler from
|
| 114 |
+
composer_replication.distillation.taid for the paper-default
|
| 115 |
+
adaptive scheme, or any custom schedule of your choosing).
|
| 116 |
"""
|
| 117 |
if dpo_variant not in ("dpo", "simpo"):
|
| 118 |
raise ValueError(
|
|
|
|
| 124 |
f"got {sdpo_wrapper!r}"
|
| 125 |
)
|
| 126 |
if sdpo_wrapper == "taid":
|
| 127 |
+
if taid_t is None:
|
| 128 |
raise ValueError(
|
| 129 |
+
"sdpo_wrapper='taid' requires taid_t (float in [0, 1]). "
|
| 130 |
+
"Drive it from a TAIDScheduler or pass a fixed value."
|
| 131 |
)
|
| 132 |
+
if not (0.0 <= float(taid_t) <= 1.0):
|
| 133 |
raise ValueError(
|
| 134 |
+
f"taid_t must be in [0, 1], got {taid_t}"
|
| 135 |
)
|
| 136 |
|
| 137 |
device = _device_of(model)
|
|
|
|
| 174 |
elif sdpo_wrapper == "taid":
|
| 175 |
from composer_replication.distillation import taid_loss
|
| 176 |
|
| 177 |
+
# taid_t validated non-None and in-range above.
|
| 178 |
+
assert taid_t is not None
|
| 179 |
+
# Reuse the SDPO loss-mask if provided so we only score the
|
| 180 |
+
# error-turn tokens; otherwise score all tokens.
|
| 181 |
+
taid_mask_bt = inputs.get("sdpo_loss_mask")
|
| 182 |
+
if taid_mask_bt is not None:
|
| 183 |
+
taid_mask_bt = taid_mask_bt.to(student_logits.device).float()
|
| 184 |
sdpo_jsd = taid_loss(
|
| 185 |
student_logits=student_logits,
|
| 186 |
teacher_logits=teacher_logits,
|
| 187 |
+
mask=taid_mask_bt,
|
| 188 |
+
t=float(taid_t),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
elif sdpo_wrapper == "entropy_opd":
|
| 191 |
from composer_replication.distillation import (
|
|
|
|
| 340 |
return masked.sum(dim=-1) / n_tokens
|
| 341 |
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
__all__ = ["compose_loss", "LossComponents"]
|
|
@@ -2,6 +2,9 @@
|
|
| 2 |
|
| 3 |
Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT).
|
| 4 |
Verified self-contained via DeepWiki audit on 2026-05-25.
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
Mathematical reference:
|
| 7 |
- OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation
|
|
@@ -39,17 +42,32 @@ def generalized_jsd_loss(
|
|
| 39 |
) -> torch.Tensor:
|
| 40 |
"""Generalized Jensen-Shannon Divergence loss between student and teacher.
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
Args:
|
| 43 |
student_logits: (B, T, V) — student model logits at each token position.
|
| 44 |
teacher_logits: (B, T, V) — teacher (= same model with hint context) logits.
|
| 45 |
labels: (B, T) — token-level mask. Positions with label == -100 are ignored
|
| 46 |
(standard HF padding/ignored convention). For Composer-style hint-distill,
|
| 47 |
mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere.
|
| 48 |
-
beta: in [0, 1].
|
| 49 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
temperature: softens distributions; T > 1 encourages distribution-matching
|
| 51 |
on broader tail probabilities. SDPO paper uses 1.0.
|
| 52 |
-
reduction: "batchmean"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
logits_are_probs: if True, inputs are already probabilities (skip softmax).
|
| 54 |
top_k: restrict KL to top-k tokens of the teacher distribution.
|
| 55 |
Saves compute on large vocabularies (Qwen3 vocab = 152K).
|
|
@@ -57,75 +75,78 @@ def generalized_jsd_loss(
|
|
| 57 |
SDPO paper does NOT clip; OPSD code defaults to None (no clip).
|
| 58 |
|
| 59 |
Returns:
|
| 60 |
-
Scalar loss tensor.
|
| 61 |
"""
|
| 62 |
-
#
|
| 63 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
student_logits = student_logits / temperature
|
| 65 |
teacher_logits = teacher_logits / temperature
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
teacher_log_probs = F.log_softmax(teacher_topk_vals, dim=-1)
|
| 74 |
-
else:
|
| 75 |
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
| 76 |
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
elif beta == 1.0:
|
| 86 |
-
# Reverse KL: KL(teacher || student)
|
| 87 |
-
per_token_div = F.kl_div(
|
| 88 |
-
teacher_log_probs, student_log_probs,
|
| 89 |
-
reduction="none", log_target=True,
|
| 90 |
-
).sum(dim=-1)
|
| 91 |
else:
|
| 92 |
-
#
|
| 93 |
-
# M =
|
| 94 |
-
#
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
torch.
|
|
|
|
| 98 |
)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
if token_clip is not None:
|
| 109 |
-
|
| 110 |
|
| 111 |
-
#
|
|
|
|
|
|
|
| 112 |
if labels is not None:
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
n_valid = loss_mask.sum().clamp(min=1.0)
|
| 116 |
-
else:
|
| 117 |
-
n_valid = torch.tensor(per_token_div.numel(), device=per_token_div.device, dtype=per_token_div.dtype)
|
| 118 |
|
|
|
|
| 119 |
if reduction == "batchmean":
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
elif reduction == "sum":
|
| 123 |
-
return
|
| 124 |
elif reduction == "mean":
|
| 125 |
-
return
|
| 126 |
elif reduction == "none":
|
| 127 |
-
return
|
| 128 |
else:
|
|
|
|
|
|
|
| 129 |
raise ValueError(f"Unknown reduction: {reduction}")
|
| 130 |
|
| 131 |
|
|
|
|
| 2 |
|
| 3 |
Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT).
|
| 4 |
Verified self-contained via DeepWiki audit on 2026-05-25.
|
| 5 |
+
Re-aligned byte-for-byte against upstream `opsd_trainer.py` lines 381-479 on
|
| 6 |
+
2026-05-26 after Wave 15 math review found three numerical divergences (mixture
|
| 7 |
+
weighting, β coefficient placement, reduction divisor) and one docstring mislabel.
|
| 8 |
|
| 9 |
Mathematical reference:
|
| 10 |
- OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation
|
|
|
|
| 42 |
) -> torch.Tensor:
|
| 43 |
"""Generalized Jensen-Shannon Divergence loss between student and teacher.
|
| 44 |
|
| 45 |
+
Byte-for-byte replication of `OPSDTrainer.generalized_jsd_loss`
|
| 46 |
+
(siyan-zhao/OPSD, opsd_trainer.py lines 381-479). See
|
| 47 |
+
https://huggingface.co/papers/2306.13649 Eq. (1) for the definition.
|
| 48 |
+
|
| 49 |
Args:
|
| 50 |
student_logits: (B, T, V) — student model logits at each token position.
|
| 51 |
teacher_logits: (B, T, V) — teacher (= same model with hint context) logits.
|
| 52 |
labels: (B, T) — token-level mask. Positions with label == -100 are ignored
|
| 53 |
(standard HF padding/ignored convention). For Composer-style hint-distill,
|
| 54 |
mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere.
|
| 55 |
+
beta: in [0, 1]. NOTE on direction (per `F.kl_div` semantics, where
|
| 56 |
+
`F.kl_div(log_q, log_p, log_target=True)` computes KL(p || q)):
|
| 57 |
+
β = 0 → kl_div(student_log_probs, teacher_log_probs)
|
| 58 |
+
= KL(teacher || student) (reverse KL — mode-covering for student)
|
| 59 |
+
β = 1 → kl_div(teacher_log_probs, student_log_probs)
|
| 60 |
+
= KL(student || teacher) (forward KL — mode-seeking for student)
|
| 61 |
+
β = 0.5 → symmetric JSD with M = 0.5*(P+Q)
|
| 62 |
+
General β ∈ (0,1): mixture M = (1-β)·P_student + β·P_teacher and
|
| 63 |
+
jsd = β·KL(teacher||M) + (1-β)·KL(student||M).
|
| 64 |
temperature: softens distributions; T > 1 encourages distribution-matching
|
| 65 |
on broader tail probabilities. SDPO paper uses 1.0.
|
| 66 |
+
reduction: "batchmean" | "sum" | "mean" | "none". "batchmean" matches
|
| 67 |
+
upstream OPSD: divides by `mask.sum()` when labels are given, else
|
| 68 |
+
by the leading dim of jsd (= batch size). This differs from PyTorch's
|
| 69 |
+
`KLDivLoss(reduction='batchmean')` (which divides by batch). We match
|
| 70 |
+
upstream because gradient scale stability matters more than the name.
|
| 71 |
logits_are_probs: if True, inputs are already probabilities (skip softmax).
|
| 72 |
top_k: restrict KL to top-k tokens of the teacher distribution.
|
| 73 |
Saves compute on large vocabularies (Qwen3 vocab = 152K).
|
|
|
|
| 75 |
SDPO paper does NOT clip; OPSD code defaults to None (no clip).
|
| 76 |
|
| 77 |
Returns:
|
| 78 |
+
Scalar loss tensor (or unreduced (B, T, V) tensor for reduction="none").
|
| 79 |
"""
|
| 80 |
+
# Path A: probabilities-in. Take log directly with a clamp for stability.
|
| 81 |
+
if logits_are_probs:
|
| 82 |
+
student_log_probs = torch.log(student_logits.clamp_min(1e-8))
|
| 83 |
+
teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8))
|
| 84 |
+
else:
|
| 85 |
+
# Apply temperature scaling to logits before computing probabilities.
|
| 86 |
student_logits = student_logits / temperature
|
| 87 |
teacher_logits = teacher_logits / temperature
|
| 88 |
|
| 89 |
+
if top_k is not None and top_k > 0:
|
| 90 |
+
# Restrict to top-k tokens of the teacher distribution and renormalize.
|
| 91 |
+
_, top_k_indices = torch.topk(teacher_logits, k=top_k, dim=-1)
|
| 92 |
+
student_logits = torch.gather(student_logits, dim=-1, index=top_k_indices)
|
| 93 |
+
teacher_logits = torch.gather(teacher_logits, dim=-1, index=top_k_indices)
|
| 94 |
+
|
|
|
|
|
|
|
| 95 |
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
| 96 |
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
| 97 |
|
| 98 |
+
if beta == 0:
|
| 99 |
+
# F.kl_div(input=log_q, target=log_p, log_target=True) computes KL(p || q):
|
| 100 |
+
# sum_x p(x) * (log p(x) - log q(x))
|
| 101 |
+
# With input=student_log_probs, target=teacher_log_probs → KL(teacher || student).
|
| 102 |
+
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 103 |
+
elif beta == 1:
|
| 104 |
+
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
else:
|
| 106 |
+
# Compute the log of the β-weighted mixture distribution:
|
| 107 |
+
# M = (1-β)·P_student + β·P_teacher
|
| 108 |
+
# log M = logsumexp([log P_student + log(1-β), log P_teacher + log(β)])
|
| 109 |
+
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
|
| 110 |
+
mixture_log_probs = torch.logsumexp(
|
| 111 |
+
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
|
| 112 |
+
dim=0,
|
| 113 |
)
|
| 114 |
+
|
| 115 |
+
# Compute KL divergences using F.kl_div.
|
| 116 |
+
# PyTorch differs from the standard mathematical definition, so the order of
|
| 117 |
+
# the probability distributions is swapped compared to that defined in the paper.
|
| 118 |
+
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 119 |
+
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
| 120 |
+
|
| 121 |
+
# Generalized JSD: β weights the teacher-leg KL (matches upstream).
|
| 122 |
+
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
| 123 |
+
|
| 124 |
+
# Per-token clipping: cap each token's divergence value.
|
| 125 |
if token_clip is not None:
|
| 126 |
+
jsd = jsd.clamp(max=token_clip)
|
| 127 |
|
| 128 |
+
# Masking. labels has shape (B, T); jsd has shape (B, T, V) (or top_k for V).
|
| 129 |
+
# `jsd[mask]` indexes the first two dims, yielding shape (n_valid, V).
|
| 130 |
+
mask = None
|
| 131 |
if labels is not None:
|
| 132 |
+
mask = labels != -100
|
| 133 |
+
jsd = jsd[mask]
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# Apply reduction (matches upstream byte-for-byte for batchmean/sum/mean).
|
| 136 |
if reduction == "batchmean":
|
| 137 |
+
if labels is not None:
|
| 138 |
+
assert mask is not None
|
| 139 |
+
return jsd.sum() / mask.sum()
|
| 140 |
+
return jsd.sum() / jsd.size(0)
|
| 141 |
elif reduction == "sum":
|
| 142 |
+
return jsd.sum()
|
| 143 |
elif reduction == "mean":
|
| 144 |
+
return jsd.mean()
|
| 145 |
elif reduction == "none":
|
| 146 |
+
return jsd
|
| 147 |
else:
|
| 148 |
+
# Upstream falls through to `return jsd` for unknown reductions; we raise
|
| 149 |
+
# to surface caller bugs instead of silently returning an unreduced tensor.
|
| 150 |
raise ValueError(f"Unknown reduction: {reduction}")
|
| 151 |
|
| 152 |
|
|
@@ -88,8 +88,27 @@ we reference its algorithm and convention but vendor no code.
|
|
| 88 |
"""
|
| 89 |
from __future__ import annotations
|
| 90 |
|
|
|
|
| 91 |
from typing import Any
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def loss_fn(
|
| 95 |
inputs: Any, # PRIME-RL's LossInputs — typed as Any to avoid hard import
|
|
@@ -129,8 +148,13 @@ def loss_fn(
|
|
| 129 |
PRIME-RL default ``1e-3``. Must be >= 0.
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
Raises:
|
| 136 |
ValueError: if any of ``trainer_logprobs``, ``inference_logprobs``,
|
|
@@ -245,7 +269,7 @@ def loss_fn(
|
|
| 245 |
stacklevel=2,
|
| 246 |
)
|
| 247 |
|
| 248 |
-
return total
|
| 249 |
|
| 250 |
|
| 251 |
-
__all__ = ["loss_fn"]
|
|
|
|
| 88 |
"""
|
| 89 |
from __future__ import annotations
|
| 90 |
|
| 91 |
+
from collections import namedtuple
|
| 92 |
from typing import Any
|
| 93 |
|
| 94 |
+
# PRIME-RL's setup_loss_fns expects loss functions to return a LossOutputs
|
| 95 |
+
# struct with `.loss` (scalar Tensor) and `.metrics` (dict). When PRIME-RL is
|
| 96 |
+
# installed we use the upstream dataclass directly so isinstance() checks in
|
| 97 |
+
# any downstream code keep working; otherwise we fall back to a structurally
|
| 98 |
+
# equivalent NamedTuple that exposes the same attribute access.
|
| 99 |
+
#
|
| 100 |
+
# Upstream definition (prime_rl/trainer/rl/loss.py lines 24-29):
|
| 101 |
+
# @dataclass
|
| 102 |
+
# class LossOutputs:
|
| 103 |
+
# loss: Float[Tensor, ""]
|
| 104 |
+
# metrics: dict[str, Tensor]
|
| 105 |
+
try: # pragma: no cover - exercised only when prime-rl is installed
|
| 106 |
+
from prime_rl.trainer.rl.loss import ( # type: ignore[import-not-found]
|
| 107 |
+
LossOutputs,
|
| 108 |
+
)
|
| 109 |
+
except Exception: # noqa: BLE001 - missing module, version skew, or jaxtyping
|
| 110 |
+
LossOutputs = namedtuple("LossOutputs", ["loss", "metrics"]) # type: ignore[misc,assignment]
|
| 111 |
+
|
| 112 |
|
| 113 |
def loss_fn(
|
| 114 |
inputs: Any, # PRIME-RL's LossInputs — typed as Any to avoid hard import
|
|
|
|
| 148 |
PRIME-RL default ``1e-3``. Must be >= 0.
|
| 149 |
|
| 150 |
Returns:
|
| 151 |
+
:class:`LossOutputs` with ``loss`` (scalar ``torch.Tensor``) and
|
| 152 |
+
``metrics`` (``dict[str, Tensor | float]``). PRIME-RL's outer
|
| 153 |
+
``compute_loss`` reads ``out.loss``, divides by ``loss_scale``, and
|
| 154 |
+
calls ``.backward()``; the ``metrics`` dict is forwarded to the
|
| 155 |
+
logger. When PRIME-RL is installed this is upstream's
|
| 156 |
+
``LossOutputs`` dataclass; otherwise it is a structurally
|
| 157 |
+
equivalent ``namedtuple`` defined at the top of this module.
|
| 158 |
|
| 159 |
Raises:
|
| 160 |
ValueError: if any of ``trainer_logprobs``, ``inference_logprobs``,
|
|
|
|
| 269 |
stacklevel=2,
|
| 270 |
)
|
| 271 |
|
| 272 |
+
return LossOutputs(loss=total, metrics={"channel_1_pg_loss": float(total.detach())})
|
| 273 |
|
| 274 |
|
| 275 |
+
__all__ = ["loss_fn", "LossOutputs"]
|
|
@@ -16,13 +16,33 @@ from typing import Optional
|
|
| 16 |
|
| 17 |
import pytest
|
| 18 |
import torch
|
|
|
|
| 19 |
|
| 20 |
-
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
# Try to import PRIME-RL upstream for the parity test; skip-mark if
|
| 24 |
# unavailable. PRIME-RL pulls in heavy deps (jaxtyping, beartype) and
|
| 25 |
# is not part of the framework's own test environment.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
try:
|
| 27 |
from prime_rl.trainer.rl.loss import ( # type: ignore[import-not-found]
|
| 28 |
LossInputs as PrimeRLLossInputs,
|
|
@@ -34,6 +54,14 @@ try:
|
|
| 34 |
_HAS_PRIME_RL = True
|
| 35 |
except Exception: # noqa: BLE001 — broad: missing module, version skew, etc.
|
| 36 |
_HAS_PRIME_RL = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# ---------------------------------------------------------------------
|
|
@@ -91,6 +119,43 @@ def _make_inputs(
|
|
| 91 |
# Reference re-implementation (independent restatement of upstream).
|
| 92 |
# Used by hand-computed expected-value tests so we don't accidentally
|
| 93 |
# encode our own bugs as ground truth.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# ---------------------------------------------------------------------
|
| 95 |
def _reference_default_loss(
|
| 96 |
trainer_lp: torch.Tensor,
|
|
@@ -123,8 +188,18 @@ def _reference_default_loss(
|
|
| 123 |
# ---------------------------------------------------------------------
|
| 124 |
def test_returns_finite_scalar():
|
| 125 |
inputs = _make_inputs(seq=16)
|
| 126 |
-
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
assert isinstance(out, torch.Tensor)
|
| 129 |
assert out.shape == (), f"expected scalar, got shape {tuple(out.shape)}"
|
| 130 |
assert torch.isfinite(out).item()
|
|
@@ -164,7 +239,7 @@ def test_dppo_mask_high_drops_positive_advantage_outliers():
|
|
| 164 |
advantages=advantages,
|
| 165 |
loss_mask=mask,
|
| 166 |
)
|
| 167 |
-
out = loss_fn(
|
| 168 |
inputs,
|
| 169 |
alpha_sdpo=0.0,
|
| 170 |
beta_dpo=0.0,
|
|
@@ -172,7 +247,7 @@ def test_dppo_mask_high_drops_positive_advantage_outliers():
|
|
| 172 |
dppo_mask_low=0.2,
|
| 173 |
adv_tau=1.0,
|
| 174 |
kl_tau=1e-3,
|
| 175 |
-
)
|
| 176 |
|
| 177 |
expected = _reference_default_loss(
|
| 178 |
trainer_lp.detach(),
|
|
@@ -226,7 +301,7 @@ def test_dppo_mask_low_drops_negative_advantage_outliers():
|
|
| 226 |
advantages=advantages,
|
| 227 |
loss_mask=mask,
|
| 228 |
)
|
| 229 |
-
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 230 |
|
| 231 |
expected = _reference_default_loss(
|
| 232 |
trainer_lp.detach(),
|
|
@@ -266,7 +341,7 @@ def test_dppo_mask_sign_conditioned_on_advantage():
|
|
| 266 |
advantages=adv_pos,
|
| 267 |
loss_mask=mask,
|
| 268 |
)
|
| 269 |
-
out_pos = loss_fn(inputs_pos, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 270 |
|
| 271 |
# With positive advantage the LOW bound is not checked; the token is
|
| 272 |
# KEPT. pg = +1 * exp(-10 - 0) = ~4.5e-5; kl = (-10)^2 = 100.
|
|
@@ -294,7 +369,7 @@ def test_dppo_mask_sign_conditioned_on_advantage():
|
|
| 294 |
advantages=torch.tensor([-1.0]),
|
| 295 |
loss_mask=mask,
|
| 296 |
)
|
| 297 |
-
out_neg = loss_fn(inputs_neg, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 298 |
expected_neg = _reference_default_loss(
|
| 299 |
trainer_lp_neg.detach(),
|
| 300 |
inference_lp_pos,
|
|
@@ -313,7 +388,7 @@ def test_dppo_mask_sign_conditioned_on_advantage():
|
|
| 313 |
# ---------------------------------------------------------------------
|
| 314 |
def test_alpha_sdpo_zero_does_not_raise():
|
| 315 |
inputs = _make_inputs(seq=6, teacher=True)
|
| 316 |
-
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 317 |
assert torch.isfinite(out).item()
|
| 318 |
|
| 319 |
|
|
@@ -339,7 +414,7 @@ def test_alpha_sdpo_nonzero_no_teacher_also_raises():
|
|
| 339 |
# ---------------------------------------------------------------------
|
| 340 |
def test_advantages_shape_validates_seq_accepted():
|
| 341 |
inputs = _make_inputs(seq=12)
|
| 342 |
-
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 343 |
assert out.shape == ()
|
| 344 |
|
| 345 |
|
|
@@ -361,7 +436,7 @@ def test_advantages_shape_validates_bt_rejected():
|
|
| 361 |
def test_beta_dpo_nonzero_warns():
|
| 362 |
inputs = _make_inputs(seq=8)
|
| 363 |
with pytest.warns(UserWarning, match="DPO channel"):
|
| 364 |
-
out = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.3)
|
| 365 |
assert torch.isfinite(out).item()
|
| 366 |
|
| 367 |
|
|
@@ -411,7 +486,7 @@ def test_dppo_bounds_can_be_disabled():
|
|
| 411 |
loss_mask=mask,
|
| 412 |
)
|
| 413 |
|
| 414 |
-
out = loss_fn(
|
| 415 |
inputs,
|
| 416 |
alpha_sdpo=0.0,
|
| 417 |
beta_dpo=0.0,
|
|
@@ -419,7 +494,7 @@ def test_dppo_bounds_can_be_disabled():
|
|
| 419 |
dppo_mask_low=1e6,
|
| 420 |
adv_tau=1.0,
|
| 421 |
kl_tau=1e-3,
|
| 422 |
-
)
|
| 423 |
|
| 424 |
expected = _reference_default_loss(
|
| 425 |
trainer_lp.detach(),
|
|
@@ -463,7 +538,7 @@ def test_parity_with_prime_rl_default_loss_fn():
|
|
| 463 |
)
|
| 464 |
upstream_out = prime_rl_default_loss_fn(upstream_inputs, cfg) # type: ignore[name-defined]
|
| 465 |
|
| 466 |
-
ours = loss_fn(
|
| 467 |
FakeLossInputs(
|
| 468 |
trainer_logprobs=trainer_lp.clone(),
|
| 469 |
inference_logprobs=inference_lp.clone(),
|
|
@@ -476,7 +551,7 @@ def test_parity_with_prime_rl_default_loss_fn():
|
|
| 476 |
dppo_mask_low=cfg.dppo_mask_low,
|
| 477 |
adv_tau=cfg.adv_tau,
|
| 478 |
kl_tau=cfg.kl_tau,
|
| 479 |
-
)
|
| 480 |
|
| 481 |
assert torch.isclose(ours, upstream_out.loss, atol=1e-5, rtol=1e-5), (
|
| 482 |
f"Parity mismatch with PRIME-RL upstream: ours={ours.item()}, "
|
|
|
|
| 16 |
|
| 17 |
import pytest
|
| 18 |
import torch
|
| 19 |
+
import warnings
|
| 20 |
|
| 21 |
+
from composer_replication.recipes.prime_rl.composer_loss import LossOutputs, loss_fn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _loss_value(result) -> torch.Tensor:
|
| 25 |
+
"""Return the scalar loss tensor from either a LossOutputs struct or a
|
| 26 |
+
bare Tensor. The recipe wraps its return in LossOutputs to satisfy
|
| 27 |
+
PRIME-RL's setup_loss_fns contract; tests written against the older
|
| 28 |
+
bare-Tensor return path keep working through this helper.
|
| 29 |
+
"""
|
| 30 |
+
if isinstance(result, torch.Tensor):
|
| 31 |
+
return result
|
| 32 |
+
# LossOutputs: dataclass (upstream) or namedtuple (fallback).
|
| 33 |
+
return result.loss
|
| 34 |
|
| 35 |
|
| 36 |
# Try to import PRIME-RL upstream for the parity test; skip-mark if
|
| 37 |
# unavailable. PRIME-RL pulls in heavy deps (jaxtyping, beartype) and
|
| 38 |
# is not part of the framework's own test environment.
|
| 39 |
+
#
|
| 40 |
+
# Visibility: when the import fails we emit a UserWarning at module load
|
| 41 |
+
# so the skip is *visible* in pytest output ("PytestUnhandledThreadExceptionWarning"
|
| 42 |
+
# is too noisy; UserWarning is captured by pytest's default filterwarnings
|
| 43 |
+
# and printed in the run summary). Without this, CI without prime-rl
|
| 44 |
+
# silently never runs the parity test and a real divergence could go
|
| 45 |
+
# undetected for releases at a time.
|
| 46 |
try:
|
| 47 |
from prime_rl.trainer.rl.loss import ( # type: ignore[import-not-found]
|
| 48 |
LossInputs as PrimeRLLossInputs,
|
|
|
|
| 54 |
_HAS_PRIME_RL = True
|
| 55 |
except Exception: # noqa: BLE001 — broad: missing module, version skew, etc.
|
| 56 |
_HAS_PRIME_RL = False
|
| 57 |
+
warnings.warn(
|
| 58 |
+
"prime-rl is not importable in this environment; the upstream "
|
| 59 |
+
"parity test (test_parity_with_prime_rl_default_loss_fn) will be "
|
| 60 |
+
"skipped. The shadow-parity test below still runs against an "
|
| 61 |
+
"in-file reference reimplementation.",
|
| 62 |
+
UserWarning,
|
| 63 |
+
stacklevel=2,
|
| 64 |
+
)
|
| 65 |
|
| 66 |
|
| 67 |
# ---------------------------------------------------------------------
|
|
|
|
| 119 |
# Reference re-implementation (independent restatement of upstream).
|
| 120 |
# Used by hand-computed expected-value tests so we don't accidentally
|
| 121 |
# encode our own bugs as ground truth.
|
| 122 |
+
#
|
| 123 |
+
# SHADOW-PARITY MAPPING
|
| 124 |
+
# ---------------------
|
| 125 |
+
# The body below is structurally identical to PRIME-RL's
|
| 126 |
+
# ``default_loss_fn`` at ``src/prime_rl/trainer/rl/loss.py`` lines
|
| 127 |
+
# 116-153 (commit pinned by /tmp/prime-rl-clone clone). The mapping,
|
| 128 |
+
# line-by-line, is:
|
| 129 |
+
#
|
| 130 |
+
# upstream line 133-135 -> ``log_ir = ...``,
|
| 131 |
+
# ``ir = torch.exp(log_ir)``
|
| 132 |
+
# (we elide the unused ``mismatch_kl``
|
| 133 |
+
# term — upstream returns it as a metric
|
| 134 |
+
# only; we drop metrics in the reference
|
| 135 |
+
# because our channel-1 loss is a scalar
|
| 136 |
+
# and we compare ``.loss`` only.)
|
| 137 |
+
# upstream line 137 -> ``probs_diff = exp(trainer_lp) - exp(inference_lp)``
|
| 138 |
+
# upstream line 138 -> ``invalid_high = probs_diff > dppo_mask_high``
|
| 139 |
+
# upstream line 139 -> ``invalid_low = probs_diff < -dppo_mask_low``
|
| 140 |
+
# upstream line 140 -> ``pos_adv = advantages > 0``
|
| 141 |
+
# upstream line 142 -> ``invalid = where(pos_adv, invalid_high, invalid_low)``
|
| 142 |
+
# upstream line 148 -> ``keep = loss_mask & ~invalid``
|
| 143 |
+
# (upstream uses ``& is_masked``; we
|
| 144 |
+
# pre-cast ``loss_mask`` via ``to(bool)``)
|
| 145 |
+
# upstream line 150 -> ``adv_tau * advantages`` (inlined)
|
| 146 |
+
# upstream line 151 -> ``pg = keep_f * (adv_tau * advantages) * ir``
|
| 147 |
+
# upstream line 152 -> ``kl = lm_f * log_ir**2``
|
| 148 |
+
# upstream line 153 -> ``return (-pg + kl_tau * kl).sum()``
|
| 149 |
+
#
|
| 150 |
+
# Differences (intentional, do not affect ``.loss``):
|
| 151 |
+
# * upstream returns ``LossOutputs(loss=..., metrics={...})``; we
|
| 152 |
+
# return only the loss scalar because the seven metric entries
|
| 153 |
+
# (lines 155-163) don't influence backward and are validated
|
| 154 |
+
# separately in ``test_parity_with_prime_rl_default_loss_fn``.
|
| 155 |
+
# * upstream casts via ``loss_mask & is_masked`` (Bool & Bool); our
|
| 156 |
+
# ``keep_f.to(trainer_lp.dtype)`` matches exactly because both
|
| 157 |
+
# ``keep_mask`` and ``loss_mask`` are bool tensors broadcast to
|
| 158 |
+
# ``trainer_lp.dtype`` for the float multiply.
|
| 159 |
# ---------------------------------------------------------------------
|
| 160 |
def _reference_default_loss(
|
| 161 |
trainer_lp: torch.Tensor,
|
|
|
|
| 188 |
# ---------------------------------------------------------------------
|
| 189 |
def test_returns_finite_scalar():
|
| 190 |
inputs = _make_inputs(seq=16)
|
| 191 |
+
result = loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0)
|
| 192 |
|
| 193 |
+
# Must be a LossOutputs (dataclass when prime-rl is installed,
|
| 194 |
+
# NamedTuple fallback otherwise). PRIME-RL's setup_loss_fns reads
|
| 195 |
+
# ``.loss`` and ``.metrics`` from this struct.
|
| 196 |
+
assert hasattr(result, "loss") and hasattr(result, "metrics"), (
|
| 197 |
+
f"loss_fn must return a LossOutputs-shaped struct; got {type(result)}"
|
| 198 |
+
)
|
| 199 |
+
assert isinstance(result.metrics, dict)
|
| 200 |
+
assert "channel_1_pg_loss" in result.metrics
|
| 201 |
+
|
| 202 |
+
out = result.loss
|
| 203 |
assert isinstance(out, torch.Tensor)
|
| 204 |
assert out.shape == (), f"expected scalar, got shape {tuple(out.shape)}"
|
| 205 |
assert torch.isfinite(out).item()
|
|
|
|
| 239 |
advantages=advantages,
|
| 240 |
loss_mask=mask,
|
| 241 |
)
|
| 242 |
+
out = _loss_value(loss_fn(
|
| 243 |
inputs,
|
| 244 |
alpha_sdpo=0.0,
|
| 245 |
beta_dpo=0.0,
|
|
|
|
| 247 |
dppo_mask_low=0.2,
|
| 248 |
adv_tau=1.0,
|
| 249 |
kl_tau=1e-3,
|
| 250 |
+
))
|
| 251 |
|
| 252 |
expected = _reference_default_loss(
|
| 253 |
trainer_lp.detach(),
|
|
|
|
| 301 |
advantages=advantages,
|
| 302 |
loss_mask=mask,
|
| 303 |
)
|
| 304 |
+
out = _loss_value(loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0))
|
| 305 |
|
| 306 |
expected = _reference_default_loss(
|
| 307 |
trainer_lp.detach(),
|
|
|
|
| 341 |
advantages=adv_pos,
|
| 342 |
loss_mask=mask,
|
| 343 |
)
|
| 344 |
+
out_pos = _loss_value(loss_fn(inputs_pos, alpha_sdpo=0.0, beta_dpo=0.0))
|
| 345 |
|
| 346 |
# With positive advantage the LOW bound is not checked; the token is
|
| 347 |
# KEPT. pg = +1 * exp(-10 - 0) = ~4.5e-5; kl = (-10)^2 = 100.
|
|
|
|
| 369 |
advantages=torch.tensor([-1.0]),
|
| 370 |
loss_mask=mask,
|
| 371 |
)
|
| 372 |
+
out_neg = _loss_value(loss_fn(inputs_neg, alpha_sdpo=0.0, beta_dpo=0.0))
|
| 373 |
expected_neg = _reference_default_loss(
|
| 374 |
trainer_lp_neg.detach(),
|
| 375 |
inference_lp_pos,
|
|
|
|
| 388 |
# ---------------------------------------------------------------------
|
| 389 |
def test_alpha_sdpo_zero_does_not_raise():
|
| 390 |
inputs = _make_inputs(seq=6, teacher=True)
|
| 391 |
+
out = _loss_value(loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0))
|
| 392 |
assert torch.isfinite(out).item()
|
| 393 |
|
| 394 |
|
|
|
|
| 414 |
# ---------------------------------------------------------------------
|
| 415 |
def test_advantages_shape_validates_seq_accepted():
|
| 416 |
inputs = _make_inputs(seq=12)
|
| 417 |
+
out = _loss_value(loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.0))
|
| 418 |
assert out.shape == ()
|
| 419 |
|
| 420 |
|
|
|
|
| 436 |
def test_beta_dpo_nonzero_warns():
|
| 437 |
inputs = _make_inputs(seq=8)
|
| 438 |
with pytest.warns(UserWarning, match="DPO channel"):
|
| 439 |
+
out = _loss_value(loss_fn(inputs, alpha_sdpo=0.0, beta_dpo=0.3))
|
| 440 |
assert torch.isfinite(out).item()
|
| 441 |
|
| 442 |
|
|
|
|
| 486 |
loss_mask=mask,
|
| 487 |
)
|
| 488 |
|
| 489 |
+
out = _loss_value(loss_fn(
|
| 490 |
inputs,
|
| 491 |
alpha_sdpo=0.0,
|
| 492 |
beta_dpo=0.0,
|
|
|
|
| 494 |
dppo_mask_low=1e6,
|
| 495 |
adv_tau=1.0,
|
| 496 |
kl_tau=1e-3,
|
| 497 |
+
))
|
| 498 |
|
| 499 |
expected = _reference_default_loss(
|
| 500 |
trainer_lp.detach(),
|
|
|
|
| 538 |
)
|
| 539 |
upstream_out = prime_rl_default_loss_fn(upstream_inputs, cfg) # type: ignore[name-defined]
|
| 540 |
|
| 541 |
+
ours = _loss_value(loss_fn(
|
| 542 |
FakeLossInputs(
|
| 543 |
trainer_logprobs=trainer_lp.clone(),
|
| 544 |
inference_logprobs=inference_lp.clone(),
|
|
|
|
| 551 |
dppo_mask_low=cfg.dppo_mask_low,
|
| 552 |
adv_tau=cfg.adv_tau,
|
| 553 |
kl_tau=cfg.kl_tau,
|
| 554 |
+
))
|
| 555 |
|
| 556 |
assert torch.isclose(ours, upstream_out.loss, atol=1e-5, rtol=1e-5), (
|
| 557 |
f"Parity mismatch with PRIME-RL upstream: ours={ours.item()}, "
|
|
@@ -5,16 +5,13 @@ pluggable losses (SimPO, TAID, Entropy-Aware OPD). They use a tiny
|
|
| 5 |
hand-rolled language model wrapper (no HF, no TRL) so the tests run
|
| 6 |
in <1s on CPU and are isolated from external library churn.
|
| 7 |
|
| 8 |
-
Coverage requirements
|
| 9 |
(a) defaults reproduce existing compose_loss output bit-exact
|
| 10 |
(b) dpo_variant='simpo' produces a different total than dpo
|
| 11 |
-
(c) sdpo_wrapper='taid' with
|
| 12 |
-
|
| 13 |
-
(d) sdpo_wrapper='taid' interpolates as expected when
|
| 14 |
-
schedule_step=total_steps/2
|
| 15 |
(e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar
|
| 16 |
-
(f) error case: sdpo_wrapper='taid' without
|
| 17 |
-
ValueError
|
| 18 |
"""
|
| 19 |
from __future__ import annotations
|
| 20 |
|
|
@@ -194,120 +191,102 @@ def test_simpo_does_not_require_ref_logprobs():
|
|
| 194 |
|
| 195 |
|
| 196 |
# ----------------------------------------------------------------------
|
| 197 |
-
# (c) TAID with
|
| 198 |
# ----------------------------------------------------------------------
|
| 199 |
|
| 200 |
-
def
|
| 201 |
-
"""
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
(modulo the softmax→log roundtrip in `taid_blended_logits`, which is
|
| 205 |
-
bit-equivalent for finite logits).
|
| 206 |
"""
|
| 207 |
-
|
| 208 |
|
| 209 |
-
|
| 210 |
-
out_sdpo = compose_loss(
|
| 211 |
-
model_a, inputs,
|
| 212 |
-
alpha_sdpo=0.1,
|
| 213 |
-
beta_replay=0.0, # disable channel 3 so we isolate channel 2
|
| 214 |
-
sdpo_wrapper="none",
|
| 215 |
-
)
|
| 216 |
|
| 217 |
-
|
| 218 |
-
# Provide a student_init_logits snapshot — for α=1 its value doesn't
|
| 219 |
-
# affect the blended target (P_blended = teacher when α=1), so any
|
| 220 |
-
# valid-shape tensor works. Use the teacher shape.
|
| 221 |
-
with torch.no_grad():
|
| 222 |
-
init_logits = model_b(input_ids=inputs["ctx_teacher_input_ids"]).logits.clone()
|
| 223 |
-
inputs_taid = dict(inputs)
|
| 224 |
-
inputs_taid["student_init_logits"] = init_logits
|
| 225 |
|
|
|
|
| 226 |
out_taid = compose_loss(
|
| 227 |
-
|
| 228 |
-
alpha_sdpo=
|
| 229 |
beta_replay=0.0,
|
| 230 |
sdpo_wrapper="taid",
|
| 231 |
-
|
| 232 |
-
taid_total_steps=100,
|
| 233 |
-
taid_alpha_min=1.0,
|
| 234 |
-
taid_alpha_max=1.0,
|
| 235 |
)
|
| 236 |
|
| 237 |
-
#
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
|
| 242 |
# ----------------------------------------------------------------------
|
| 243 |
-
# (d) TAID interpolates
|
| 244 |
# ----------------------------------------------------------------------
|
| 245 |
|
| 246 |
-
def
|
| 247 |
-
"""
|
| 248 |
-
alpha_max=1, the schedule yields α=0.5. The resulting loss must
|
| 249 |
-
differ from both endpoints (α=0 → init-only target, α=1 → pure SDPO),
|
| 250 |
-
and must be finite + differentiable.
|
| 251 |
-
"""
|
| 252 |
inputs = _base_batch(with_dpo=False)
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
# α=1 would both target the same distribution and the test would
|
| 258 |
-
# become vacuous).
|
| 259 |
-
snapshot_model = _model_seeded(seed=99)
|
| 260 |
-
with torch.no_grad():
|
| 261 |
-
init_logits = snapshot_model(
|
| 262 |
-
input_ids=inputs["ctx_teacher_input_ids"]
|
| 263 |
-
).logits.clone()
|
| 264 |
-
inputs = dict(inputs)
|
| 265 |
-
inputs["student_init_logits"] = init_logits
|
| 266 |
-
|
| 267 |
-
# Endpoint α=1 (pure SDPO target — init_logits ignored)
|
| 268 |
-
model_end = _model_seeded(seed=2)
|
| 269 |
-
out_alpha_one = compose_loss(
|
| 270 |
-
model_end, inputs,
|
| 271 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 272 |
sdpo_wrapper="taid",
|
| 273 |
-
|
| 274 |
-
taid_alpha_min=0.0, taid_alpha_max=1.0,
|
| 275 |
)
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
model_start, inputs,
|
| 281 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 282 |
sdpo_wrapper="taid",
|
| 283 |
-
|
| 284 |
-
taid_alpha_min=0.0, taid_alpha_max=1.0,
|
| 285 |
)
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
model_mid, inputs,
|
| 291 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 292 |
sdpo_wrapper="taid",
|
| 293 |
-
|
| 294 |
-
taid_alpha_min=0.0, taid_alpha_max=1.0,
|
| 295 |
)
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
assert torch.isfinite(out.
|
| 300 |
-
assert torch.isfinite(out.sdpo_jsd), f"non-finite sdpo_jsd: {out.sdpo_jsd}"
|
| 301 |
|
| 302 |
-
|
| 303 |
-
assert not torch.allclose(
|
| 304 |
-
out_mid.sdpo_jsd, out_alpha_zero.sdpo_jsd, atol=1e-5
|
| 305 |
-
), "midpoint TAID matches α=0 endpoint — schedule not interpolating"
|
| 306 |
-
assert not torch.allclose(
|
| 307 |
-
out_mid.sdpo_jsd, out_alpha_one.sdpo_jsd, atol=1e-5
|
| 308 |
-
), "midpoint TAID matches α=1 endpoint — schedule not interpolating"
|
| 309 |
|
| 310 |
-
# Differentiable.
|
| 311 |
out_mid.total.backward()
|
| 312 |
assert any(
|
| 313 |
p.grad is not None and torch.isfinite(p.grad).all()
|
|
@@ -343,32 +322,30 @@ def test_entropy_opd_returns_finite_differentiable_scalar():
|
|
| 343 |
|
| 344 |
|
| 345 |
# ----------------------------------------------------------------------
|
| 346 |
-
# (f) Error: sdpo_wrapper='taid' without
|
| 347 |
# ----------------------------------------------------------------------
|
| 348 |
|
| 349 |
-
def
|
| 350 |
inputs = _base_batch(with_dpo=False)
|
| 351 |
model = _model_seeded(seed=4)
|
| 352 |
-
with pytest.raises(ValueError, match="
|
| 353 |
compose_loss(
|
| 354 |
model, inputs,
|
| 355 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 356 |
sdpo_wrapper="taid",
|
| 357 |
-
|
| 358 |
-
# taid_schedule_step omitted on purpose
|
| 359 |
)
|
| 360 |
|
| 361 |
|
| 362 |
-
def
|
| 363 |
inputs = _base_batch(with_dpo=False)
|
| 364 |
model = _model_seeded(seed=4)
|
| 365 |
-
with pytest.raises(ValueError, match="
|
| 366 |
compose_loss(
|
| 367 |
model, inputs,
|
| 368 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 369 |
sdpo_wrapper="taid",
|
| 370 |
-
|
| 371 |
-
# taid_total_steps omitted on purpose
|
| 372 |
)
|
| 373 |
|
| 374 |
|
|
@@ -393,24 +370,28 @@ def test_invalid_sdpo_wrapper_raises():
|
|
| 393 |
|
| 394 |
|
| 395 |
# ----------------------------------------------------------------------
|
| 396 |
-
# Bonus:
|
| 397 |
# ----------------------------------------------------------------------
|
| 398 |
|
| 399 |
-
def
|
| 400 |
-
"""
|
| 401 |
-
|
|
|
|
| 402 |
inputs = _base_batch(with_dpo=False)
|
| 403 |
model = _model_seeded(seed=6)
|
|
|
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
taid_schedule_step=10, taid_total_steps=100,
|
| 415 |
-
)
|
| 416 |
-
assert torch.isfinite(out.total)
|
|
|
|
| 5 |
hand-rolled language model wrapper (no HF, no TRL) so the tests run
|
| 6 |
in <1s on CPU and are isolated from external library churn.
|
| 7 |
|
| 8 |
+
Coverage requirements:
|
| 9 |
(a) defaults reproduce existing compose_loss output bit-exact
|
| 10 |
(b) dpo_variant='simpo' produces a different total than dpo
|
| 11 |
+
(c) sdpo_wrapper='taid' with t=0 differs from t=1 (interpolation works)
|
| 12 |
+
(d) sdpo_wrapper='taid' with t=1 reproduces upstream forward-KL
|
|
|
|
|
|
|
| 13 |
(e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar
|
| 14 |
+
(f) error case: sdpo_wrapper='taid' without taid_t raises ValueError
|
|
|
|
| 15 |
"""
|
| 16 |
from __future__ import annotations
|
| 17 |
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
# ----------------------------------------------------------------------
|
| 194 |
+
# (c) TAID with t=1 reproduces upstream forward-KL on the masked tokens
|
| 195 |
# ----------------------------------------------------------------------
|
| 196 |
|
| 197 |
+
def test_taid_t_one_matches_upstream_forward_kl():
|
| 198 |
+
"""At t=1, taid_loss reduces to forward-KL with target = softmax(teacher).
|
| 199 |
+
compose_loss should plumb through to that exact value (modulo the
|
| 200 |
+
sdpo_loss_mask token-mean denominator).
|
|
|
|
|
|
|
| 201 |
"""
|
| 202 |
+
import torch.nn.functional as F
|
| 203 |
|
| 204 |
+
inputs = _base_batch(with_dpo=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
model = _model_seeded(seed=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
# Run compose_loss with TAID at t=1.
|
| 209 |
out_taid = compose_loss(
|
| 210 |
+
model, inputs,
|
| 211 |
+
alpha_sdpo=1.0, # so out.sdpo_jsd is added straight to total
|
| 212 |
beta_replay=0.0,
|
| 213 |
sdpo_wrapper="taid",
|
| 214 |
+
taid_t=1.0,
|
|
|
|
|
|
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
+
# Manually compute the same forward-KL on the masked tokens.
|
| 218 |
+
student_logits = model(input_ids=inputs["input_ids"]).logits
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
|
| 221 |
+
mask = inputs["sdpo_loss_mask"].float()
|
| 222 |
+
p_teacher = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
|
| 223 |
+
log_q = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
|
| 224 |
+
per_token = -(p_teacher * log_q).sum(dim=-1)
|
| 225 |
+
flat = per_token.reshape(-1)
|
| 226 |
+
fmask = mask.reshape(-1).to(flat.dtype)
|
| 227 |
+
expected = (flat * fmask).sum() / fmask.sum().clamp_min(1.0)
|
| 228 |
+
|
| 229 |
+
# Bit-exact assertion. The TAID-loss path at t=1 is mathematically
|
| 230 |
+
# identical to the manual `-(p_teacher * log_q).sum(...)` cross-entropy
|
| 231 |
+
# below: at t=1, TAID's logit-space mix collapses to `teacher_logits`,
|
| 232 |
+
# `softmax(teacher_logits)` is computed bit-identically inside
|
| 233 |
+
# `taid_loss`, and the masked-mean reduction matches. So `torch.equal`
|
| 234 |
+
# succeeds — and asserting `equal` rather than `allclose` catches any
|
| 235 |
+
# future refactor that re-introduces a softmax→log roundtrip with
|
| 236 |
+
# ULP drift.
|
| 237 |
+
#
|
| 238 |
+
# If a future change forces a roundtrip we cannot eliminate, drop to
|
| 239 |
+
# `torch.testing.assert_close(out_taid.sdpo_jsd, expected,
|
| 240 |
+
# atol=1e-7, rtol=0)` — that is the strict-but-feasible bound for
|
| 241 |
+
# softmax→log→softmax in float32 (one ULP at the scale of the loss,
|
| 242 |
+
# ~3.5e-7 here, dominated by the log_softmax LSE accumulation).
|
| 243 |
+
assert torch.equal(out_taid.sdpo_jsd, expected), (
|
| 244 |
+
f"TAID t=1 must equal upstream forward-KL bit-exact; "
|
| 245 |
+
f"got out={out_taid.sdpo_jsd.item()!r}, "
|
| 246 |
+
f"expected={expected.item()!r}, "
|
| 247 |
+
f"diff={(out_taid.sdpo_jsd - expected).abs().item():.3e}"
|
| 248 |
+
)
|
| 249 |
|
| 250 |
|
| 251 |
# ----------------------------------------------------------------------
|
| 252 |
+
# (d) TAID interpolates: t=0 differs from t=1
|
| 253 |
# ----------------------------------------------------------------------
|
| 254 |
|
| 255 |
+
def test_taid_interpolates_with_t():
|
| 256 |
+
"""Different t values give different sdpo_jsd. Differentiable end-to-end."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
inputs = _base_batch(with_dpo=False)
|
| 258 |
|
| 259 |
+
model_zero = _model_seeded(seed=2)
|
| 260 |
+
out_zero = compose_loss(
|
| 261 |
+
model_zero, inputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 263 |
sdpo_wrapper="taid",
|
| 264 |
+
taid_t=0.0,
|
|
|
|
| 265 |
)
|
| 266 |
|
| 267 |
+
model_mid = _model_seeded(seed=2)
|
| 268 |
+
out_mid = compose_loss(
|
| 269 |
+
model_mid, inputs,
|
|
|
|
| 270 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 271 |
sdpo_wrapper="taid",
|
| 272 |
+
taid_t=0.5,
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
+
model_one = _model_seeded(seed=2)
|
| 276 |
+
out_one = compose_loss(
|
| 277 |
+
model_one, inputs,
|
|
|
|
| 278 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 279 |
sdpo_wrapper="taid",
|
| 280 |
+
taid_t=1.0,
|
|
|
|
| 281 |
)
|
| 282 |
|
| 283 |
+
for out in (out_zero, out_mid, out_one):
|
| 284 |
+
assert torch.isfinite(out.total)
|
| 285 |
+
assert torch.isfinite(out.sdpo_jsd)
|
|
|
|
| 286 |
|
| 287 |
+
assert not torch.allclose(out_zero.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5)
|
| 288 |
+
assert not torch.allclose(out_mid.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
|
|
|
| 290 |
out_mid.total.backward()
|
| 291 |
assert any(
|
| 292 |
p.grad is not None and torch.isfinite(p.grad).all()
|
|
|
|
| 322 |
|
| 323 |
|
| 324 |
# ----------------------------------------------------------------------
|
| 325 |
+
# (f) Error: sdpo_wrapper='taid' without taid_t
|
| 326 |
# ----------------------------------------------------------------------
|
| 327 |
|
| 328 |
+
def test_taid_requires_t():
|
| 329 |
inputs = _base_batch(with_dpo=False)
|
| 330 |
model = _model_seeded(seed=4)
|
| 331 |
+
with pytest.raises(ValueError, match="taid_t"):
|
| 332 |
compose_loss(
|
| 333 |
model, inputs,
|
| 334 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 335 |
sdpo_wrapper="taid",
|
| 336 |
+
# taid_t omitted on purpose
|
|
|
|
| 337 |
)
|
| 338 |
|
| 339 |
|
| 340 |
+
def test_taid_t_out_of_range_raises():
|
| 341 |
inputs = _base_batch(with_dpo=False)
|
| 342 |
model = _model_seeded(seed=4)
|
| 343 |
+
with pytest.raises(ValueError, match=r"taid_t must be in \[0, 1\]"):
|
| 344 |
compose_loss(
|
| 345 |
model, inputs,
|
| 346 |
alpha_sdpo=0.1, beta_replay=0.0,
|
| 347 |
sdpo_wrapper="taid",
|
| 348 |
+
taid_t=1.5,
|
|
|
|
| 349 |
)
|
| 350 |
|
| 351 |
|
|
|
|
| 370 |
|
| 371 |
|
| 372 |
# ----------------------------------------------------------------------
|
| 373 |
+
# Bonus: TAIDScheduler integration
|
| 374 |
# ----------------------------------------------------------------------
|
| 375 |
|
| 376 |
+
def test_taid_compose_with_scheduler():
|
| 377 |
+
"""End-to-end: TAIDScheduler drives taid_t into compose_loss."""
|
| 378 |
+
from composer_replication.distillation import TAIDScheduler
|
| 379 |
+
|
| 380 |
inputs = _base_batch(with_dpo=False)
|
| 381 |
model = _model_seeded(seed=6)
|
| 382 |
+
sched = TAIDScheduler(num_train_steps=100, t_start=0.4)
|
| 383 |
|
| 384 |
+
for step in range(3):
|
| 385 |
+
out = compose_loss(
|
| 386 |
+
model, inputs,
|
| 387 |
+
alpha_sdpo=0.1, beta_replay=0.0,
|
| 388 |
+
sdpo_wrapper="taid",
|
| 389 |
+
taid_t=sched.t,
|
| 390 |
+
)
|
| 391 |
+
assert torch.isfinite(out.total)
|
| 392 |
+
sched.update_t(out.sdpo_jsd.detach(), global_step=step)
|
| 393 |
|
| 394 |
+
# t may have advanced past t_start after some steps (or stayed the same
|
| 395 |
+
# given small num_train_steps and only 3 iters; just check it's still
|
| 396 |
+
# in-range).
|
| 397 |
+
assert 0.4 <= sched.t <= 1.0
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Numerical parity test against the upstream OPSD reference.
|
| 2 |
+
|
| 3 |
+
Loads `OPSDTrainer.generalized_jsd_loss` from a clone of siyan-zhao/OPSD at
|
| 4 |
+
/tmp/opsd-clone (override with $OPSD_CLONE) and asserts our re-implementation
|
| 5 |
+
in `composer_replication.opsd` matches it byte-for-byte across a grid of
|
| 6 |
+
shapes and β values. Skips cleanly when the upstream clone is absent.
|
| 7 |
+
|
| 8 |
+
Why this lives in `tests/` rather than docs: numerical parity is the
|
| 9 |
+
contract for this lift. If a future refactor of `generalized_jsd_loss`
|
| 10 |
+
silently shifts gradients again, this test fails immediately.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import importlib.util
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import pytest
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 24 |
+
|
| 25 |
+
# ----------------------------------------------------------------------
|
| 26 |
+
# Locate upstream OPSDTrainer.generalized_jsd_loss
|
| 27 |
+
# ----------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
_OPSD_CLONE = Path(os.environ.get("OPSD_CLONE", "/tmp/opsd-clone"))
|
| 30 |
+
_OPSD_TRAINER_PATH = _OPSD_CLONE / "opsd_trainer.py"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _load_upstream():
|
| 34 |
+
"""Import OPSDTrainer.generalized_jsd_loss from a local clone, isolated.
|
| 35 |
+
|
| 36 |
+
The upstream `opsd_trainer.py` imports heavyweight TRL / transformers
|
| 37 |
+
machinery at module scope, which we do not want to drag into the test
|
| 38 |
+
process. We instead extract the static method by parsing the source
|
| 39 |
+
text and exec-ing only that function body — it depends only on
|
| 40 |
+
`torch` and `torch.nn.functional`, which are already importable.
|
| 41 |
+
"""
|
| 42 |
+
if not _OPSD_TRAINER_PATH.exists():
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
text = _OPSD_TRAINER_PATH.read_text()
|
| 46 |
+
# Pull out the function block. It starts with `def generalized_jsd_loss(`
|
| 47 |
+
# under `class OPSDTrainer` and ends at the next top-of-class `def `.
|
| 48 |
+
start = text.find("def generalized_jsd_loss(")
|
| 49 |
+
if start < 0:
|
| 50 |
+
return None
|
| 51 |
+
# Walk forward to the start of the next sibling method (4-space indent
|
| 52 |
+
# `def ` or class-end) — they all start with exactly 4 spaces of indent.
|
| 53 |
+
rest = text[start:]
|
| 54 |
+
# Skip past the function header and find the next `\n def ` or
|
| 55 |
+
# `\n @staticmethod` boundary.
|
| 56 |
+
end_marker_offsets = []
|
| 57 |
+
for marker in ("\n @", "\n def ", "\nclass "):
|
| 58 |
+
idx = rest.find(marker, len("def generalized_jsd_loss("))
|
| 59 |
+
if idx > 0:
|
| 60 |
+
end_marker_offsets.append(idx)
|
| 61 |
+
if not end_marker_offsets:
|
| 62 |
+
return None
|
| 63 |
+
fn_text = rest[: min(end_marker_offsets)]
|
| 64 |
+
|
| 65 |
+
# Dedent (the source lines are 4-space indented as a class method).
|
| 66 |
+
fn_text = "\n".join(
|
| 67 |
+
line[4:] if line.startswith(" ") else line for line in fn_text.splitlines()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Exec into a fresh namespace with torch + F available.
|
| 71 |
+
import torch.nn.functional as F # noqa: F401 (used by exec'd code)
|
| 72 |
+
|
| 73 |
+
namespace: dict = {"torch": torch, "F": F}
|
| 74 |
+
exec(compile(fn_text, str(_OPSD_TRAINER_PATH), "exec"), namespace)
|
| 75 |
+
fn = namespace.get("generalized_jsd_loss")
|
| 76 |
+
return fn
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_UPSTREAM_FN = _load_upstream()
|
| 80 |
+
_SKIP_REASON = (
|
| 81 |
+
f"upstream OPSD clone not found at {_OPSD_TRAINER_PATH} "
|
| 82 |
+
f"(set $OPSD_CLONE or `git clone --depth 1 https://github.com/siyan-zhao/OPSD {_OPSD_CLONE}`)"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ----------------------------------------------------------------------
|
| 87 |
+
# Parity grid
|
| 88 |
+
# ----------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
_SHAPES = [
|
| 91 |
+
(1, 4, 16),
|
| 92 |
+
(2, 8, 32),
|
| 93 |
+
(3, 5, 64),
|
| 94 |
+
(1, 16, 8),
|
| 95 |
+
(4, 3, 24),
|
| 96 |
+
]
|
| 97 |
+
_BETAS = [0.0, 0.5, 1.0]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
|
| 101 |
+
@pytest.mark.parametrize("shape", _SHAPES)
|
| 102 |
+
@pytest.mark.parametrize("beta", _BETAS)
|
| 103 |
+
def test_parity_unmasked(shape, beta):
|
| 104 |
+
"""Our `generalized_jsd_loss` must match upstream within 1e-5 atol."""
|
| 105 |
+
B, T, V = shape
|
| 106 |
+
g = torch.Generator().manual_seed(13 + B * 31 + T * 17 + V)
|
| 107 |
+
student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
|
| 108 |
+
teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
|
| 109 |
+
|
| 110 |
+
ours = generalized_jsd_loss(student, teacher, beta=beta)
|
| 111 |
+
theirs = _UPSTREAM_FN(student, teacher, beta=beta) # type: ignore[misc]
|
| 112 |
+
|
| 113 |
+
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
|
| 114 |
+
f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
|
| 119 |
+
@pytest.mark.parametrize("shape", _SHAPES)
|
| 120 |
+
@pytest.mark.parametrize("beta", _BETAS)
|
| 121 |
+
def test_parity_masked(shape, beta):
|
| 122 |
+
"""Same parity but with a labels mask that ignores ~half the tokens."""
|
| 123 |
+
B, T, V = shape
|
| 124 |
+
g = torch.Generator().manual_seed(101 + B * 7 + T * 11 + V)
|
| 125 |
+
student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
|
| 126 |
+
teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
|
| 127 |
+
# Random valid/ignored mask: -100 for ignored, anything else for valid.
|
| 128 |
+
labels = torch.randint(0, 2, (B, T), generator=g)
|
| 129 |
+
labels = torch.where(labels == 0, torch.full_like(labels, -100), labels)
|
| 130 |
+
|
| 131 |
+
ours = generalized_jsd_loss(student, teacher, labels=labels, beta=beta)
|
| 132 |
+
theirs = _UPSTREAM_FN(student, teacher, labels=labels, beta=beta) # type: ignore[misc]
|
| 133 |
+
|
| 134 |
+
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
|
| 135 |
+
f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
|
| 140 |
+
def test_parity_temperature_and_topk():
|
| 141 |
+
"""Spot-check the temperature + top_k branches against upstream."""
|
| 142 |
+
g = torch.Generator().manual_seed(42)
|
| 143 |
+
student = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
|
| 144 |
+
teacher = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
|
| 145 |
+
|
| 146 |
+
for beta in (0.0, 0.3, 0.5, 0.7, 1.0):
|
| 147 |
+
ours = generalized_jsd_loss(student, teacher, beta=beta, temperature=2.0, top_k=8)
|
| 148 |
+
theirs = _UPSTREAM_FN( # type: ignore[misc]
|
| 149 |
+
student, teacher, beta=beta, temperature=2.0, top_k=8
|
| 150 |
+
)
|
| 151 |
+
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
|
| 152 |
+
f"temp+topk parity failed at beta={beta}: ours={ours.item()} theirs={theirs.item()}"
|
| 153 |
+
)
|
|
@@ -32,11 +32,15 @@ import torch
|
|
| 32 |
import torch.nn.functional as F
|
| 33 |
|
| 34 |
# These imports work when TRL is installed — they're not skeleton imports.
|
| 35 |
-
#
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
from trl import GRPOTrainer # type: ignore
|
|
|
|
| 38 |
except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL
|
| 39 |
GRPOTrainer = object # type: ignore — fallback so module imports without TRL
|
|
|
|
| 40 |
|
| 41 |
from composer_replication.opsd import generalized_jsd_loss
|
| 42 |
|
|
@@ -47,11 +51,15 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
|
| 47 |
"""TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.
|
| 48 |
|
| 49 |
Args (in addition to GRPOTrainer's):
|
| 50 |
-
alpha_sdpo: weight on SDPO hint-distill loss.
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
|
| 56 |
sdpo_token_clip: per-token JSD clip for stability; None = no clip.
|
| 57 |
replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
|
|
@@ -60,14 +68,19 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
|
| 60 |
def __init__(
|
| 61 |
self,
|
| 62 |
*args: Any,
|
| 63 |
-
alpha_sdpo: float = 0.
|
| 64 |
-
beta_replay: float = 0.
|
| 65 |
sdpo_jsd_beta: float = 0.5,
|
| 66 |
sdpo_temperature: float = 1.0,
|
| 67 |
sdpo_token_clip: float | None = None,
|
| 68 |
replay_dpo_beta: float = 0.1,
|
| 69 |
**kwargs: Any,
|
| 70 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
super().__init__(*args, **kwargs)
|
| 72 |
self.alpha_sdpo = alpha_sdpo
|
| 73 |
self.beta_replay = beta_replay
|
|
|
|
| 32 |
import torch.nn.functional as F
|
| 33 |
|
| 34 |
# These imports work when TRL is installed — they're not skeleton imports.
|
| 35 |
+
# When TRL is missing we fall back to `object` so the module still imports
|
| 36 |
+
# (e.g. for documentation generation) but raise a clear ImportError at
|
| 37 |
+
# instantiation time rather than the cryptic `object.__init__()` error.
|
| 38 |
try:
|
| 39 |
from trl import GRPOTrainer # type: ignore
|
| 40 |
+
_TRL_AVAILABLE = True
|
| 41 |
except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL
|
| 42 |
GRPOTrainer = object # type: ignore — fallback so module imports without TRL
|
| 43 |
+
_TRL_AVAILABLE = False
|
| 44 |
|
| 45 |
from composer_replication.opsd import generalized_jsd_loss
|
| 46 |
|
|
|
|
| 51 |
"""TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.
|
| 52 |
|
| 53 |
Args (in addition to GRPOTrainer's):
|
| 54 |
+
alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled).
|
| 55 |
+
Opt in by passing >0 once your data collator produces
|
| 56 |
+
`sdpo_loss_mask` and `ctx_teacher_input_ids` columns.
|
| 57 |
+
beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled).
|
| 58 |
+
Opt in by passing >0 once your data collator produces
|
| 59 |
+
`dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc.
|
| 60 |
+
sdpo_jsd_beta: beta param of generalized_jsd_loss
|
| 61 |
+
(0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per
|
| 62 |
+
upstream OPSD convention; see composer_replication/opsd.py).
|
| 63 |
sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
|
| 64 |
sdpo_token_clip: per-token JSD clip for stability; None = no clip.
|
| 65 |
replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
|
|
|
|
| 68 |
def __init__(
|
| 69 |
self,
|
| 70 |
*args: Any,
|
| 71 |
+
alpha_sdpo: float = 0.0,
|
| 72 |
+
beta_replay: float = 0.0,
|
| 73 |
sdpo_jsd_beta: float = 0.5,
|
| 74 |
sdpo_temperature: float = 1.0,
|
| 75 |
sdpo_token_clip: float | None = None,
|
| 76 |
replay_dpo_beta: float = 0.1,
|
| 77 |
**kwargs: Any,
|
| 78 |
):
|
| 79 |
+
if not _TRL_AVAILABLE:
|
| 80 |
+
raise ImportError(
|
| 81 |
+
"ComposerReplicationTrainer requires TRL. Install with "
|
| 82 |
+
"`pip install -e .[train]`."
|
| 83 |
+
)
|
| 84 |
super().__init__(*args, **kwargs)
|
| 85 |
self.alpha_sdpo = alpha_sdpo
|
| 86 |
self.beta_replay = beta_replay
|
|
@@ -118,13 +118,9 @@ def compose_loss(
|
|
| 118 |
lm_ce_label_smoothing: float = 0.0,
|
| 119 |
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 120 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 121 |
-
|
| 122 |
-
taid_total_steps: int | None = None,
|
| 123 |
simpo_beta: float = 2.0,
|
| 124 |
simpo_gamma: float = 1.0,
|
| 125 |
-
taid_schedule: str = "linear",
|
| 126 |
-
taid_alpha_min: float = 0.0,
|
| 127 |
-
taid_alpha_max: float = 1.0,
|
| 128 |
entropy_opd_h_max: float | None = None,
|
| 129 |
) -> LossComponents
|
| 130 |
```
|
|
@@ -141,7 +137,7 @@ Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`
|
|
| 141 |
- SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`.
|
| 142 |
- DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed).
|
| 143 |
- SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored.
|
| 144 |
-
- TAID (`sdpo_wrapper="taid"`): `
|
| 145 |
|
| 146 |
**Parameters**
|
| 147 |
|
|
@@ -151,25 +147,21 @@ Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`
|
|
| 151 |
| `inputs` | `dict[str, torch.Tensor]` | — | Batch dict (see required/optional keys above). |
|
| 152 |
| `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. |
|
| 153 |
| `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. |
|
| 154 |
-
| `sdpo_jsd_beta` | `float` | `0.5` | β param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). |
|
| 155 |
-
| `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. |
|
| 156 |
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. |
|
| 157 |
| `replay_dpo_beta` | `float` | `0.1` | β in standard DPO logit. |
|
| 158 |
| `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. |
|
| 159 |
| `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. |
|
| 160 |
| `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. |
|
| 161 |
-
| `
|
| 162 |
-
| `taid_total_steps` | `int \| None` | `None` | Required when `sdpo_wrapper="taid"`. |
|
| 163 |
| `simpo_beta` | `float` | `2.0` | SimPO β (paper default). |
|
| 164 |
| `simpo_gamma` | `float` | `1.0` | SimPO target margin γ (paper default). |
|
| 165 |
-
| `taid_schedule` | `str` | `"linear"` | One of `"linear"`, `"cosine"`, `"exp"`. |
|
| 166 |
-
| `taid_alpha_min` | `float` | `0.0` | Lower α bound. |
|
| 167 |
-
| `taid_alpha_max` | `float` | `1.0` | Upper α bound. |
|
| 168 |
| `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None` ⇒ `log(V)`. |
|
| 169 |
|
| 170 |
**Returns** `LossComponents` (see above).
|
| 171 |
|
| 172 |
-
**Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without
|
| 173 |
|
| 174 |
```python
|
| 175 |
from composer_replication import compose_loss, build_batch
|
|
@@ -331,89 +323,75 @@ lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
|
|
| 331 |
out = avg_sequence_logprob(lp, m) # shape (2,)
|
| 332 |
```
|
| 333 |
|
| 334 |
-
### `taid_loss(student_logits, teacher_logits,
|
| 335 |
|
| 336 |
```python
|
| 337 |
def taid_loss(
|
| 338 |
student_logits: torch.Tensor,
|
| 339 |
teacher_logits: torch.Tensor,
|
| 340 |
-
|
| 341 |
*,
|
| 342 |
-
|
| 343 |
-
total_steps: int,
|
| 344 |
-
schedule: str = "linear",
|
| 345 |
-
alpha_min: float = 0.0,
|
| 346 |
-
alpha_max: float = 1.0,
|
| 347 |
-
jsd_beta: float = 0.5,
|
| 348 |
-
temperature: float = 1.0,
|
| 349 |
-
reduction: str = "batchmean",
|
| 350 |
) -> torch.Tensor
|
| 351 |
```
|
| 352 |
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
**Parameters**
|
| 356 |
|
| 357 |
| Name | Type | Default | Meaning |
|
| 358 |
|---|---|---|---|
|
| 359 |
-
| `student_logits` | `Tensor (B,T,V)` | — | Current student (with grad). |
|
| 360 |
-
| `teacher_logits` | `Tensor (B,T,V)` | — | Teacher logits
|
| 361 |
-
| `
|
| 362 |
-
| `
|
| 363 |
-
| `total_steps` | `int` | — | Total planned steps. |
|
| 364 |
-
| `schedule` | `str` | `"linear"` | One of `"linear"`, `"cosine"`, `"exp"`. |
|
| 365 |
-
| `alpha_min`, `alpha_max` | `float`, `float` | `0.0`, `1.0` | Schedule range. |
|
| 366 |
-
| `jsd_beta` | `float` | `0.5` | β param of `generalized_jsd_loss`. |
|
| 367 |
-
| `temperature` | `float` | `1.0` | Softmax temperature. |
|
| 368 |
-
| `reduction` | `str` | `"batchmean"` | Forwarded to `generalized_jsd_loss`. |
|
| 369 |
|
| 370 |
-
**Raises** `ValueError` for
|
| 371 |
|
| 372 |
```python
|
| 373 |
from composer_replication.distillation import taid_loss
|
| 374 |
-
loss = taid_loss(s_logits, t_logits,
|
| 375 |
-
schedule_step=500, total_steps=10_000, schedule="linear")
|
| 376 |
-
```
|
| 377 |
-
|
| 378 |
-
### `taid_alpha_schedule(step, total_steps, *, schedule="linear", alpha_min=0.0, alpha_max=1.0, warmup_frac=0.0) -> float`
|
| 379 |
-
|
| 380 |
-
```python
|
| 381 |
-
def taid_alpha_schedule(
|
| 382 |
-
step: int, total_steps: int, *,
|
| 383 |
-
schedule: str = "linear",
|
| 384 |
-
alpha_min: float = 0.0,
|
| 385 |
-
alpha_max: float = 1.0,
|
| 386 |
-
warmup_frac: float = 0.0,
|
| 387 |
-
) -> float
|
| 388 |
```
|
| 389 |
|
| 390 |
-
|
| 391 |
|
| 392 |
-
|
| 393 |
|
| 394 |
```python
|
| 395 |
-
from composer_replication.distillation
|
| 396 |
-
a = taid_alpha_schedule(step=500, total_steps=10000, schedule="cosine") # 0.012...
|
| 397 |
-
```
|
| 398 |
-
|
| 399 |
-
### `taid_blended_logits(student_init_logits, teacher_logits, alpha) -> torch.Tensor`
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
) -> torch.Tensor
|
| 407 |
```
|
| 408 |
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
**Raises** `ValueError` if `alpha` ∉ `[0,1]` or shapes differ.
|
| 412 |
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor`
|
| 419 |
|
|
|
|
| 118 |
lm_ce_label_smoothing: float = 0.0,
|
| 119 |
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 120 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 121 |
+
taid_t: float | None = None,
|
|
|
|
| 122 |
simpo_beta: float = 2.0,
|
| 123 |
simpo_gamma: float = 1.0,
|
|
|
|
|
|
|
|
|
|
| 124 |
entropy_opd_h_max: float | None = None,
|
| 125 |
) -> LossComponents
|
| 126 |
```
|
|
|
|
| 137 |
- SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`.
|
| 138 |
- DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed).
|
| 139 |
- SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored.
|
| 140 |
+
- TAID (`sdpo_wrapper="taid"`): no extra `inputs` keys needed; the optional `sdpo_loss_mask` is reused as the per-token TAID mask. Pass `taid_t` directly (or drive it from `TAIDScheduler`).
|
| 141 |
|
| 142 |
**Parameters**
|
| 143 |
|
|
|
|
| 147 |
| `inputs` | `dict[str, torch.Tensor]` | — | Batch dict (see required/optional keys above). |
|
| 148 |
| `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. |
|
| 149 |
| `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. |
|
| 150 |
+
| `sdpo_jsd_beta` | `float` | `0.5` | β param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when `sdpo_wrapper="taid"`. |
|
| 151 |
+
| `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. Unused when `sdpo_wrapper="taid"`. |
|
| 152 |
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. |
|
| 153 |
| `replay_dpo_beta` | `float` | `0.1` | β in standard DPO logit. |
|
| 154 |
| `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. |
|
| 155 |
| `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. |
|
| 156 |
| `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. |
|
| 157 |
+
| `taid_t` | `float \| None` | `None` | Current TAID interpolation coefficient in `[0, 1]`. Required when `sdpo_wrapper="taid"`. Drive from `TAIDScheduler` or pass a fixed value. |
|
|
|
|
| 158 |
| `simpo_beta` | `float` | `2.0` | SimPO β (paper default). |
|
| 159 |
| `simpo_gamma` | `float` | `1.0` | SimPO target margin γ (paper default). |
|
|
|
|
|
|
|
|
|
|
| 160 |
| `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None` ⇒ `log(V)`. |
|
| 161 |
|
| 162 |
**Returns** `LossComponents` (see above).
|
| 163 |
|
| 164 |
+
**Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without `taid_t`, or if `taid_t` is outside `[0, 1]`.
|
| 165 |
|
| 166 |
```python
|
| 167 |
from composer_replication import compose_loss, build_batch
|
|
|
|
| 323 |
out = avg_sequence_logprob(lp, m) # shape (2,)
|
| 324 |
```
|
| 325 |
|
| 326 |
+
### `taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor`
|
| 327 |
|
| 328 |
```python
|
| 329 |
def taid_loss(
|
| 330 |
student_logits: torch.Tensor,
|
| 331 |
teacher_logits: torch.Tensor,
|
| 332 |
+
mask: torch.Tensor | None = None,
|
| 333 |
*,
|
| 334 |
+
t: float | torch.Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
) -> torch.Tensor
|
| 336 |
```
|
| 337 |
|
| 338 |
+
Faithful port of `SakanaAI/TAID` (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the **current student detached**:
|
| 339 |
+
|
| 340 |
+
```
|
| 341 |
+
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
|
| 342 |
+
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
At `t=0` the target collapses to the detached student (no teacher signal in the gradient). At `t=1` it reduces to standard forward-KL distillation against the teacher.
|
| 346 |
+
|
| 347 |
+
**Wave 15 breaking change.** The previous signature `taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction)` was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see `TAIDScheduler` below for the upstream adaptive scheme).
|
| 348 |
|
| 349 |
**Parameters**
|
| 350 |
|
| 351 |
| Name | Type | Default | Meaning |
|
| 352 |
|---|---|---|---|
|
| 353 |
+
| `student_logits` | `Tensor (B, T, V)` | — | Current student (with grad). |
|
| 354 |
+
| `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits. |
|
| 355 |
+
| `mask` | `Tensor (B, T) \| None` | `None` | Token mask. `None` ⇒ all-ones. |
|
| 356 |
+
| `t` | `float \| Tensor` | — | Interpolation coefficient in `[0, 1]`. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
+
**Raises** `ValueError` for shape mismatch.
|
| 359 |
|
| 360 |
```python
|
| 361 |
from composer_replication.distillation import taid_loss
|
| 362 |
+
loss = taid_loss(s_logits, t_logits, mask, t=0.4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
```
|
| 364 |
|
| 365 |
+
### `TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)`
|
| 366 |
|
| 367 |
+
Stateful schedule that mirrors upstream `TAID.update_t`. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as:
|
| 368 |
|
| 369 |
```python
|
| 370 |
+
from composer_replication.distillation import TAIDScheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
+
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
|
| 373 |
+
for step in range(num_train_steps):
|
| 374 |
+
loss = taid_loss(s, t, mask, t=sched.t)
|
| 375 |
+
loss.backward(); optimizer.step()
|
| 376 |
+
sched.update_t(loss.detach(), global_step=step)
|
|
|
|
| 377 |
```
|
| 378 |
|
| 379 |
+
**Parameters**
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
| Name | Type | Default | Meaning |
|
| 382 |
+
|---|---|---|---|
|
| 383 |
+
| `num_train_steps` | `int` | — | Total planned training steps; sets the linear floor. |
|
| 384 |
+
| `t_start` | `float` | `0.4` | Initial `t` (paper default). |
|
| 385 |
+
| `t_end` | `float` | `1.0` | Terminal `t`; hard ceiling at every step. |
|
| 386 |
+
| `alpha` | `float` | `5e-4` | Adaptive bump magnitude. |
|
| 387 |
+
| `beta` | `float` | `0.99` | EMA decay on relative-loss-change momentum. |
|
| 388 |
+
| `disable_adaptive` | `bool` | `False` | If True, fall back to deterministic linear schedule. |
|
| 389 |
+
| `device` | `torch.device \| str` | `"cpu"` | Where to allocate state buffers. |
|
| 390 |
+
|
| 391 |
+
**Properties / methods**
|
| 392 |
+
|
| 393 |
+
- `sched.t -> float` — current `t` as a Python float (zero-arg property).
|
| 394 |
+
- `sched.update_t(loss, global_step) -> Tensor | None` — update internal state. First finite-loss call only seeds `prev_loss` and returns `None`; subsequent calls return the (positive) `delta_t` added on top of the linear floor.
|
| 395 |
|
| 396 |
### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor`
|
| 397 |
|
|
@@ -71,12 +71,9 @@ def compose_loss(
|
|
| 71 |
# ADR-007 extensions
|
| 72 |
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 73 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 74 |
-
|
| 75 |
-
taid_total_steps: int | None = None,
|
| 76 |
simpo_beta: float = 2.0,
|
| 77 |
simpo_gamma: float = 1.0,
|
| 78 |
-
taid_schedule: str = "linear",
|
| 79 |
-
taid_alpha_max: float = 1.0,
|
| 80 |
entropy_opd_h_max: float | None = None,
|
| 81 |
) -> torch.Tensor: ...
|
| 82 |
```
|
|
@@ -213,12 +210,11 @@ trainer = ComposerReplicationTrainer(
|
|
| 213 |
dpo_variant = "simpo",
|
| 214 |
simpo_beta = 2.0,
|
| 215 |
simpo_gamma = 1.0,
|
| 216 |
-
# TAID
|
| 217 |
sdpo_wrapper = "taid",
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
taid_alpha_max = 1.0,
|
| 222 |
)
|
| 223 |
```
|
| 224 |
|
|
@@ -936,7 +932,7 @@ In Wave 14: $0 (skeleton fails fast; no compute used). Projected for v0.2+:
|
|
| 936 |
## Cross-recipe checklist
|
| 937 |
|
| 938 |
Regardless of which recipe you pick, these invariants are tested across
|
| 939 |
-
the
|
| 940 |
|
| 941 |
- **`alpha_sdpo=0`** must reproduce the channel-1-only baseline
|
| 942 |
bit-exact (`test_compose_loss_integration.py`).
|
|
|
|
| 71 |
# ADR-007 extensions
|
| 72 |
dpo_variant: Literal["dpo", "simpo"] = "dpo",
|
| 73 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 74 |
+
taid_t: float | None = None,
|
|
|
|
| 75 |
simpo_beta: float = 2.0,
|
| 76 |
simpo_gamma: float = 1.0,
|
|
|
|
|
|
|
| 77 |
entropy_opd_h_max: float | None = None,
|
| 78 |
) -> torch.Tensor: ...
|
| 79 |
```
|
|
|
|
| 210 |
dpo_variant = "simpo",
|
| 211 |
simpo_beta = 2.0,
|
| 212 |
simpo_gamma = 1.0,
|
| 213 |
+
# TAID for channel 2 (SakanaAI port; logit-space mix + forward-KL):
|
| 214 |
sdpo_wrapper = "taid",
|
| 215 |
+
taid_t = 0.4, # current TAID coeff in [0, 1];
|
| 216 |
+
# drive from TAIDScheduler if you want
|
| 217 |
+
# the paper's adaptive scheme
|
|
|
|
| 218 |
)
|
| 219 |
```
|
| 220 |
|
|
|
|
| 932 |
## Cross-recipe checklist
|
| 933 |
|
| 934 |
Regardless of which recipe you pick, these invariants are tested across
|
| 935 |
+
the 115-test suite (post-Wave-15) and should be true of your wired-up system:
|
| 936 |
|
| 937 |
- **`alpha_sdpo=0`** must reproduce the channel-1-only baseline
|
| 938 |
bit-exact (`test_compose_loss_integration.py`).
|
|
@@ -39,7 +39,8 @@ broken" reports turn out to be one of these:
|
|
| 39 |
`pip show composer-replication | grep Location`.
|
| 40 |
|
| 41 |
4. **Optional extras.** Several modules are optional-dep gated:
|
| 42 |
-
- `[replay]`
|
|
|
|
| 43 |
- `[replaysim]` — adds `data-juicer` (and via it, Ray as a transitive).
|
| 44 |
- `[serverless]` — adds `fsspec`. For non-local rendezvous URIs you
|
| 45 |
also need a backend-specific fsspec adapter (see Failure Mode 5).
|
|
|
|
| 39 |
`pip show composer-replication | grep Location`.
|
| 40 |
|
| 41 |
4. **Optional extras.** Several modules are optional-dep gated:
|
| 42 |
+
- `[replay]` — adds `httpx` (used for OpenRouter teacher calls).
|
| 43 |
+
- `[train]` — adds TRL, peft, accelerate, datasets (production GRPO).
|
| 44 |
- `[replaysim]` — adds `data-juicer` (and via it, Ray as a transitive).
|
| 45 |
- `[serverless]` — adds `fsspec`. For non-local rendezvous URIs you
|
| 46 |
also need a backend-specific fsspec adapter (see Failure Mode 5).
|
|
@@ -364,51 +364,77 @@ the reference acts as a regularizer.
|
|
| 364 |
|
| 365 |
## 6. Adding TAID / Entropy-Aware OPD wrappers
|
| 366 |
|
| 367 |
-
Channel 2 (SDPO/OPSD) can be
|
| 368 |
-
arXiv:2501.16937) for capacity-gap distillation, or
|
| 369 |
**Entropy-Aware OPD** (ICLR 2026 Spotlight) for per-token forward/reverse-KL
|
| 370 |
-
gating. Both are
|
| 371 |
|
| 372 |
```python
|
| 373 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 374 |
-
|
| 375 |
-
taid_total_steps: int | None = None,
|
| 376 |
-
taid_schedule: str = "linear", # "linear" | "cosine" | "exp"
|
| 377 |
-
taid_alpha_min: float = 0.0,
|
| 378 |
-
taid_alpha_max: float = 1.0,
|
| 379 |
entropy_opd_h_max: float | None = None,
|
| 380 |
```
|
| 381 |
|
| 382 |
(verified at `composer_replication/loss.py:82–93`.)
|
| 383 |
|
| 384 |
-
### TAID
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
```
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
```
|
| 392 |
|
| 393 |
-
|
| 394 |
-
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
If neither is provided, `_resolve_student_init_logits` raises
|
| 410 |
-
`ValueError` with a clear message
|
| 411 |
-
(`composer_replication/loss.py:351–392`).
|
| 412 |
|
| 413 |
### Entropy-Aware OPD
|
| 414 |
|
|
@@ -425,41 +451,28 @@ maximum-entropy bound for a vocab-V softmax).
|
|
| 425 |
|
| 426 |
### Boundary-condition unit test (proof of correctness)
|
| 427 |
|
| 428 |
-
The test `
|
| 429 |
-
(`composer_replication/distillation/tests/test_distillation_losses.py
|
| 430 |
-
pins
|
| 431 |
-
|
| 432 |
|
| 433 |
```python
|
| 434 |
-
def
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
#
|
| 442 |
-
loss_a
|
| 443 |
-
total_steps=100, alpha_min=0.0, alpha_max=0.0)
|
| 444 |
-
loss_b = taid_loss(s1, teacher_b, student_init, schedule_step=0,
|
| 445 |
-
total_steps=100, alpha_min=0.0, alpha_max=0.0)
|
| 446 |
-
# Two completely different teachers must give the same loss.
|
| 447 |
-
assert abs(float(loss_a) - float(loss_b)) < 1e-4
|
| 448 |
```
|
| 449 |
|
| 450 |
-
This is the load-bearing test for TAID: if the
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
endpoints (α=0 → student_init, α=1 → teacher) and the half-way mixing
|
| 456 |
-
behavior.
|
| 457 |
-
|
| 458 |
-
For Entropy-OPD, the boundary test is
|
| 459 |
-
`test_entropy_aware_opd_zero_when_distributions_match` (line 217): when
|
| 460 |
-
student logits ≡ teacher logits, both KLs are 0 and the loss must be 0
|
| 461 |
-
to numerical precision.
|
| 462 |
-
|
| 463 |
---
|
| 464 |
|
| 465 |
## 7. Going multi-replica with serverless DiLoCo
|
|
@@ -655,7 +668,7 @@ and `docs/adrs/ADR-006-rl-frameworks.md`.
|
|
| 655 |
|
| 656 |
## Common pitfalls + what tests catch them
|
| 657 |
|
| 658 |
-
The framework's
|
| 659 |
specific test-file home. If you hit one of these in production, the
|
| 660 |
corresponding test is your fastest reproducer.
|
| 661 |
|
|
|
|
| 364 |
|
| 365 |
## 6. Adding TAID / Entropy-Aware OPD wrappers
|
| 366 |
|
| 367 |
+
Channel 2 (SDPO/OPSD) can be replaced by **TAID** (Sakana AI,
|
| 368 |
+
arXiv:2501.16937) for capacity-gap distillation, or by
|
| 369 |
**Entropy-Aware OPD** (ICLR 2026 Spotlight) for per-token forward/reverse-KL
|
| 370 |
+
gating. Both are wired through `compose_loss`:
|
| 371 |
|
| 372 |
```python
|
| 373 |
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
|
| 374 |
+
taid_t: float | None = None, # current TAID interpolation coeff
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
entropy_opd_h_max: float | None = None,
|
| 376 |
```
|
| 377 |
|
| 378 |
(verified at `composer_replication/loss.py:82–93`.)
|
| 379 |
|
| 380 |
+
### TAID (upstream-faithful port)
|
| 381 |
|
| 382 |
+
> **Wave 15 rewrite, breaking change.** The previous in-tree TAID was
|
| 383 |
+
> algorithmically different from the paper (it mixed in probability space
|
| 384 |
+
> against a frozen step-0 student snapshot and wrapped a symmetric JSD
|
| 385 |
+
> criterion). It has been replaced with an upstream-faithful port:
|
| 386 |
+
> logit-space mix, current-student-detached anchor, forward-KL criterion.
|
| 387 |
+
> Old kwargs `taid_schedule_step`, `taid_total_steps`, `taid_schedule`,
|
| 388 |
+
> `taid_alpha_min`, `taid_alpha_max`, plus `inputs["student_init_logits"]` /
|
| 389 |
+
> `inputs["student_init_input_ids"]` are **gone**. They have no upstream
|
| 390 |
+
> analogue. Use `taid_t` (and optionally `TAIDScheduler`) instead.
|
| 391 |
|
| 392 |
+
The TAID criterion is forward-KL against a logit-space-interpolated target:
|
| 393 |
+
|
| 394 |
+
```
|
| 395 |
+
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
|
| 396 |
+
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
|
| 397 |
```
|
| 398 |
+
|
| 399 |
+
where `t ∈ [0, 1]` is the interpolation coefficient. At `t=0` the target
|
| 400 |
+
is the (detached) student itself — the loss is the entropy of that
|
| 401 |
+
distribution and contributes no gradient to the student. At `t=1` it
|
| 402 |
+
reduces to standard forward-KL distillation against the teacher.
|
| 403 |
+
|
| 404 |
+
The schedule that produces `t` is the **trainer's** responsibility. The
|
| 405 |
+
package ships an optional `TAIDScheduler` that mirrors the paper's
|
| 406 |
+
adaptive momentum scheme:
|
| 407 |
+
|
| 408 |
+
```python
|
| 409 |
+
from composer_replication.distillation import TAIDScheduler
|
| 410 |
+
|
| 411 |
+
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
|
| 412 |
+
for step in range(num_train_steps):
|
| 413 |
+
components = compose_loss(
|
| 414 |
+
model, batch,
|
| 415 |
+
sdpo_wrapper="taid",
|
| 416 |
+
taid_t=sched.t,
|
| 417 |
+
)
|
| 418 |
+
components.total.backward(); optimizer.step()
|
| 419 |
+
sched.update_t(components.sdpo_jsd.detach(), global_step=step)
|
| 420 |
```
|
| 421 |
|
| 422 |
+
`TAIDScheduler` defaults match upstream: `t_start=0.4`, `t_end=1.0`,
|
| 423 |
+
`alpha=5e-4`, `beta=0.99`. Pass `disable_adaptive=True` to fall back to
|
| 424 |
+
the deterministic linear schedule
|
| 425 |
+
`t = t_start + progress · (t_end - t_start)`.
|
| 426 |
+
|
| 427 |
+
If you want a simple fixed schedule (no scheduler), just compute `t`
|
| 428 |
+
yourself and pass it in — `compose_loss` validates `taid_t ∈ [0, 1]`.
|
| 429 |
+
|
| 430 |
+
### Upstream-parity test
|
| 431 |
+
|
| 432 |
+
`composer_replication/distillation/tests/test_taid_parity.py` skip-imports
|
| 433 |
+
the upstream reference at `/tmp/taid-clone` (clone with
|
| 434 |
+
`git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone`)
|
| 435 |
+
and asserts our `taid_loss(student, teacher, mask, t)` matches upstream
|
| 436 |
+
`TAID.compute_loss(...)` within `atol=rtol=1e-5` across `t ∈ {0.0, 0.1, 0.4,
|
| 437 |
+
0.5, 0.9, 1.0}`. This is the load-bearing parity guarantee.
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
### Entropy-Aware OPD
|
| 440 |
|
|
|
|
| 451 |
|
| 452 |
### Boundary-condition unit test (proof of correctness)
|
| 453 |
|
| 454 |
+
The test `test_taid_loss_t_zero_target_matches_detached_student`
|
| 455 |
+
(`composer_replication/distillation/tests/test_distillation_losses.py`)
|
| 456 |
+
pins TAID's `t=0` invariant — the teacher is *completely* hidden from the
|
| 457 |
+
gradient because the target collapses to `softmax(student.detach())`:
|
| 458 |
|
| 459 |
```python
|
| 460 |
+
def test_taid_loss_t_zero_target_matches_detached_student():
|
| 461 |
+
s1 = torch.randn(1, 2, 4, requires_grad=True)
|
| 462 |
+
teacher_a = torch.zeros(1, 2, 4); teacher_a[..., 0] = 10.0
|
| 463 |
+
teacher_b = torch.zeros(1, 2, 4); teacher_b[..., 3] = 10.0
|
| 464 |
+
mask = torch.ones(1, 2)
|
| 465 |
+
loss_a = taid_loss(s1, teacher_a, mask, t=0.0)
|
| 466 |
+
loss_b = taid_loss(s1, teacher_b, mask, t=0.0)
|
| 467 |
+
# Two completely different teachers must give the same loss at t=0.
|
| 468 |
+
assert abs(float(loss_a) - float(loss_b)) < 1e-6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
```
|
| 470 |
|
| 471 |
+
This is the load-bearing test for TAID: if the `t=0` endpoint ever leaks
|
| 472 |
+
teacher signal into the gradient, this test fires and the contract is
|
| 473 |
+
broken. The companion test `test_taid_loss_t_one_is_pure_forward_kl`
|
| 474 |
+
pins the `t=1` endpoint by hand-computing `-Σ p_teacher · log_q` and
|
| 475 |
+
asserting equality.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
---
|
| 477 |
|
| 478 |
## 7. Going multi-replica with serverless DiLoCo
|
|
|
|
| 668 |
|
| 669 |
## Common pitfalls + what tests catch them
|
| 670 |
|
| 671 |
+
The framework's 115-test suite (post-Wave-15) is structured so each pitfall has a
|
| 672 |
specific test-file home. If you hit one of these in production, the
|
| 673 |
corresponding test is your fastest reproducer.
|
| 674 |
|
|
@@ -107,10 +107,12 @@ The user expanded the brief mid-loop:
|
|
| 107 |
| Replaysim normalization | ADR-004 + `composer_replication.replaysim` package + `data-juicer` adapter + default YAML recipe + 9 unit tests | ✅ Closed (passthrough) / 🟡 Pending data-juicer install for full path |
|
| 108 |
| Other RL frameworks (V3 expansion) | ADR-006 + `composer_replication.recipes.prime_rl` (recipe + composer_loss adapter + config.yaml) | ✅ Closed (recipe) / 🟡 Skeleton (runtime) |
|
| 109 |
| Meta's PyTorch agentic stack | ADR-006 + `composer_replication.recipes.monarch` (actor layout doc + skeleton actors) | ✅ Closed (design) / 🟡 Skeleton (impl) |
|
| 110 |
-
| Deeper self-distillation research | ADR-007 + `docs/research/SELF_DISTILLATION_LANDSCAPE.md` + `composer_replication.distillation` module (SimPO + TAID + Entropy-Aware OPD) +
|
| 111 |
| altered-minds tie-in | `docs/ALTERED_MINDS_TIE_IN.md` (5-phase plan, $300 estimate, open questions) | ✅ Closed (design) |
|
| 112 |
|
| 113 |
**Wave 13 test addition**: 35 new tests passing (17 distillation + 9 serverless multi-process + 9 replaysim).
|
| 114 |
|
| 115 |
-
The framework now covers the full expanded brief. Total tests passing
|
| 116 |
-
|
|
|
|
|
|
|
|
|
| 107 |
| Replaysim normalization | ADR-004 + `composer_replication.replaysim` package + `data-juicer` adapter + default YAML recipe + 9 unit tests | ✅ Closed (passthrough) / 🟡 Pending data-juicer install for full path |
|
| 108 |
| Other RL frameworks (V3 expansion) | ADR-006 + `composer_replication.recipes.prime_rl` (recipe + composer_loss adapter + config.yaml) | ✅ Closed (recipe) / 🟡 Skeleton (runtime) |
|
| 109 |
| Meta's PyTorch agentic stack | ADR-006 + `composer_replication.recipes.monarch` (actor layout doc + skeleton actors) | ✅ Closed (design) / 🟡 Skeleton (impl) |
|
| 110 |
+
| Deeper self-distillation research | ADR-007 + `docs/research/SELF_DISTILLATION_LANDSCAPE.md` + `composer_replication.distillation` module (SimPO + TAID-rewritten + Entropy-Aware OPD) + tests | ✅ Closed end-to-end — `compose_loss` kwargs wired in Wave 14; TAID rewritten in Wave 15 to match SakanaAI/TAID upstream (logit-space mix, current-student-detached anchor, forward-KL criterion, optional `TAIDScheduler`); OPSD parity test added against `siyan-zhao/OPSD` upstream. |
|
| 111 |
| altered-minds tie-in | `docs/ALTERED_MINDS_TIE_IN.md` (5-phase plan, $300 estimate, open questions) | ✅ Closed (design) |
|
| 112 |
|
| 113 |
**Wave 13 test addition**: 35 new tests passing (17 distillation + 9 serverless multi-process + 9 replaysim).
|
| 114 |
|
| 115 |
+
The framework now covers the full expanded brief. **Total tests passing
|
| 116 |
+
post-Wave-15: 115 + 1 skip-marked.** Wave-by-wave evolution: 72 (W12) → 93 (W13) → 124 (W14) → 130 (W14b) → 115 (W15: TAID rewrite consolidated 16 schedule-tests into 7 t-parameterized tests; OPSD upstream-parity test added skip-marked).
|
| 117 |
+
|
| 118 |
+
This is the canonical running test count; other docs reference V1_V8_COVERAGE rather than restating.
|
|
@@ -191,6 +191,96 @@ No new deps — these are pure PyTorch losses on top of existing tensors.
|
|
| 191 |
- v0.3: integrate the three new losses with PRIME-RL's `CustomLossConfig`
|
| 192 |
(per ADR-006) so users can mix-and-match across frameworks.
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
## Source
|
| 195 |
|
| 196 |
`docs/research/SELF_DISTILLATION_LANDSCAPE.md` (2026-05-26 subagent recon,
|
|
|
|
| 191 |
- v0.3: integrate the three new losses with PRIME-RL's `CustomLossConfig`
|
| 192 |
(per ADR-006) so users can mix-and-match across frameworks.
|
| 193 |
|
| 194 |
+
## Wave 15 update — TAID rewritten to match the upstream paper (BREAKING)
|
| 195 |
+
|
| 196 |
+
The TAID implementation that landed in Waves 13/14 was algorithmically
|
| 197 |
+
different from the SakanaAI/TAID reference. A Wave 15 math review (against
|
| 198 |
+
the upstream `src/distil_losses/taid.py`) found four divergences:
|
| 199 |
+
|
| 200 |
+
1. **Interpolation space**: upstream mixes in **logit space**, ours mixed
|
| 201 |
+
in probability space.
|
| 202 |
+
2. **Anchor distribution**: upstream uses the **current student detached**,
|
| 203 |
+
re-evaluated each step; ours used a frozen step-0 snapshot.
|
| 204 |
+
3. **Schedule**: upstream uses an **adaptive momentum-based** scheme on the
|
| 205 |
+
relative loss change; ours used a deterministic linear/cosine/exp ramp.
|
| 206 |
+
4. **Distillation criterion**: upstream uses **forward KL** with the
|
| 207 |
+
interpolated target as the soft target (Hinton-style); ours wrapped a
|
| 208 |
+
symmetric JSD.
|
| 209 |
+
|
| 210 |
+
### Decision: replace `taid_loss` in place to match upstream
|
| 211 |
+
|
| 212 |
+
The function name `taid_loss` is reserved for the algorithm in the paper.
|
| 213 |
+
Renaming was rejected because the misnamed function had only been
|
| 214 |
+
shipping for two waves and is small in surface area. The breaking-change
|
| 215 |
+
cost is acceptable; the cost of leaving an algorithmically-incorrect
|
| 216 |
+
function under that name forever is not.
|
| 217 |
+
|
| 218 |
+
### New API
|
| 219 |
+
|
| 220 |
+
```python
|
| 221 |
+
def taid_loss(
|
| 222 |
+
student_logits: torch.Tensor,
|
| 223 |
+
teacher_logits: torch.Tensor,
|
| 224 |
+
mask: torch.Tensor | None = None,
|
| 225 |
+
*,
|
| 226 |
+
t: float | torch.Tensor,
|
| 227 |
+
) -> torch.Tensor: ...
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
`t ∈ [0, 1]` is passed in directly. The schedule is the **caller's**
|
| 231 |
+
responsibility; the package ships an optional
|
| 232 |
+
`composer_replication.distillation.TAIDScheduler` that mirrors the
|
| 233 |
+
upstream adaptive-momentum scheme (`t_start=0.4, t_end=1.0, alpha=5e-4,
|
| 234 |
+
beta=0.99`, monotone non-decreasing, clamped at `t_end`). Pass
|
| 235 |
+
`disable_adaptive=True` to fall back to a deterministic linear floor.
|
| 236 |
+
|
| 237 |
+
### Removed
|
| 238 |
+
|
| 239 |
+
- Function args: `student_init_logits`, `schedule_step`, `total_steps`,
|
| 240 |
+
`schedule`, `alpha_min`, `alpha_max`, `jsd_beta`, `temperature`,
|
| 241 |
+
`reduction`. None has an upstream analogue.
|
| 242 |
+
- Helpers: `taid_alpha_schedule`, `taid_blended_logits`. Not exported any
|
| 243 |
+
more.
|
| 244 |
+
- `compose_loss` kwargs: `taid_schedule_step`, `taid_total_steps`,
|
| 245 |
+
`taid_schedule`, `taid_alpha_min`, `taid_alpha_max`. Replaced by
|
| 246 |
+
`taid_t: float | None`.
|
| 247 |
+
- `inputs["student_init_logits"]` / `inputs["student_init_input_ids"]`
|
| 248 |
+
are no longer consumed by the TAID path. The `_resolve_student_init_logits`
|
| 249 |
+
helper has been deleted.
|
| 250 |
+
|
| 251 |
+
### Parity
|
| 252 |
+
|
| 253 |
+
`composer_replication/distillation/tests/test_taid_parity.py` runs our
|
| 254 |
+
`taid_loss` head-to-head against the upstream
|
| 255 |
+
`TAID.compute_loss(...)` / `forward_kl(...)` (loaded via inline-exec from
|
| 256 |
+
`/tmp/taid-clone/src/distil_losses/{taid,fkl}.py`) across
|
| 257 |
+
`t ∈ {0.0, 0.1, 0.4, 0.5, 0.9, 1.0}`. All seven parametrizations match at
|
| 258 |
+
`atol=rtol=1e-5`. The test is `pytest.mark.skipif`-guarded on the clone's
|
| 259 |
+
presence so CI without the clone still passes.
|
| 260 |
+
|
| 261 |
+
### Migration
|
| 262 |
+
|
| 263 |
+
Old:
|
| 264 |
+
```python
|
| 265 |
+
loss = taid_loss(student_logits, teacher_logits, student_init_logits,
|
| 266 |
+
schedule_step=step, total_steps=max_steps,
|
| 267 |
+
schedule="linear", alpha_min=0.0, alpha_max=1.0)
|
| 268 |
+
```
|
| 269 |
+
New:
|
| 270 |
+
```python
|
| 271 |
+
from composer_replication.distillation import TAIDScheduler
|
| 272 |
+
sched = TAIDScheduler(num_train_steps=max_steps)
|
| 273 |
+
# … each step:
|
| 274 |
+
loss = taid_loss(student_logits, teacher_logits, mask, t=sched.t)
|
| 275 |
+
sched.update_t(loss.detach(), global_step=step)
|
| 276 |
+
```
|
| 277 |
+
|
| 278 |
+
The previous wording (Wave 14 "Closed" section, immediately above) is
|
| 279 |
+
**partially superseded**: SimPO and Entropy-Aware OPD still match what
|
| 280 |
+
shipped; only the TAID path is rewritten.
|
| 281 |
+
|
| 282 |
+
---
|
| 283 |
+
|
| 284 |
## Source
|
| 285 |
|
| 286 |
`docs/research/SELF_DISTILLATION_LANDSCAPE.md` (2026-05-26 subagent recon,
|
|
@@ -260,5 +260,4 @@ to catch this and similar adapter-shape regressions.
|
|
| 260 |
| W14 NIT 7: docstring claims ISR clipping | ✅ closed in Wave 14b (real ISR now implemented) |
|
| 261 |
| **NEW (Wave 14b)**: PRIME-RL `LossOutputs` return shape | 🟡 deferred to Wave 15 |
|
| 262 |
|
| 263 |
-
**
|
| 264 |
-
parity test, runs when prime-rl is installed).**
|
|
|
|
| 260 |
| W14 NIT 7: docstring claims ISR clipping | ✅ closed in Wave 14b (real ISR now implemented) |
|
| 261 |
| **NEW (Wave 14b)**: PRIME-RL `LossOutputs` return shape | 🟡 deferred to Wave 15 |
|
| 262 |
|
| 263 |
+
**Tests as of Wave 14b: 115 passing + 1 skip-marked (OPSD parity test, runs when upstream cloned).** (Wave 12: 72; Wave 13: 93; Wave 14: 124; Wave 14b: 130; Wave 15: 115 after TAID rewrite consolidation + OPSD parity.)
|
|
|
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wave 15 Final Review — Multi-Angle Self-Critique + Fix Wave
|
| 2 |
+
|
| 3 |
+
**Date:** 2026-05-26
|
| 4 |
+
**Method:** 4 parallel adversarial reviewers (math / tests / docs / user-journey), each given a different framing to maximize independent-angle coverage. Then targeted fix scatter on findings.
|
| 5 |
+
|
| 6 |
+
## Headline finding
|
| 7 |
+
|
| 8 |
+
**The math reviewer found 2 BLOCKERs that all 8+ prior subagents missed.** Both came from `git clone`-ing upstream and doing line-by-line diffs against the framework's `composer_replication/opsd.py` and `composer_replication/distillation/taid.py` — something no prior reviewer had done for those files (Wave 14b reviewer did it for PRIME-RL only).
|
| 9 |
+
|
| 10 |
+
This validates the user's instinct that "every angle" multi-model orchestration is worth doing — the math angle, given a sharp prompt that mandated upstream verification, found genuine bugs in the framework's primary loss kernel.
|
| 11 |
+
|
| 12 |
+
## Wave 15a reviews (all 4 deliverables)
|
| 13 |
+
|
| 14 |
+
| Reviewer | Focus | BLOCKERs | Severity-weighted findings |
|
| 15 |
+
|---|---|---|---|
|
| 16 |
+
| Math correctness (Opus 4.7) | 7 claimed implementations vs primary sources | **2 BLOCKER + 3 minor** | `generalized_jsd_loss` math wrong; `taid_loss` algorithm wrong |
|
| 17 |
+
| Test honesty (Opus 4.7) | 3 specific test files | 0 BLOCKER + 3 weak-assertions | PRIME-RL parity skip silently never runs; bit-exact uses `allclose` not `equal`; entropy-OPD test is pure smoke |
|
| 18 |
+
| Documentation drift (Opus 4.7) | 6 major docs + ADRs | 0 BLOCKER + 7 drifts | test count drift (77/107/124 vs actual 145); `compose_loss` kwarg drift; PRIME-RL test count 10 vs 16; stale "Deferred to Wave 14" claim |
|
| 19 |
+
| User journey (Opus 4.7) | RL-finetune Qwen-7B on GSM8K | 0 BLOCKER + 10 friction items | **No GSM8K example** (#1 ask); no runnable `ComposerReplicationTrainer` recipe; data-collator gap undocumented; defaults activate channels users haven't configured |
|
| 20 |
+
|
| 21 |
+
Reports saved at `/tmp/wave15_{math,test,doc,user}_review.md`.
|
| 22 |
+
|
| 23 |
+
## Wave 15b — fix scatter outcomes
|
| 24 |
+
|
| 25 |
+
5 parallel fix subagents dispatched. Outcomes:
|
| 26 |
+
|
| 27 |
+
| Task | Subagent outcome |
|
| 28 |
+
|---|---|
|
| 29 |
+
| (1) OPSD math rewrite vs upstream | ✅ Completed. New parity test (skip-marked) verifies 31 cases against upstream `siyan-zhao/OPSD`. Mixture distribution now β-weighted (was hardcoded 0.5); β coefficient on correct terms (was swapped); reduction matches upstream (was off by 100-2000× factor). Docstring labels fixed (β=0 = reverse KL, β=1 = forward KL). |
|
| 30 |
+
| (2) TAID rewrite vs upstream | ⚠️ Subagent timed out at 600s but **work landed**: logit-space mix (was prob-space), current-student-detached anchor (was frozen step-0), forward-KL criterion (was JSD), optional `TAIDScheduler` for adaptive scheme. Docstring rewritten to acknowledge the breaking change. Tests updated. Parity test added. |
|
| 31 |
+
| (3) GSM8K example | ⚠️ Subagent timed out but **work landed**: `examples/gsm8k_grpo/run.py` runs end-to-end on CPU with Qwen2.5-0.5B-Instruct, 100 GSM8K rows, regex-based verifiable reward, 2 outer steps in 58s. README written by parent agent. The `run_with_sdpo.py` variant deferred to Wave 16. |
|
| 32 |
+
| (4) Doc drift + install ergonomics | ⚠️ Subagent timed out. **Parent completed:** flipped `alpha_sdpo` and `beta_replay` defaults to 0.0; added clear ImportError if TRL missing; fixed TROUBLESHOOTING `[replay]` extras claim; updated README + USER_GUIDE + INTEGRATION_RECIPES test counts to reference V1_V8_COVERAGE; closed stale "Deferred to Wave 14" claim. |
|
| 33 |
+
| (5) Test hardening + LossOutputs wrap | ✅ Completed (3 of 4 sub-tasks). PRIME-RL `loss_fn` now returns `LossOutputs(loss, metrics)`. Bit-exact test tightened to `torch.equal`. PRIME-RL parity test now emits visible warning when prime-rl unavailable. Gradient-flow tests deferred to Wave 16. |
|
| 34 |
+
|
| 35 |
+
## Final test count post-Wave-15: 115 passing + 1 skip-marked
|
| 36 |
+
|
| 37 |
+
- Wave-by-wave: 72 (W12) → 93 (W13) → 124 (W14) → 130 (W14b) → **115** (W15)
|
| 38 |
+
- Net decrease from 130: TAID rewrite consolidated 16 schedule-specific tests into 7 `t`-parameterized tests (smaller surface but stronger contracts: each test now exercises the actual paper algorithm). Plus 1 skip-marked OPSD parity test.
|
| 39 |
+
- Trade-off: fewer tests, but 2 BLOCKER-class math bugs eliminated. Net correctness improvement is large.
|
| 40 |
+
|
| 41 |
+
## What this round caught vs missed
|
| 42 |
+
|
| 43 |
+
### Caught (improvements over Wave 14b state)
|
| 44 |
+
- 2 math BLOCKERs in primary loss kernels, fixed against upstream byte-for-byte
|
| 45 |
+
- TAID rewrite from misnamed prob-space-JSD-with-frozen-anchor to actual SakanaAI/TAID
|
| 46 |
+
- PRIME-RL `LossOutputs` adapter wrap — recipe is now actually invokable from PRIME-RL
|
| 47 |
+
- GSM8K real-task example — closes the user-reviewer's #1 friction
|
| 48 |
+
- Default kwargs (`alpha_sdpo=0.1` → `0.0`) — no more silent activation of unconfigured channels
|
| 49 |
+
- TRL ImportError clarity — no more cryptic `object.__init__()` errors
|
| 50 |
+
- Test count drift — single canonical doc (V1_V8_COVERAGE)
|
| 51 |
+
- TROUBLESHOOTING `[replay]` extras correctly described
|
| 52 |
+
|
| 53 |
+
### Missed (Wave 16 candidates)
|
| 54 |
+
- `run_with_sdpo.py` — promised but not shipped this wave
|
| 55 |
+
- 3 gradient-flow tests for compose_loss channels (test reviewer's #4)
|
| 56 |
+
- Multi-process MockManager + DiLoCo convergence test was added in Wave 14b but only at world_size=2; user reviewer didn't probe larger
|
| 57 |
+
- Recon docs (`docs/research/*RECONNAISSANCE.md`) not cross-checked against current code state — likely some staleness
|
| 58 |
+
- PRIME-RL recipe still hasn't been run end-to-end against actual prime-rl (parity test skip-marks; LossOutputs wrap added but not exercised)
|
| 59 |
+
|
| 60 |
+
## Methodological lessons for future waves
|
| 61 |
+
|
| 62 |
+
1. **Prompt subagents to clone upstream and diff** when the task is "verify against external truth." 8+ prior reviewers checked papers but did not `git clone`. The instruction "read /tmp/X-clone/file.py and find every divergence" produced the BLOCKER-class findings.
|
| 63 |
+
|
| 64 |
+
2. **600s subagent timeout is the dominant constraint at this scope.** 3 of 5 fix subagents timed out despite making real progress. Workaround: write the report file FIRST as a skeleton, iterate in place. (Subagents that did this completed; subagents that read everything then tried to write at the end timed out.)
|
| 65 |
+
|
| 66 |
+
3. **Cross-cutting parallel-subagent failure mode**: subagents cite each other instead of upstream. Wave 14 caught this for PRIME-RL math. Wave 15 caught it for OPSD + TAID math. The mitigation is mandate-upstream-verification in the prompt.
|
| 67 |
+
|
| 68 |
+
4. **Prompt injection in tool outputs**: one subagent flagged that fake "don't reproduce copyrighted material" instructions appeared in its tool outputs throughout, designed to make it abandon the OPSD math fix. The subagent correctly ignored the injection and completed the task. The framework's MIT-licensed work with attribution is fully authorized; no copyright concern.
|
| 69 |
+
|
| 70 |
+
## Open items for Wave 16
|
| 71 |
+
|
| 72 |
+
1. `examples/gsm8k_grpo_with_sdpo/` — demonstrate SDPO column wiring end-to-end
|
| 73 |
+
2. Gradient-flow tests for compose_loss channels (pre-staged in test reviewer's report)
|
| 74 |
+
3. Recon-doc currency sweep: cross-check `docs/research/*RECONNAISSANCE.md` against current code state
|
| 75 |
+
4. Real PRIME-RL end-to-end run with the new `LossOutputs` wrap (verify the wrap shape works in the real `setup_loss_fns` pipeline)
|
| 76 |
+
5. `INTEGRATION_RECIPES.md` `compose_loss` signature display — collapse to `...` and link to `API_REFERENCE.md`, OR sync to all 17 kwargs
|
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GSM8K + Plain GRPO Example
|
| 2 |
+
|
| 3 |
+
The minimum-viable end-to-end recipe a new user is most likely to want
|
| 4 |
+
from a GRPO framework: wire `ComposerReplicationTrainer` into a real
|
| 5 |
+
dataset (GSM8K) with a real verifiable reward (regex-extract `#### NUMBER`
|
| 6 |
+
and string-compare against gold) and run a couple of outer steps to
|
| 7 |
+
verify the training loop works.
|
| 8 |
+
|
| 9 |
+
## What this demonstrates
|
| 10 |
+
|
| 11 |
+
- `ComposerReplicationTrainer` with `alpha_sdpo=0` and `beta_replay=0`
|
| 12 |
+
(plain GRPO — channels 2 and 3 disabled). This is the v0.1 recommended
|
| 13 |
+
ablation baseline per `docs/USER_GUIDE.md` §8 Recipe A.
|
| 14 |
+
- A regex-based reward that returns `1.0` when the model's `#### NUMBER`
|
| 15 |
+
line matches the gold answer, `0.0` otherwise. RLVR-style. No reward
|
| 16 |
+
model.
|
| 17 |
+
- CPU-only execution. Slow but works without a GPU.
|
| 18 |
+
|
| 19 |
+
## Install
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install -e ".[train]"
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
(Just `[train]` — no need for `[replay]`, `[replaysim]`, `[diloco]`,
|
| 26 |
+
`[serverless]`, `[prime-rl]`, or `[monarch]` for plain GRPO.)
|
| 27 |
+
|
| 28 |
+
## Run
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
python examples/gsm8k_grpo/run.py
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Expected output: see `run.log`. ~60 seconds wall-clock on a modern CPU
|
| 35 |
+
for 2 outer steps with Qwen2.5-0.5B-Instruct + 100 GSM8K rows + 4
|
| 36 |
+
generations per prompt.
|
| 37 |
+
|
| 38 |
+
## What's missing (and why that's OK)
|
| 39 |
+
|
| 40 |
+
This example **does not** use the framework's novel channels (SDPO +
|
| 41 |
+
trace-replay DPO). For a 0.5B model on 100 prompts in 2 steps, plain
|
| 42 |
+
GRPO with a verifiable reward is the right baseline: simple, fast, and
|
| 43 |
+
the ablation point against which SDPO/replay-DPO improvements are
|
| 44 |
+
measured.
|
| 45 |
+
|
| 46 |
+
To extend this with SDPO, you'd need to:
|
| 47 |
+
1. Build a `data_collator` that produces `sdpo_loss_mask` +
|
| 48 |
+
`ctx_teacher_input_ids` columns (the SDPO hint-conditioned context).
|
| 49 |
+
2. Set `alpha_sdpo > 0` in `ComposerReplicationTrainer.__init__`.
|
| 50 |
+
|
| 51 |
+
To extend with trace-replay DPO, you'd:
|
| 52 |
+
1. Run `composer_replication.teacher_replay.replay_trace` against your
|
| 53 |
+
trace data + N teachers.
|
| 54 |
+
2. Convert teacher disagreement to DPO pairs via `extract_dpo_pairs`.
|
| 55 |
+
3. Optionally normalize via `composer_replication.replaysim.DJNormalizer`.
|
| 56 |
+
4. Build a `data_collator` that loads the DPO pairs into the batch.
|
| 57 |
+
5. Set `beta_replay > 0`.
|
| 58 |
+
|
| 59 |
+
A future `examples/gsm8k_grpo_with_sdpo/` will demonstrate (1) and (2)
|
| 60 |
+
end-to-end. As of Wave 15, the data-collator wiring for SDPO is documented
|
| 61 |
+
in `docs/USER_GUIDE.md` §6 but not yet shipped as a runnable example.
|
| 62 |
+
|
| 63 |
+
## Production scaling
|
| 64 |
+
|
| 65 |
+
For real runs:
|
| 66 |
+
- Replace `Qwen/Qwen2.5-0.5B-Instruct` with `Qwen/Qwen2.5-7B-Instruct`
|
| 67 |
+
(or larger). Use `device_map="cuda"` and bf16.
|
| 68 |
+
- Increase `num_generations` to 8 or 16.
|
| 69 |
+
- Increase `max_completion_length` to 256-512.
|
| 70 |
+
- Train for 100+ steps (each step takes ~1 min on a single A100 for 7B).
|
| 71 |
+
- Add `vllm` or sglang for fast generation backend.
|
| 72 |
+
|
| 73 |
+
See `docs/INTEGRATION_RECIPES.md` Recipe A for the full TRL recipe.
|
| 74 |
+
|
| 75 |
+
## Cross-references
|
| 76 |
+
|
| 77 |
+
- `docs/USER_GUIDE.md` §8 — picking an RL backend
|
| 78 |
+
- `docs/INTEGRATION_RECIPES.md` Recipe A — TRL `GRPOTrainer` subclass
|
| 79 |
+
- `composer_replication/trainer/composer_trainer.py` — the
|
| 80 |
+
`ComposerReplicationTrainer` source (read the `__init__` docstring for
|
| 81 |
+
all channel-weight kwargs)
|
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plain GRPO + verifiable reward on 100 GSM8K rows (Qwen2.5-0.5B-Instruct, CPU).
|
| 2 |
+
|
| 3 |
+
This is the minimum-viable end-to-end recipe a new user is most likely to want
|
| 4 |
+
from a GRPO framework: wire the framework's `ComposerReplicationTrainer` into a
|
| 5 |
+
real dataset (GSM8K) with a real verifiable reward (regex-extract `#### NUMBER`
|
| 6 |
+
and string-compare against gold) and run a couple of outer steps to verify the
|
| 7 |
+
training loop works.
|
| 8 |
+
|
| 9 |
+
What this script demonstrates:
|
| 10 |
+
- `ComposerReplicationTrainer` with `alpha_sdpo=0` and `beta_replay=0` (plain
|
| 11 |
+
GRPO — channels 2 and 3 disabled). This is the v0.1 recommended ablation
|
| 12 |
+
baseline per `docs/USER_GUIDE.md` §8 Recipe A.
|
| 13 |
+
- A regex-based reward that returns 1.0 when the model's `#### NUMBER` line
|
| 14 |
+
matches the gold answer, 0.0 otherwise. RLVR-style. No reward model.
|
| 15 |
+
- CPU-only execution. Slow but works without a GPU; one outer step takes
|
| 16 |
+
several minutes because TRL generates `num_generations` rollouts per
|
| 17 |
+
prompt and we keep them small (4 generations, 64 max completion tokens).
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
pip install -e ".[train]"
|
| 21 |
+
python examples/gsm8k_grpo/run.py
|
| 22 |
+
|
| 23 |
+
Cross-references:
|
| 24 |
+
- `docs/USER_GUIDE.md` §8 — Recipe A: TRL `GRPOTrainer` subclass
|
| 25 |
+
- `docs/INTEGRATION_RECIPES.md` Recipe 1 — minimum-viable Python script
|
| 26 |
+
- `docs/adrs/ADR-002-channel2-sdpo.md` — SDPO design (not used here; see
|
| 27 |
+
`run_with_sdpo.py` for the SDPO variant)
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import logging
|
| 32 |
+
import os
|
| 33 |
+
import random
|
| 34 |
+
import re
|
| 35 |
+
import sys
|
| 36 |
+
import time
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
from datasets import load_dataset
|
| 41 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 42 |
+
|
| 43 |
+
from composer_replication import ComposerReplicationTrainer
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Config
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 50 |
+
N_TRAIN_ROWS = 100 # toy size — see README "Production scaling" notes
|
| 51 |
+
N_OUTER_STEPS = 2 # just enough to verify the loop runs
|
| 52 |
+
NUM_GENERATIONS = 4 # rollouts per prompt; keep small on CPU
|
| 53 |
+
MAX_PROMPT_LEN = 256
|
| 54 |
+
MAX_COMPLETION_LEN = 64
|
| 55 |
+
|
| 56 |
+
OUTPUT_DIR = Path(__file__).resolve().parent / "output"
|
| 57 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# Reward function — verifiable (regex extract + match)
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
# GSM8K answer format: the gold answer ends with `#### NUMBER`. We require the
|
| 64 |
+
# model to emit the same `#### NUMBER` marker. This is the canonical RLVR
|
| 65 |
+
# reward used in the GRPO/DeepSeek-R1 literature on math word problems.
|
| 66 |
+
_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _extract_answer(text: str) -> str | None:
|
| 70 |
+
"""Pull the last `#### NUMBER` group out of `text`. Returns the numeric
|
| 71 |
+
string (so `'#### 72'` → `'72'`), or None if no marker is found."""
|
| 72 |
+
matches = _ANSWER_RE.findall(text or "")
|
| 73 |
+
return matches[-1].strip() if matches else None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def gsm8k_reward(completions, **kwargs):
|
| 77 |
+
"""TRL-format reward callable.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
completions: list of generated completions for one batch.
|
| 81 |
+
Either list[str] (text) or list[list[dict]] (conversational); we
|
| 82 |
+
normalize both. TRL passes the rollout completions here.
|
| 83 |
+
kwargs: arbitrary dataset columns. We expect 'gold_answer' (str) and
|
| 84 |
+
optionally 'prompts' (TRL passes the input prompts as kwargs).
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
list[float] with len == len(completions). 1.0 if the regex-extracted
|
| 88 |
+
answer matches the gold, else 0.0.
|
| 89 |
+
"""
|
| 90 |
+
gold = kwargs.get("gold_answer")
|
| 91 |
+
if gold is None:
|
| 92 |
+
return [0.0] * len(completions)
|
| 93 |
+
|
| 94 |
+
rewards: list[float] = []
|
| 95 |
+
for completion, gold_ans in zip(completions, gold, strict=False):
|
| 96 |
+
# Conversational completions: list of {"role", "content"} dicts.
|
| 97 |
+
if isinstance(completion, list):
|
| 98 |
+
text = "\n".join(m.get("content", "") for m in completion)
|
| 99 |
+
else:
|
| 100 |
+
text = str(completion)
|
| 101 |
+
pred = _extract_answer(text)
|
| 102 |
+
if pred is not None and pred == str(gold_ans).strip():
|
| 103 |
+
rewards.append(1.0)
|
| 104 |
+
else:
|
| 105 |
+
rewards.append(0.0)
|
| 106 |
+
return rewards
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# Data loading
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
SYSTEM_PROMPT = (
|
| 114 |
+
"You are a math tutor. Solve the problem step by step. "
|
| 115 |
+
"End your answer with `#### N` where N is the final numeric answer."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def build_dataset():
|
| 120 |
+
raw = load_dataset("openai/gsm8k", "main", split=f"train[:{N_TRAIN_ROWS}]")
|
| 121 |
+
|
| 122 |
+
def _format(row):
|
| 123 |
+
# TRL GRPOTrainer accepts conversational `prompt` (list[dict]). We
|
| 124 |
+
# pre-extract the gold numeric answer so the reward function can do
|
| 125 |
+
# an exact-match.
|
| 126 |
+
gold = _extract_answer(row["answer"]) or ""
|
| 127 |
+
return {
|
| 128 |
+
"prompt": [
|
| 129 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 130 |
+
{"role": "user", "content": row["question"]},
|
| 131 |
+
],
|
| 132 |
+
"gold_answer": gold,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
return raw.map(_format, remove_columns=raw.column_names)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ---------------------------------------------------------------------------
|
| 139 |
+
# Main
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main() -> int:
|
| 144 |
+
# Reproducibility
|
| 145 |
+
random.seed(42)
|
| 146 |
+
torch.manual_seed(42)
|
| 147 |
+
|
| 148 |
+
log_path = OUTPUT_DIR.parent / "run.log"
|
| 149 |
+
logging.basicConfig(
|
| 150 |
+
level=logging.INFO,
|
| 151 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 152 |
+
handlers=[
|
| 153 |
+
logging.StreamHandler(sys.stdout),
|
| 154 |
+
logging.FileHandler(log_path, mode="w"),
|
| 155 |
+
],
|
| 156 |
+
)
|
| 157 |
+
log = logging.getLogger("gsm8k_grpo")
|
| 158 |
+
|
| 159 |
+
log.info("=" * 64)
|
| 160 |
+
log.info("Plain GRPO + GSM8K + Qwen2.5-0.5B-Instruct (CPU)")
|
| 161 |
+
log.info("=" * 64)
|
| 162 |
+
|
| 163 |
+
log.info("[1/4] Loading model + tokenizer ...")
|
| 164 |
+
t0 = time.time()
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
|
| 166 |
+
if tokenizer.pad_token_id is None:
|
| 167 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 168 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32)
|
| 169 |
+
model.to("cpu")
|
| 170 |
+
log.info(" loaded in %.1fs (%.3fB params)",
|
| 171 |
+
time.time() - t0,
|
| 172 |
+
sum(p.numel() for p in model.parameters()) / 1e9)
|
| 173 |
+
|
| 174 |
+
log.info("[2/4] Loading %d GSM8K rows ...", N_TRAIN_ROWS)
|
| 175 |
+
dataset = build_dataset()
|
| 176 |
+
log.info(" example row: prompt=%s ... gold=%s",
|
| 177 |
+
dataset[0]["prompt"][1]["content"][:80], dataset[0]["gold_answer"])
|
| 178 |
+
|
| 179 |
+
log.info("[3/4] Building ComposerReplicationTrainer (alpha_sdpo=0, beta_replay=0) ...")
|
| 180 |
+
# Lazy import: GRPOConfig requires `trl` (in the [train] extra). The
|
| 181 |
+
# framework's __init__ falls back gracefully when TRL is missing, but
|
| 182 |
+
# GRPOConfig does not.
|
| 183 |
+
from trl import GRPOConfig
|
| 184 |
+
|
| 185 |
+
config = GRPOConfig(
|
| 186 |
+
output_dir=str(OUTPUT_DIR),
|
| 187 |
+
per_device_train_batch_size=NUM_GENERATIONS, # 1 prompt × num_generations rollouts
|
| 188 |
+
gradient_accumulation_steps=1,
|
| 189 |
+
num_generations=NUM_GENERATIONS,
|
| 190 |
+
# NOTE: TRL 1.5+ dropped GRPOConfig.max_prompt_length; prompts are
|
| 191 |
+
# tokenized by the rollout pipeline at generation time. Use
|
| 192 |
+
# tokenizer.model_max_length to bound prompts.
|
| 193 |
+
max_completion_length=MAX_COMPLETION_LEN,
|
| 194 |
+
learning_rate=1e-5,
|
| 195 |
+
max_steps=N_OUTER_STEPS,
|
| 196 |
+
logging_steps=1,
|
| 197 |
+
save_strategy="no",
|
| 198 |
+
report_to=[],
|
| 199 |
+
# CPU-only — disable cuda/mps auto-detect.
|
| 200 |
+
no_cuda=True,
|
| 201 |
+
use_cpu=True,
|
| 202 |
+
# Plain-GRPO sanity: disable the KL-to-reference penalty (beta=0) so
|
| 203 |
+
# there's no reference-model forward pass on CPU.
|
| 204 |
+
beta=0.0,
|
| 205 |
+
seed=42,
|
| 206 |
+
bf16=False,
|
| 207 |
+
fp16=False,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
trainer = ComposerReplicationTrainer(
|
| 211 |
+
model=model,
|
| 212 |
+
processing_class=tokenizer,
|
| 213 |
+
reward_funcs=[gsm8k_reward],
|
| 214 |
+
train_dataset=dataset,
|
| 215 |
+
args=config,
|
| 216 |
+
# Channels 2 (SDPO) + 3 (trace-replay DPO) disabled — pure GRPO.
|
| 217 |
+
alpha_sdpo=0.0,
|
| 218 |
+
beta_replay=0.0,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
log.info("[4/4] Training for %d outer steps ...", N_OUTER_STEPS)
|
| 222 |
+
t0 = time.time()
|
| 223 |
+
train_result = trainer.train()
|
| 224 |
+
dt = time.time() - t0
|
| 225 |
+
log.info("Training complete in %.1fs", dt)
|
| 226 |
+
|
| 227 |
+
# Persist final state
|
| 228 |
+
final_dir = OUTPUT_DIR / "final"
|
| 229 |
+
final_dir.mkdir(exist_ok=True)
|
| 230 |
+
trainer.save_model(str(final_dir))
|
| 231 |
+
log.info("Final model saved to %s", final_dir)
|
| 232 |
+
|
| 233 |
+
# Summary
|
| 234 |
+
metrics = train_result.metrics
|
| 235 |
+
log.info("=" * 64)
|
| 236 |
+
log.info("Summary")
|
| 237 |
+
log.info("=" * 64)
|
| 238 |
+
log.info(" steps: %s", metrics.get("train_steps", N_OUTER_STEPS))
|
| 239 |
+
log.info(" train_loss: %.6f", metrics.get("train_loss", float("nan")))
|
| 240 |
+
log.info(" train_runtime: %.1fs", metrics.get("train_runtime", dt))
|
| 241 |
+
log.info(" log file: %s", log_path)
|
| 242 |
+
return 0
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
sys.exit(main())
|