File size: 13,887 Bytes
d9dd3a5
 
 
 
 
 
 
e5add15
d9dd3a5
 
e5add15
 
d9dd3a5
e5add15
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
d9dd3a5
 
e5add15
 
 
 
d9dd3a5
e5add15
d9dd3a5
e5add15
d9dd3a5
e5add15
d9dd3a5
e5add15
d9dd3a5
e5add15
 
d9dd3a5
 
e5add15
d9dd3a5
 
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
e5add15
d9dd3a5
 
e5add15
 
d9dd3a5
 
e5add15
 
 
d9dd3a5
 
e5add15
d9dd3a5
 
e5add15
 
 
d9dd3a5
 
e5add15
d9dd3a5
 
e5add15
 
 
d9dd3a5
 
e5add15
d9dd3a5
 
e5add15
 
 
d9dd3a5
e5add15
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
d9dd3a5
 
e5add15
d9dd3a5
 
e5add15
d9dd3a5
 
 
 
e5add15
d9dd3a5
 
 
e5add15
d9dd3a5
 
e5add15
d9dd3a5
 
 
 
e5add15
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
d9dd3a5
 
e5add15
 
 
 
d9dd3a5
 
e5add15
d9dd3a5
e5add15
 
 
 
 
 
 
 
 
d9dd3a5
e5add15
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
"""Integration tests for ADR-007 distillation kwargs in compose_loss.

These tests exercise the wiring between `compose_loss` and the three
pluggable losses (SimPO, TAID, Entropy-Aware OPD). They use a tiny
hand-rolled language model wrapper (no HF, no TRL) so the tests run
in <1s on CPU and are isolated from external library churn.

Coverage requirements:
    (a) defaults reproduce existing compose_loss output bit-exact
    (b) dpo_variant='simpo' produces a different total than dpo
    (c) sdpo_wrapper='taid' with t=0 differs from t=1 (interpolation works)
    (d) sdpo_wrapper='taid' with t=1 reproduces upstream forward-KL
    (e) sdpo_wrapper='entropy_opd' returns a finite differentiable scalar
    (f) error case: sdpo_wrapper='taid' without taid_t raises ValueError
"""
from __future__ import annotations

import pytest
import torch
import torch.nn as nn

from composer_replication import LossComponents, compose_loss


# ----------------------------------------------------------------------
# Tiny LM stand-in
# ----------------------------------------------------------------------

class TinyLM(nn.Module):
    """Minimal `nn.Module` with the HF-style `model(input_ids=...).logits` API.

    Vocab=32, hidden=16, two-layer MLP head. Tiny enough that all tests
    run in milliseconds on CPU.
    """

    def __init__(self, vocab: int = 32, hidden: int = 16, seed: int = 0):
        super().__init__()
        torch.manual_seed(seed)
        self.embed = nn.Embedding(vocab, hidden)
        self.fc = nn.Linear(hidden, hidden)
        self.head = nn.Linear(hidden, vocab)

    def forward(self, input_ids: torch.Tensor):
        h = torch.tanh(self.fc(self.embed(input_ids)))
        logits = self.head(h)

        class _Out:
            pass
        out = _Out()
        out.logits = logits
        return out


# ----------------------------------------------------------------------
# Batch fixtures
# ----------------------------------------------------------------------

VOCAB = 32
B = 2
T = 8


