composer-replication-framework / composer_replication /tests /test_compose_loss_integration.py
Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""Integration tests for ADR-007 distillation kwargs in compose_loss.
These tests exercise the wiring between `compose_loss` and the three
pluggable losses (SimPO, TAID, Entropy-Aware OPD). They use a tiny
hand-rolled language model wrapper (no HF, no TRL) so the tests run
in <1s on CPU and are isolated from external library churn.
Coverage requirements:
(a) defaults reproduce existing compose_loss output bit-exact
(b) dpo_variant='simpo' produces a different total than dpo
(c) sdpo_wrapper='taid' with t=0 differs from t=1 (interpolation works)
(d) sdpo_wrapper='taid' with t=1 reproduces upstream forward-KL
(e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar
(f) error case: sdpo_wrapper='taid' without taid_t raises ValueError
"""
from __future__ import annotations
import pytest
import torch
import torch.nn as nn
from composer_replication import LossComponents, compose_loss
# ----------------------------------------------------------------------
# Tiny LM stand-in
# ----------------------------------------------------------------------
class TinyLM(nn.Module):
"""Minimal `nn.Module` with the HF-style `model(input_ids=...).logits` API.
Vocab=32, hidden=16, two-layer MLP head. Tiny enough that all tests
run in milliseconds on CPU.
"""
def __init__(self, vocab: int = 32, hidden: int = 16, seed: int = 0):
super().__init__()
torch.manual_seed(seed)
self.embed = nn.Embedding(vocab, hidden)
self.fc = nn.Linear(hidden, hidden)
self.head = nn.Linear(hidden, vocab)
def forward(self, input_ids: torch.Tensor):
h = torch.tanh(self.fc(self.embed(input_ids)))
logits = self.head(h)
class _Out:
pass
out = _Out()
out.logits = logits
return out
# ----------------------------------------------------------------------
# Batch fixtures
# ----------------------------------------------------------------------
VOCAB = 32
B = 2
T = 8
def _base_batch(seed: int = 7, *, with_dpo: bool = True) -> dict[str, torch.Tensor]:
"""Build a deterministic input batch with all 3 channels populated."""
g = torch.Generator().manual_seed(seed)
inputs: dict[str, torch.Tensor] = {
"input_ids": torch.randint(0, VOCAB, (B, T), generator=g),
"response_mask": torch.zeros(B, T, dtype=torch.long),
"ctx_teacher_input_ids": torch.randint(0, VOCAB, (B, T), generator=g),
"sdpo_loss_mask": torch.zeros(B, T, dtype=torch.long),
}
# Mark the second half as response tokens so the LM-CE channel is non-trivial.
inputs["response_mask"][:, T // 2:] = 1
inputs["sdpo_loss_mask"][:, T // 2:] = 1
if with_dpo:
inputs["dpo_chosen_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g)
inputs["dpo_chosen_response_mask"] = torch.ones(B, T, dtype=torch.long)
inputs["dpo_rejected_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g)
inputs["dpo_rejected_response_mask"] = torch.ones(B, T, dtype=torch.long)
# Standard DPO needs ref logprobs; SimPO ignores them.
inputs["dpo_chosen_ref_logprobs"] = torch.randn(B, generator=g)
inputs["dpo_rejected_ref_logprobs"] = torch.randn(B, generator=g)
return inputs
def _model_seeded(seed: int = 0) -> TinyLM:
m = TinyLM(vocab=VOCAB, hidden=16, seed=seed)
m.eval() # Deterministic forward — no dropout.
return m
# ----------------------------------------------------------------------
# (a) Defaults reproduce existing output bit-exact
# ----------------------------------------------------------------------
def test_defaults_bit_exact_with_legacy_kwargs():
"""Calling compose_loss with new kwargs at their defaults must equal
calling it with only the legacy kwargs. Bit-exact: every channel +
total agree to 0 ULPs because the code path is identical.
"""
inputs = _base_batch()
model_a = _model_seeded(seed=0)
out_legacy = compose_loss(
model_a,
inputs,
alpha_sdpo=0.1,
beta_replay=0.05,
sdpo_jsd_beta=0.5,
sdpo_temperature=1.0,
replay_dpo_beta=0.1,
)
model_b = _model_seeded(seed=0)
out_new = compose_loss(
model_b,
inputs,
alpha_sdpo=0.1,
beta_replay=0.05,
sdpo_jsd_beta=0.5,
sdpo_temperature=1.0,
replay_dpo_beta=0.1,
dpo_variant="dpo",
sdpo_wrapper="none",
)
assert isinstance(out_new, LossComponents)
assert torch.equal(out_legacy.lm_ce, out_new.lm_ce)
assert torch.equal(out_legacy.sdpo_jsd, out_new.sdpo_jsd)
assert torch.equal(out_legacy.trace_replay_dpo, out_new.trace_replay_dpo)
assert torch.equal(out_legacy.total, out_new.total)
# ----------------------------------------------------------------------
# (b) dpo_variant='simpo' produces a different total than dpo
# ----------------------------------------------------------------------
def test_simpo_variant_changes_total():
"""SimPO uses average-logprob and drops the reference subtraction, so
it must produce a different (and finite) trace_replay_dpo + total."""
inputs = _base_batch()
model_a = _model_seeded(seed=0)
out_dpo = compose_loss(
model_a, inputs,
alpha_sdpo=0.0, # isolate channel 3
beta_replay=0.05,
dpo_variant="dpo",
)
model_b = _model_seeded(seed=0)
out_simpo = compose_loss(
model_b, inputs,
alpha_sdpo=0.0,
beta_replay=0.05,
dpo_variant="simpo",
)
assert torch.isfinite(out_simpo.total)
assert torch.isfinite(out_simpo.trace_replay_dpo)
# Different formulae => different values.
assert not torch.allclose(
out_dpo.trace_replay_dpo, out_simpo.trace_replay_dpo
)
assert not torch.allclose(out_dpo.total, out_simpo.total)
# Gradient flow check.
out_simpo.total.backward()
assert any(
p.grad is not None and torch.isfinite(p.grad).all()
for p in model_b.parameters()
)
def test_simpo_does_not_require_ref_logprobs():
"""SimPO is reference-free; compose_loss should run when those keys are
absent from `inputs` (only when dpo_variant='simpo')."""
inputs = _base_batch()
inputs.pop("dpo_chosen_ref_logprobs")
inputs.pop("dpo_rejected_ref_logprobs")
model = _model_seeded(seed=0)
out = compose_loss(
model, inputs,
alpha_sdpo=0.0,
beta_replay=0.05,
dpo_variant="simpo",
)
assert torch.isfinite(out.total)
assert torch.isfinite(out.trace_replay_dpo)
# ----------------------------------------------------------------------
# (c) TAID with t=1 reproduces upstream forward-KL on the masked tokens
# ----------------------------------------------------------------------
def test_taid_t_one_matches_upstream_forward_kl():
"""At t=1, taid_loss reduces to forward-KL with target = softmax(teacher).
compose_loss should plumb through to that exact value (modulo the
sdpo_loss_mask token-mean denominator).
"""
import torch.nn.functional as F
inputs = _base_batch(with_dpo=False)
model = _model_seeded(seed=1)
# Run compose_loss with TAID at t=1.
out_taid = compose_loss(
model, inputs,
alpha_sdpo=1.0, # so out.sdpo_jsd is added straight to total
beta_replay=0.0,
sdpo_wrapper="taid",
taid_t=1.0,
)
# Manually compute the same forward-KL on the masked tokens.
student_logits = model(input_ids=inputs["input_ids"]).logits
with torch.no_grad():
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
mask = inputs["sdpo_loss_mask"].float()
p_teacher = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
log_q = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
per_token = -(p_teacher * log_q).sum(dim=-1)
flat = per_token.reshape(-1)
fmask = mask.reshape(-1).to(flat.dtype)
expected = (flat * fmask).sum() / fmask.sum().clamp_min(1.0)
# Bit-exact assertion. The TAID-loss path at t=1 is mathematically
# identical to the manual `-(p_teacher * log_q).sum(...)` cross-entropy
# below: at t=1, TAID's logit-space mix collapses to `teacher_logits`,
# `softmax(teacher_logits)` is computed bit-identically inside
# `taid_loss`, and the masked-mean reduction matches. So `torch.equal`
# succeeds — and asserting `equal` rather than `allclose` catches any
# future refactor that re-introduces a softmax→log roundtrip with
# ULP drift.
#
# If a future change forces a roundtrip we cannot eliminate, drop to
# `torch.testing.assert_close(out_taid.sdpo_jsd, expected,
# atol=1e-7, rtol=0)` — that is the strict-but-feasible bound for
# softmax→log→softmax in float32 (one ULP at the scale of the loss,
# ~3.5e-7 here, dominated by the log_softmax LSE accumulation).
assert torch.equal(out_taid.sdpo_jsd, expected), (
f"TAID t=1 must equal upstream forward-KL bit-exact; "
f"got out={out_taid.sdpo_jsd.item()!r}, "
f"expected={expected.item()!r}, "
f"diff={(out_taid.sdpo_jsd - expected).abs().item():.3e}"
)
# ----------------------------------------------------------------------
# (d) TAID interpolates: t=0 differs from t=1
# ----------------------------------------------------------------------
def test_taid_interpolates_with_t():
"""Different t values give different sdpo_jsd. Differentiable end-to-end."""
inputs = _base_batch(with_dpo=False)
model_zero = _model_seeded(seed=2)
out_zero = compose_loss(
model_zero, inputs,
alpha_sdpo=0.1, beta_replay=0.0,
sdpo_wrapper="taid",
taid_t=0.0,
)
model_mid = _model_seeded(seed=2)
out_mid = compose_loss(
model_mid, inputs,
alpha_sdpo=0.1, beta_replay=0.0,
sdpo_wrapper="taid",
taid_t=0.5,
)
model_one = _model_seeded(seed=2)
out_one = compose_loss(
model_one, inputs,
alpha_sdpo=0.1, beta_replay=0.0,
sdpo_wrapper="taid",
taid_t=1.0,
)
for out in (out_zero, out_mid, out_one):
assert torch.isfinite(out.total)
assert torch.isfinite(out.sdpo_jsd)
assert not torch.allclose(out_zero.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5)
assert not torch.allclose(out_mid.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5)
out_mid.total.backward()
assert any(
p.grad is not None and torch.isfinite(p.grad).all()
for p in model_mid.parameters()
)
# ----------------------------------------------------------------------
# (e) Entropy-Aware OPD returns a finite differentiable scalar
# ----------------------------------------------------------------------
def test_entropy_opd_returns_finite_differentiable_scalar():
inputs = _base_batch(with_dpo=False)
model = _model_seeded(seed=3)
out = compose_loss(
model, inputs,
alpha_sdpo=0.1,
beta_replay=0.0,
sdpo_wrapper="entropy_opd",
)
assert isinstance(out, LossComponents)
assert out.total.shape == ()
assert torch.isfinite(out.total)
assert torch.isfinite(out.sdpo_jsd)
assert out.total.requires_grad
out.total.backward()
grads = [p.grad for p in model.parameters() if p.grad is not None]
assert len(grads) > 0
assert all(torch.isfinite(g).all() for g in grads)
# ----------------------------------------------------------------------
# (f) Error: sdpo_wrapper='taid' without taid_t
# ----------------------------------------------------------------------
def test_taid_requires_t():
inputs = _base_batch(with_dpo=False)
model = _model_seeded(seed=4)
with pytest.raises(ValueError, match="taid_t"):
compose_loss(
model, inputs,
alpha_sdpo=0.1, beta_replay=0.0,
sdpo_wrapper="taid",
# taid_t omitted on purpose
)
def test_taid_t_out_of_range_raises():
inputs = _base_batch(with_dpo=False)
model = _model_seeded(seed=4)
with pytest.raises(ValueError, match=r"taid_t must be in \[0, 1\]"):
compose_loss(
model, inputs,
alpha_sdpo=0.1, beta_replay=0.0,
sdpo_wrapper="taid",
taid_t=1.5,
)
def test_invalid_dpo_variant_raises():
inputs = _base_batch()
model = _model_seeded(seed=5)
with pytest.raises(ValueError, match="dpo_variant"):
compose_loss(
model, inputs,
dpo_variant="bogus", # type: ignore[arg-type]
)
def test_invalid_sdpo_wrapper_raises():
inputs = _base_batch()
model = _model_seeded(seed=5)
with pytest.raises(ValueError, match="sdpo_wrapper"):
compose_loss(
model, inputs,
sdpo_wrapper="bogus", # type: ignore[arg-type]
)
# ----------------------------------------------------------------------
# Bonus: TAIDScheduler integration
# ----------------------------------------------------------------------
def test_taid_compose_with_scheduler():
"""End-to-end: TAIDScheduler drives taid_t into compose_loss."""
from composer_replication.distillation import TAIDScheduler
inputs = _base_batch(with_dpo=False)
model = _model_seeded(seed=6)
sched = TAIDScheduler(num_train_steps=100, t_start=0.4)
for step in range(3):
out = compose_loss(
model, inputs,
alpha_sdpo=0.1, beta_replay=0.0,
sdpo_wrapper="taid",
taid_t=sched.t,
)
assert torch.isfinite(out.total)
sched.update_t(out.sdpo_jsd.detach(), global_step=step)
# t may have advanced past t_start after some steps (or stayed the same
# given small num_train_steps and only 3 iters; just check it's still
# in-range).
assert 0.4 <= sched.t <= 1.0