"""Integration tests for ADR-007 distillation kwargs in compose_loss. These tests exercise the wiring between `compose_loss` and the three pluggable losses (SimPO, TAID, Entropy-Aware OPD). They use a tiny hand-rolled language model wrapper (no HF, no TRL) so the tests run in <1s on CPU and are isolated from external library churn. Coverage requirements: (a) defaults reproduce existing compose_loss output bit-exact (b) dpo_variant='simpo' produces a different total than dpo (c) sdpo_wrapper='taid' with t=0 differs from t=1 (interpolation works) (d) sdpo_wrapper='taid' with t=1 reproduces upstream forward-KL (e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar (f) error case: sdpo_wrapper='taid' without taid_t raises ValueError """ from __future__ import annotations import pytest import torch import torch.nn as nn from composer_replication import LossComponents, compose_loss # ---------------------------------------------------------------------- # Tiny LM stand-in # ---------------------------------------------------------------------- class TinyLM(nn.Module): """Minimal `nn.Module` with the HF-style `model(input_ids=...).logits` API. Vocab=32, hidden=16, two-layer MLP head. Tiny enough that all tests run in milliseconds on CPU. """ def __init__(self, vocab: int = 32, hidden: int = 16, seed: int = 0): super().__init__() torch.manual_seed(seed) self.embed = nn.Embedding(vocab, hidden) self.fc = nn.Linear(hidden, hidden) self.head = nn.Linear(hidden, vocab) def forward(self, input_ids: torch.Tensor): h = torch.tanh(self.fc(self.embed(input_ids))) logits = self.head(h) class _Out: pass out = _Out() out.logits = logits return out # ---------------------------------------------------------------------- # Batch fixtures # ---------------------------------------------------------------------- VOCAB = 32 B = 2 T = 8 def _base_batch(seed: int = 7, *, with_dpo: bool = True) -> dict[str, torch.Tensor]: """Build a deterministic input batch with all 3 channels populated.""" g = torch.Generator().manual_seed(seed) inputs: dict[str, torch.Tensor] = { "input_ids": torch.randint(0, VOCAB, (B, T), generator=g), "response_mask": torch.zeros(B, T, dtype=torch.long), "ctx_teacher_input_ids": torch.randint(0, VOCAB, (B, T), generator=g), "sdpo_loss_mask": torch.zeros(B, T, dtype=torch.long), } # Mark the second half as response tokens so the LM-CE channel is non-trivial. inputs["response_mask"][:, T // 2:] = 1 inputs["sdpo_loss_mask"][:, T // 2:] = 1 if with_dpo: inputs["dpo_chosen_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g) inputs["dpo_chosen_response_mask"] = torch.ones(B, T, dtype=torch.long) inputs["dpo_rejected_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g) inputs["dpo_rejected_response_mask"] = torch.ones(B, T, dtype=torch.long) # Standard DPO needs ref logprobs; SimPO ignores them. inputs["dpo_chosen_ref_logprobs"] = torch.randn(B, generator=g) inputs["dpo_rejected_ref_logprobs"] = torch.randn(B, generator=g) return inputs def _model_seeded(seed: int = 0) -> TinyLM: m = TinyLM(vocab=VOCAB, hidden=16, seed=seed) m.eval() # Deterministic forward — no dropout. return m # ---------------------------------------------------------------------- # (a) Defaults reproduce existing output bit-exact # ---------------------------------------------------------------------- def test_defaults_bit_exact_with_legacy_kwargs(): """Calling compose_loss with new kwargs at their defaults must equal calling it with only the legacy kwargs. Bit-exact: every channel + total agree to 0 ULPs because the code path is identical. """ inputs = _base_batch() model_a = _model_seeded(seed=0) out_legacy = compose_loss( model_a, inputs, alpha_sdpo=0.1, beta_replay=0.05, sdpo_jsd_beta=0.5, sdpo_temperature=1.0, replay_dpo_beta=0.1, ) model_b = _model_seeded(seed=0) out_new = compose_loss( model_b, inputs, alpha_sdpo=0.1, beta_replay=0.05, sdpo_jsd_beta=0.5, sdpo_temperature=1.0, replay_dpo_beta=0.1, dpo_variant="dpo", sdpo_wrapper="none", ) assert isinstance(out_new, LossComponents) assert torch.equal(out_legacy.lm_ce, out_new.lm_ce) assert torch.equal(out_legacy.sdpo_jsd, out_new.sdpo_jsd) assert torch.equal(out_legacy.trace_replay_dpo, out_new.trace_replay_dpo) assert torch.equal(out_legacy.total, out_new.total) # ---------------------------------------------------------------------- # (b) dpo_variant='simpo' produces a different total than dpo # ---------------------------------------------------------------------- def test_simpo_variant_changes_total(): """SimPO uses average-logprob and drops the reference subtraction, so it must produce a different (and finite) trace_replay_dpo + total.""" inputs = _base_batch() model_a = _model_seeded(seed=0) out_dpo = compose_loss( model_a, inputs, alpha_sdpo=0.0, # isolate channel 3 beta_replay=0.05, dpo_variant="dpo", ) model_b = _model_seeded(seed=0) out_simpo = compose_loss( model_b, inputs, alpha_sdpo=0.0, beta_replay=0.05, dpo_variant="simpo", ) assert torch.isfinite(out_simpo.total) assert torch.isfinite(out_simpo.trace_replay_dpo) # Different formulae => different values. assert not torch.allclose( out_dpo.trace_replay_dpo, out_simpo.trace_replay_dpo ) assert not torch.allclose(out_dpo.total, out_simpo.total) # Gradient flow check. out_simpo.total.backward() assert any( p.grad is not None and torch.isfinite(p.grad).all() for p in model_b.parameters() ) def test_simpo_does_not_require_ref_logprobs(): """SimPO is reference-free; compose_loss should run when those keys are absent from `inputs` (only when dpo_variant='simpo').""" inputs = _base_batch() inputs.pop("dpo_chosen_ref_logprobs") inputs.pop("dpo_rejected_ref_logprobs") model = _model_seeded(seed=0) out = compose_loss( model, inputs, alpha_sdpo=0.0, beta_replay=0.05, dpo_variant="simpo", ) assert torch.isfinite(out.total) assert torch.isfinite(out.trace_replay_dpo) # ---------------------------------------------------------------------- # (c) TAID with t=1 reproduces upstream forward-KL on the masked tokens # ---------------------------------------------------------------------- def test_taid_t_one_matches_upstream_forward_kl(): """At t=1, taid_loss reduces to forward-KL with target = softmax(teacher). compose_loss should plumb through to that exact value (modulo the sdpo_loss_mask token-mean denominator). """ import torch.nn.functional as F inputs = _base_batch(with_dpo=False) model = _model_seeded(seed=1) # Run compose_loss with TAID at t=1. out_taid = compose_loss( model, inputs, alpha_sdpo=1.0, # so out.sdpo_jsd is added straight to total beta_replay=0.0, sdpo_wrapper="taid", taid_t=1.0, ) # Manually compute the same forward-KL on the masked tokens. student_logits = model(input_ids=inputs["input_ids"]).logits with torch.no_grad(): teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits mask = inputs["sdpo_loss_mask"].float() p_teacher = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) log_q = F.log_softmax(student_logits, dim=-1, dtype=torch.float32) per_token = -(p_teacher * log_q).sum(dim=-1) flat = per_token.reshape(-1) fmask = mask.reshape(-1).to(flat.dtype) expected = (flat * fmask).sum() / fmask.sum().clamp_min(1.0) # Bit-exact assertion. The TAID-loss path at t=1 is mathematically # identical to the manual `-(p_teacher * log_q).sum(...)` cross-entropy # below: at t=1, TAID's logit-space mix collapses to `teacher_logits`, # `softmax(teacher_logits)` is computed bit-identically inside # `taid_loss`, and the masked-mean reduction matches. So `torch.equal` # succeeds — and asserting `equal` rather than `allclose` catches any # future refactor that re-introduces a softmax→log roundtrip with # ULP drift. # # If a future change forces a roundtrip we cannot eliminate, drop to # `torch.testing.assert_close(out_taid.sdpo_jsd, expected, # atol=1e-7, rtol=0)` — that is the strict-but-feasible bound for # softmax→log→softmax in float32 (one ULP at the scale of the loss, # ~3.5e-7 here, dominated by the log_softmax LSE accumulation). assert torch.equal(out_taid.sdpo_jsd, expected), ( f"TAID t=1 must equal upstream forward-KL bit-exact; " f"got out={out_taid.sdpo_jsd.item()!r}, " f"expected={expected.item()!r}, " f"diff={(out_taid.sdpo_jsd - expected).abs().item():.3e}" ) # ---------------------------------------------------------------------- # (d) TAID interpolates: t=0 differs from t=1 # ---------------------------------------------------------------------- def test_taid_interpolates_with_t(): """Different t values give different sdpo_jsd. Differentiable end-to-end.""" inputs = _base_batch(with_dpo=False) model_zero = _model_seeded(seed=2) out_zero = compose_loss( model_zero, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="taid", taid_t=0.0, ) model_mid = _model_seeded(seed=2) out_mid = compose_loss( model_mid, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="taid", taid_t=0.5, ) model_one = _model_seeded(seed=2) out_one = compose_loss( model_one, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="taid", taid_t=1.0, ) for out in (out_zero, out_mid, out_one): assert torch.isfinite(out.total) assert torch.isfinite(out.sdpo_jsd) assert not torch.allclose(out_zero.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5) assert not torch.allclose(out_mid.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5) out_mid.total.backward() assert any( p.grad is not None and torch.isfinite(p.grad).all() for p in model_mid.parameters() ) # ---------------------------------------------------------------------- # (e) Entropy-Aware OPD returns a finite differentiable scalar # ---------------------------------------------------------------------- def test_entropy_opd_returns_finite_differentiable_scalar(): inputs = _base_batch(with_dpo=False) model = _model_seeded(seed=3) out = compose_loss( model, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="entropy_opd", ) assert isinstance(out, LossComponents) assert out.total.shape == () assert torch.isfinite(out.total) assert torch.isfinite(out.sdpo_jsd) assert out.total.requires_grad out.total.backward() grads = [p.grad for p in model.parameters() if p.grad is not None] assert len(grads) > 0 assert all(torch.isfinite(g).all() for g in grads) # ---------------------------------------------------------------------- # (f) Error: sdpo_wrapper='taid' without taid_t # ---------------------------------------------------------------------- def test_taid_requires_t(): inputs = _base_batch(with_dpo=False) model = _model_seeded(seed=4) with pytest.raises(ValueError, match="taid_t"): compose_loss( model, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="taid", # taid_t omitted on purpose ) def test_taid_t_out_of_range_raises(): inputs = _base_batch(with_dpo=False) model = _model_seeded(seed=4) with pytest.raises(ValueError, match=r"taid_t must be in \[0, 1\]"): compose_loss( model, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="taid", taid_t=1.5, ) def test_invalid_dpo_variant_raises(): inputs = _base_batch() model = _model_seeded(seed=5) with pytest.raises(ValueError, match="dpo_variant"): compose_loss( model, inputs, dpo_variant="bogus", # type: ignore[arg-type] ) def test_invalid_sdpo_wrapper_raises(): inputs = _base_batch() model = _model_seeded(seed=5) with pytest.raises(ValueError, match="sdpo_wrapper"): compose_loss( model, inputs, sdpo_wrapper="bogus", # type: ignore[arg-type] ) # ---------------------------------------------------------------------- # Bonus: TAIDScheduler integration # ---------------------------------------------------------------------- def test_taid_compose_with_scheduler(): """End-to-end: TAIDScheduler drives taid_t into compose_loss.""" from composer_replication.distillation import TAIDScheduler inputs = _base_batch(with_dpo=False) model = _model_seeded(seed=6) sched = TAIDScheduler(num_train_steps=100, t_start=0.4) for step in range(3): out = compose_loss( model, inputs, alpha_sdpo=0.1, beta_replay=0.0, sdpo_wrapper="taid", taid_t=sched.t, ) assert torch.isfinite(out.total) sched.update_t(out.sdpo_jsd.detach(), global_step=step) # t may have advanced past t_start after some steps (or stayed the same # given small num_train_steps and only 3 iters; just check it's still # in-range). assert 0.4 <= sched.t <= 1.0