File size: 1,283 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
 
b266c31
 
 
 
 
e5add15
b266c31
 
 
 
 
e5add15
b266c31
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""composer_replication.distillation — pluggable self-distillation losses.

Per ADR-007, three losses additive to the framework's existing
SDPO/OPSD (`generalized_jsd_loss`):

- SimPO: reference-free DPO replacement (channel 3 alternative)
- TAID: annealed teacher interpolation (wraps generalized_jsd_loss for channel 2)
- Entropy-Aware OPD: token-wise gated forward/reverse KL (alternative
  channel-2 wrapper, per ICLR 2026 Spotlight)

All three are pure PyTorch — no external deps — so they ship in the core
package without optional extras.

Usage in `compose_loss`:

    >>> from composer_replication import compose_loss
    >>> components = compose_loss(
    ...     model, batch,
    ...     dpo_variant="simpo",          # channel 3: DPO -> SimPO
    ...     sdpo_wrapper="taid",          # channel 2: SDPO -> TAID
    ...     taid_t=0.4,                   # current TAID interpolation coeff
    ... )
"""
from __future__ import annotations

from composer_replication.distillation.simpo import simpo_loss
from composer_replication.distillation.taid import TAIDScheduler, taid_loss
from composer_replication.distillation.entropy_aware_opd import entropy_aware_opd_loss

__all__ = [
    "simpo_loss",
    "taid_loss",
    "TAIDScheduler",
    "entropy_aware_opd_loss",
]