def _base_batch(seed: int = 7, *, with_dpo: bool = True) -> dict[str, torch.Tensor]:
    """Build a deterministic input batch with all 3 channels populated."""
    g = torch.Generator().manual_seed(seed)
    inputs: dict[str, torch.Tensor] = {
        "input_ids": torch.randint(0, VOCAB, (B, T), generator=g),
        "response_mask": torch.zeros(B, T, dtype=torch.long),
        "ctx_teacher_input_ids": torch.randint(0, VOCAB, (B, T), generator=g),
        "sdpo_loss_mask": torch.zeros(B, T, dtype=torch.long),
    }
    # Mark the second half as response tokens so the LM-CE channel is non-trivial.
    inputs["response_mask"][:, T // 2:] = 1
    inputs["sdpo_loss_mask"][:, T // 2:] = 1

    if with_dpo:
        inputs["dpo_chosen_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g)
        inputs["dpo_chosen_response_mask"] = torch.ones(B, T, dtype=torch.long)
        inputs["dpo_rejected_input_ids"] = torch.randint(0, VOCAB, (B, T), generator=g)
        inputs["dpo_rejected_response_mask"] = torch.ones(B, T, dtype=torch.long)
        # Standard DPO needs ref logprobs; SimPO ignores them.
        inputs["dpo_chosen_ref_logprobs"] = torch.randn(B, generator=g)
        inputs["dpo_rejected_ref_logprobs"] = torch.randn(B, generator=g)
    return inputs


def _model_seeded(seed: int = 0) -> TinyLM:
    m = TinyLM(vocab=VOCAB, hidden=16, seed=seed)
    m.eval()  # Deterministic forward — no dropout.
    return m


# ----------------------------------------------------------------------
# (a) Defaults reproduce existing output bit-exact
# ----------------------------------------------------------------------

def test_defaults_bit_exact_with_legacy_kwargs():
    """Calling compose_loss with new kwargs at their defaults must equal
    calling it with only the legacy kwargs. Bit-exact: every channel +
    total agree to 0 ULPs because the code path is identical.
    """
    inputs = _base_batch()

    model_a = _model_seeded(seed=0)
    out_legacy = compose_loss(
        model_a,
        inputs,
        alpha_sdpo=0.1,
        beta_replay=0.05,
        sdpo_jsd_beta=0.5,
        sdpo_temperature=1.0,
        replay_dpo_beta=0.1,
    )

    model_b = _model_seeded(seed=0)
    out_new = compose_loss(
        model_b,
        inputs,
        alpha_sdpo=0.1,
        beta_replay=0.05,
        sdpo_jsd_beta=0.5,
        sdpo_temperature=1.0,
        replay_dpo_beta=0.1,
        dpo_variant="dpo",
        sdpo_wrapper="none",
    )

    assert isinstance(out_new, LossComponents)
    assert torch.equal(out_legacy.lm_ce, out_new.lm_ce)
    assert torch.equal(out_legacy.sdpo_jsd, out_new.sdpo_jsd)
    assert torch.equal(out_legacy.trace_replay_dpo, out_new.trace_replay_dpo)
    assert torch.equal(out_legacy.total, out_new.total)


# ----------------------------------------------------------------------
# (b) dpo_variant='simpo' produces a different total than dpo
# ----------------------------------------------------------------------

def test_simpo_variant_changes_total():
    """SimPO uses average-logprob and drops the reference subtraction, so
    it must produce a different (and finite) trace_replay_dpo + total."""
    inputs = _base_batch()

    model_a = _model_seeded(seed=0)
    out_dpo = compose_loss(
        model_a, inputs,
        alpha_sdpo=0.0,  # isolate channel 3
        beta_replay=0.05,
        dpo_variant="dpo",
    )

    model_b = _model_seeded(seed=0)
    out_simpo = compose_loss(
        model_b, inputs,
        alpha_sdpo=0.0,
        beta_replay=0.05,
        dpo_variant="simpo",
    )

    assert torch.isfinite(out_simpo.total)
    assert torch.isfinite(out_simpo.trace_replay_dpo)
    # Different formulae => different values.
    assert not torch.allclose(
        out_dpo.trace_replay_dpo, out_simpo.trace_replay_dpo
    )
    assert not torch.allclose(out_dpo.total, out_simpo.total)
    # Gradient flow check.
    out_simpo.total.backward()
    assert any(
        p.grad is not None and torch.isfinite(p.grad).all()
        for p in model_b.parameters()
    )


def test_simpo_does_not_require_ref_logprobs():
    """SimPO is reference-free; compose_loss should run when those keys are
    absent from `inputs` (only when dpo_variant='simpo')."""
    inputs = _base_batch()
    inputs.pop("dpo_chosen_ref_logprobs")
    inputs.pop("dpo_rejected_ref_logprobs")

    model = _model_seeded(seed=0)
    out = compose_loss(
        model, inputs,
        alpha_sdpo=0.0,
        beta_replay=0.05,
        dpo_variant="simpo",
    )
    assert torch.isfinite(out.total)
    assert torch.isfinite(out.trace_replay_dpo)


# ----------------------------------------------------------------------
# (c) TAID with t=1 reproduces upstream forward-KL on the masked tokens
# ----------------------------------------------------------------------

def test_taid_t_one_matches_upstream_forward_kl():
    """At t=1, taid_loss reduces to forward-KL with target = softmax(teacher).
    compose_loss should plumb through to that exact value (modulo the
    sdpo_loss_mask token-mean denominator).
    """
    import torch.nn.functional as F

    inputs = _base_batch(with_dpo=False)

    model = _model_seeded(seed=1)

    # Run compose_loss with TAID at t=1.
    out_taid = compose_loss(
        model, inputs,
        alpha_sdpo=1.0,  # so out.sdpo_jsd is added straight to total
        beta_replay=0.0,
        sdpo_wrapper="taid",
        taid_t=1.0,
    )

    # Manually compute the same forward-KL on the masked tokens.
    student_logits = model(input_ids=inputs["input_ids"]).logits
    with torch.no_grad():
        teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
    mask = inputs["sdpo_loss_mask"].float()
    p_teacher = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    log_q = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
    per_token = -(p_teacher * log_q).sum(dim=-1)
    flat = per_token.reshape(-1)
    fmask = mask.reshape(-1).to(flat.dtype)
    expected = (flat * fmask).sum() / fmask.sum().clamp_min(1.0)

    # Bit-exact assertion. The TAID-loss path at t=1 is mathematically
    # identical to the manual `-(p_teacher * log_q).sum(...)` cross-entropy
    # below: at t=1, TAID's logit-space mix collapses to `teacher_logits`,
    # `softmax(teacher_logits)` is computed bit-identically inside
    # `taid_loss`, and the masked-mean reduction matches. So `torch.equal`
    # succeeds — and asserting `equal` rather than `allclose` catches any
    # future refactor that re-introduces a softmax→log roundtrip with
    # ULP drift.
    #
    # If a future change forces a roundtrip we cannot eliminate, drop to
    # `torch.testing.assert_close(out_taid.sdpo_jsd, expected,
    # atol=1e-7, rtol=0)` — that is the strict-but-feasible bound for
    # softmax→log→softmax in float32 (one ULP at the scale of the loss,
    # ~3.5e-7 here, dominated by the log_softmax LSE accumulation).
    assert torch.equal(out_taid.sdpo_jsd, expected), (
        f"TAID t=1 must equal upstream forward-KL bit-exact; "
        f"got out={out_taid.sdpo_jsd.item()!r}, "
        f"expected={expected.item()!r}, "
        f"diff={(out_taid.sdpo_jsd - expected).abs().item():.3e}"
    )


# ----------------------------------------------------------------------
# (d) TAID interpolates: t=0 differs from t=1
# ----------------------------------------------------------------------

def test_taid_interpolates_with_t():
    """Different t values give different sdpo_jsd. Differentiable end-to-end."""
    inputs = _base_batch(with_dpo=False)

    model_zero = _model_seeded(seed=2)
    out_zero = compose_loss(
        model_zero, inputs,
        alpha_sdpo=0.1, beta_replay=0.0,
        sdpo_wrapper="taid",
        taid_t=0.0,
    )

    model_mid = _model_seeded(seed=2)
    out_mid = compose_loss(
        model_mid, inputs,
        alpha_sdpo=0.1, beta_replay=0.0,
        sdpo_wrapper="taid",
        taid_t=0.5,
    )

    model_one = _model_seeded(seed=2)
    out_one = compose_loss(
        model_one, inputs,
        alpha_sdpo=0.1, beta_replay=0.0,
        sdpo_wrapper="taid",
        taid_t=1.0,
    )

    for out in (out_zero, out_mid, out_one):
        assert torch.isfinite(out.total)
        assert torch.isfinite(out.sdpo_jsd)

    assert not torch.allclose(out_zero.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5)
    assert not torch.allclose(out_mid.sdpo_jsd, out_one.sdpo_jsd, atol=1e-5)

    out_mid.total.backward()
    assert any(
        p.grad is not None and torch.isfinite(p.grad).all()
        for p in model_mid.parameters()
    )


# ----------------------------------------------------------------------
# (e) Entropy-Aware OPD returns a finite differentiable scalar
# ----------------------------------------------------------------------

def test_entropy_opd_returns_finite_differentiable_scalar():
    inputs = _base_batch(with_dpo=False)

    model = _model_seeded(seed=3)
    out = compose_loss(
        model, inputs,
        alpha_sdpo=0.1,
        beta_replay=0.0,
        sdpo_wrapper="entropy_opd",
    )

    assert isinstance(out, LossComponents)
    assert out.total.shape == ()
    assert torch.isfinite(out.total)
    assert torch.isfinite(out.sdpo_jsd)
    assert out.total.requires_grad

    out.total.backward()
    grads = [p.grad for p in model.parameters() if p.grad is not None]
    assert len(grads) > 0
    assert all(torch.isfinite(g).all() for g in grads)


# ----------------------------------------------------------------------
# (f) Error: sdpo_wrapper='taid' without taid_t
# ----------------------------------------------------------------------

def test_taid_requires_t():
    inputs = _base_batch(with_dpo=False)
    model = _model_seeded(seed=4)
    with pytest.raises(ValueError, match="taid_t"):
        compose_loss(
            model, inputs,
            alpha_sdpo=0.1, beta_replay=0.0,
            sdpo_wrapper="taid",
            # taid_t omitted on purpose
        )


def test_taid_t_out_of_range_raises():
    inputs = _base_batch(with_dpo=False)
    model = _model_seeded(seed=4)
    with pytest.raises(ValueError, match=r"taid_t must be in \[0, 1\]"):
        compose_loss(
            model, inputs,
            alpha_sdpo=0.1, beta_replay=0.0,
            sdpo_wrapper="taid",
            taid_t=1.5,
        )


def test_invalid_dpo_variant_raises():
    inputs = _base_batch()
    model = _model_seeded(seed=5)
    with pytest.raises(ValueError, match="dpo_variant"):
        compose_loss(
            model, inputs,
            dpo_variant="bogus",  # type: ignore[arg-type]
        )


def test_invalid_sdpo_wrapper_raises():
    inputs = _base_batch()
    model = _model_seeded(seed=5)
    with pytest.raises(ValueError, match="sdpo_wrapper"):
        compose_loss(
            model, inputs,
            sdpo_wrapper="bogus",  # type: ignore[arg-type]
        )


# ----------------------------------------------------------------------
# Bonus: TAIDScheduler integration
# ----------------------------------------------------------------------

def test_taid_compose_with_scheduler():
    """End-to-end: TAIDScheduler drives taid_t into compose_loss."""
    from composer_replication.distillation import TAIDScheduler

    inputs = _base_batch(with_dpo=False)
    model = _model_seeded(seed=6)
    sched = TAIDScheduler(num_train_steps=100, t_start=0.4)

    for step in range(3):
        out = compose_loss(
            model, inputs,
            alpha_sdpo=0.1, beta_replay=0.0,
            sdpo_wrapper="taid",
            taid_t=sched.t,
        )
        assert torch.isfinite(out.total)
        sched.update_t(out.sdpo_jsd.detach(), global_step=step)

    # t may have advanced past t_start after some steps (or stayed the same
    # given small num_train_steps and only 3 iters; just check it's still
    # in-range).
    assert 0.4 <= sched.t <= 1.0