Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Integration tests for ADR-007 distillation kwargs in compose_loss. | |
| These tests exercise the wiring between `compose_loss` and the three | |
| pluggable losses (SimPO, TAID, Entropy-Aware OPD). They use a tiny | |
| hand-rolled language model wrapper (no HF, no TRL) so the tests run | |
| in <1s on CPU and are isolated from external library churn. | |
| Coverage requirements: | |
| (a) defaults reproduce existing compose_loss output bit-exact | |
| (b) dpo_variant='simpo' produces a different total than dpo | |
| (c) sdpo_wrapper='taid' with t=0 differs from t=1 (interpolation works) | |
| (d) sdpo_wrapper='taid' with t=1 reproduces upstream forward-KL | |
| (e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar | |
| (f) error case: sdpo_wrapper='taid' without taid_t raises ValueError | |
| """ | |
| from __future__ import annotations | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from composer_replication import LossComponents, compose_loss | |
| # ---------------------------------------------------------------------- | |
| # Tiny LM stand-in | |
| # ---------------------------------------------------------------------- | |
| class TinyLM(nn.Module): | |
| """Minimal `nn.Module` with the HF-style `model(input_ids=...).logits` API. | |
| Vocab=32, hidden=16, two-layer MLP head. Tiny enough that all tests | |
| run in milliseconds on CPU. | |
| """ | |
| def __init__(self, vocab: int = 32, hidden: int = 16, seed: int = 0): | |
| super().__init__() | |
| torch.manual_seed(seed) | |
| self.embed = nn.Embedding(vocab, hidden) | |
| self.fc = nn.Linear(hidden, hidden) | |
| self.head = nn.Linear(hidden, vocab) | |
| def forward(self, input_ids: torch.Tensor): | |
| h = torch.tanh(self.fc(self.embed(input_ids))) | |
| logits = self.head(h) | |
| class _Out: | |
| pass | |
| out = _Out() | |
| out.logits = logits | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Batch fixtures | |
| # ---------------------------------------------------------------------- | |
| VOCAB = 32 | |
| B = 2 | |
| T = 8 | |
| def _base_batch(seed: int = 7, *, with_dpo: bool = True) -> dict[str, torch.Tensor]: | |
| """Build a deterministic input batch with all 3 channels populated.""" | |
| g = torch.Generator().manual_seed(seed) | |
| inputs: dict[str, torch.Tensor] = { | |
| "input_ids": torch.randint(0, VOCAB, (B, T), generator=g), | |
| "response_mask": torch.zeros(B, T, dtype=torch.long), | |
| "ctx_teacher_input_ids": torch.randint(0, VOCAB, (B, T), generator=g), | |
| "sdpo_loss_mask": torch.zeros(B, T, dtype=torch.long), | |
| } | |
| # Mark the second half as response tokens so the LM-CE channel is non-trivial. | |
| inputs["response_mask"][:, T // 2:] = 1 | |
| inputs["sdpo_loss_mask"][:, T // 2:] = 1 | |
| if with_dpo: | |
| inputs["dpo_chosen_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g) | |
| inputs["dpo_chosen_response_mask"] = torch.ones(B, T, dtype=torch.long) | |
| inputs["dpo_rejected_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g) | |
| inputs["dpo_rejected_response_mask"] = torch.ones(B, T, dtype=torch.long) | |
| # Standard DPO needs ref logprobs; SimPO ignores them. | |
| inputs["dpo_chosen_ref_logprobs"] = torch.randn(B, generator=g) | |
| inputs["dpo_rejected_ref_logprobs"] = torch.randn(B, generator=g) | |
| return inputs | |
| def _model_seeded(seed: int = 0) -> TinyLM: | |
| m = TinyLM(vocab=VOCAB, hidden=16, seed=seed) | |
| m.eval() # Deterministic forward — no dropout. | |
| return m | |
| # ---------------------------------------------------------------------- | |
| # (a) Defaults reproduce existing output bit-exact | |
| # ---------------------------------------------------------------------- | |
| def test_defaults_bit_exact_with_legacy_kwargs(): | |
| """Calling compose_loss with new kwargs at their defaults must equal | |
| calling it with only the legacy kwargs. Bit-exact: every channel + | |
| total agree to 0 ULPs because the code path is identical. | |
| """ | |
| inputs = _base_batch() | |
| model_a = _model_seeded(seed=0) | |
| out_legacy = compose_loss( | |
| model_a, | |
| inputs, | |
| alpha_sdpo=0.1, | |
| beta_replay=0.05, | |
| sdpo_jsd_beta=0.5, | |
| sdpo_temperature=1.0, | |
| replay_dpo_beta=0.1, | |
| ) | |
| model_b = _model_seeded(seed=0) | |
| out_new = compose_loss( | |
| model_b, | |
| inputs, | |
| alpha_sdpo=0.1, | |
| beta_replay=0.05, | |
| sdpo_jsd_beta=0.5, | |
| sdpo_temperature=1.0, | |
| replay_dpo_beta=0.1, | |
| dpo_variant="dpo", | |
| sdpo_wrapper="none", | |
| ) | |
| assert isinstance(out_new, LossComponents) | |
| assert torch.equal(out_legacy.lm_ce, out_new.lm_ce) | |
| assert torch.equal(out_legacy.sdpo_jsd, out_new.sdpo_jsd) | |
| assert torch.equal(out_legacy.trace_replay_dpo, out_new.trace_replay_dpo) | |
| assert torch.equal(out_legacy.total, out_new.total) | |
| # ---------------------------------------------------------------------- | |
| # (b) dpo_variant='simpo' produces a different total than dpo | |
| # ---------------------------------------------------------------------- | |
| def test_simpo_variant_changes_total(): | |
| """SimPO uses average-logprob and drops the reference subtraction, so | |
| it must produce a different (and finite) trace_replay_dpo + total.""" | |
| inputs = _base_batch() | |
| model_a = _model_seeded(seed=0) | |
| out_dpo = compose_loss( | |
| model_a, inputs, | |
| alpha_sdpo=0.0, # isolate channel 3 | |
| beta_replay=0.05, | |
| dpo_variant="dpo", | |
| ) | |
| model_b = _model_seeded(seed=0) | |
| out_simpo = compose_loss( | |
| model_b, inputs, | |
| alpha_sdpo=0.0, | |
| beta_replay=0.05, | |
| dpo_variant="simpo", | |
| ) | |
| assert torch.isfinite(out_simpo.total) | |
| assert torch.isfinite(out_simpo.trace_replay_dpo) | |
| # Different formulae => different values. | |
| assert not torch.allclose( | |
| out_dpo.trace_replay_dpo, out_simpo.trace_replay_dpo | |
| ) | |
| assert not torch.allclose(out_dpo.total, out_simpo.total) | |
| # Gradient flow check. | |
| out_simpo.total.backward() | |
| assert any( | |
| p.grad is not None and torch.isfinite(p.grad).all() | |
| for p in model_b.parameters() | |
| ) | |
| def test_simpo_does_not_require_ref_logprobs(): | |
| """SimPO is reference-free; compose_loss should run when those keys are | |
| absent from `inputs` (only when dpo_variant='simpo').""" | |
| inputs = _base_batch() | |
| inputs.pop("dpo_chosen_ref_logprobs") | |
| inputs.pop("dpo_rejected_ref_logprobs") | |
| model = _model_seeded(seed=0) | |
| out = compose_loss( | |
| model, inputs, | |
| alpha_sdpo=0.0, | |
| beta_replay=0.05, | |
| dpo_variant="simpo", | |
| ) | |
| assert torch.isfinite(out.total) | |
| assert torch.isfinite(out.trace_replay_dpo) | |
| # ---------------------------------------------------------------------- | |
| # (c) TAID with t=1 reproduces upstream forward-KL on the masked tokens | |
| # ---------------------------------------------------------------------- | |
| def test_taid_t_one_matches_upstream_forward_kl(): | |
| """At t=1, taid_loss reduces to forward-KL with target = softmax(teacher). | |
| compose_loss should plumb through to that exact value (modulo the | |
| sdpo_loss_mask token-mean denominator). | |
| """ | |
| import torch.nn.functional as F | |
| inputs = _base_batch(with_dpo=False) | |
| model = _model_seeded(seed=1) | |
| # Run compose_loss with TAID at t=1. | |
| out_taid = compose_loss( | |
| model, inputs, | |
| alpha_sdpo=1.0, # so out.sdpo_jsd is added straight to total | |
| beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| taid_t=1.0, | |
| ) | |
| # Manually compute the same forward-KL on the masked tokens. | |
| student_logits = model(input_ids=inputs["input_ids"]).logits | |
| with torch.no_grad(): | |
| teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits | |
| mask = inputs["sdpo_loss_mask"].float() | |
| p_teacher = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) | |
| log_q = F.log_softmax(student_logits, dim=-1, dtype=torch.float32) | |
| per_token = -(p_teacher * log_q).sum(dim=-1) | |
| flat = per_token.reshape(-1) | |
| fmask = mask.reshape(-1).to(flat.dtype) | |
| expected = (flat * fmask).sum() / fmask.sum().clamp_min(1.0) | |
| # Bit-exact assertion. The TAID-loss path at t=1 is mathematically | |
| # identical to the manual `-(p_teacher * log_q).sum(...)` cross-entropy | |
| # below: at t=1, TAID's logit-space mix collapses to `teacher_logits`, | |
| # `softmax(teacher_logits)` is computed bit-identically inside | |
| # `taid_loss`, and the masked-mean reduction matches. So `torch.equal` | |
| # succeeds — and asserting `equal` rather than `allclose` catches any | |
| # future refactor that re-introduces a softmax→log roundtrip with | |
| # ULP drift. | |
| # | |
| # If a future change forces a roundtrip we cannot eliminate, drop to | |
| # `torch.testing.assert_close(out_taid.sdpo_jsd, expected, | |
| # atol=1e-7, rtol=0)` — that is the strict-but-feasible bound for | |
| # softmax→log→softmax in float32 (one ULP at the scale of the loss, | |
| # ~3.5e-7 here, dominated by the log_softmax LSE accumulation). | |
| assert torch.equal(out_taid.sdpo_jsd, expected), ( | |
| f"TAID t=1 must equal upstream forward-KL bit-exact; " | |
| f"got out={out_taid.sdpo_jsd.item()!r}, " | |
| f"expected={expected.item()!r}, " | |
| f"diff={(out_taid.sdpo_jsd - expected).abs().item():.3e}" | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # (d) TAID interpolates: t=0 differs from t=1 | |
| # ---------------------------------------------------------------------- | |
| def test_taid_interpolates_with_t(): | |
| """Different t values give different sdpo_jsd. Differentiable end-to-end.""" | |
| inputs = _base_batch(with_dpo=False) | |
| model_zero = _model_seeded(seed=2) | |
| out_zero = compose_loss( | |
| model_zero, inputs, | |
| alpha_sdpo=0.1, beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| taid_t=0.0, | |
| ) | |
| model_mid = _model_seeded(seed=2) | |
| out_mid = compose_loss( | |
| model_mid, inputs, | |
| alpha_sdpo=0.1, beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| taid_t=0.5, | |
| ) | |
| model_one = _model_seeded(seed=2) | |
| out_one = compose_loss( | |
| model_one, inputs, | |
| alpha_sdpo=0.1, beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| taid_t=1.0, | |
| ) | |
| for out in (out_zero, out_mid, out_one): | |
| assert torch.isfinite(out.total) | |
| assert torch.isfinite(out.sdpo_jsd) | |
| assert not torch.allclose(out_zero.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5) | |
| assert not torch.allclose(out_mid.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5) | |
| out_mid.total.backward() | |
| assert any( | |
| p.grad is not None and torch.isfinite(p.grad).all() | |
| for p in model_mid.parameters() | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # (e) Entropy-Aware OPD returns a finite differentiable scalar | |
| # ---------------------------------------------------------------------- | |
| def test_entropy_opd_returns_finite_differentiable_scalar(): | |
| inputs = _base_batch(with_dpo=False) | |
| model = _model_seeded(seed=3) | |
| out = compose_loss( | |
| model, inputs, | |
| alpha_sdpo=0.1, | |
| beta_replay=0.0, | |
| sdpo_wrapper="entropy_opd", | |
| ) | |
| assert isinstance(out, LossComponents) | |
| assert out.total.shape == () | |
| assert torch.isfinite(out.total) | |
| assert torch.isfinite(out.sdpo_jsd) | |
| assert out.total.requires_grad | |
| out.total.backward() | |
| grads = [p.grad for p in model.parameters() if p.grad is not None] | |
| assert len(grads) > 0 | |
| assert all(torch.isfinite(g).all() for g in grads) | |
| # ---------------------------------------------------------------------- | |
| # (f) Error: sdpo_wrapper='taid' without taid_t | |
| # ---------------------------------------------------------------------- | |
| def test_taid_requires_t(): | |
| inputs = _base_batch(with_dpo=False) | |
| model = _model_seeded(seed=4) | |
| with pytest.raises(ValueError, match="taid_t"): | |
| compose_loss( | |
| model, inputs, | |
| alpha_sdpo=0.1, beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| # taid_t omitted on purpose | |
| ) | |
| def test_taid_t_out_of_range_raises(): | |
| inputs = _base_batch(with_dpo=False) | |
| model = _model_seeded(seed=4) | |
| with pytest.raises(ValueError, match=r"taid_t must be in \[0, 1\]"): | |
| compose_loss( | |
| model, inputs, | |
| alpha_sdpo=0.1, beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| taid_t=1.5, | |
| ) | |
| def test_invalid_dpo_variant_raises(): | |
| inputs = _base_batch() | |
| model = _model_seeded(seed=5) | |
| with pytest.raises(ValueError, match="dpo_variant"): | |
| compose_loss( | |
| model, inputs, | |
| dpo_variant="bogus", # type: ignore[arg-type] | |
| ) | |
| def test_invalid_sdpo_wrapper_raises(): | |
| inputs = _base_batch() | |
| model = _model_seeded(seed=5) | |
| with pytest.raises(ValueError, match="sdpo_wrapper"): | |
| compose_loss( | |
| model, inputs, | |
| sdpo_wrapper="bogus", # type: ignore[arg-type] | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Bonus: TAIDScheduler integration | |
| # ---------------------------------------------------------------------- | |
| def test_taid_compose_with_scheduler(): | |
| """End-to-end: TAIDScheduler drives taid_t into compose_loss.""" | |
| from composer_replication.distillation import TAIDScheduler | |
| inputs = _base_batch(with_dpo=False) | |
| model = _model_seeded(seed=6) | |
| sched = TAIDScheduler(num_train_steps=100, t_start=0.4) | |
| for step in range(3): | |
| out = compose_loss( | |
| model, inputs, | |
| alpha_sdpo=0.1, beta_replay=0.0, | |
| sdpo_wrapper="taid", | |
| taid_t=sched.t, | |
| ) | |
| assert torch.isfinite(out.total) | |
| sched.update_t(out.sdpo_jsd.detach(), global_step=step) | |
| # t may have advanced past t_start after some steps (or stayed the same | |
| # given small num_train_steps and only 3 iters; just check it's still | |
| # in-range). | |
| assert 0.4 <= sched.t <= 1.0 | |