Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""composer_replication.distillation — pluggable self-distillation losses.
Per ADR-007, three losses additive to the framework's existing
SDPO/OPSD (`generalized_jsd_loss`):
- SimPO: reference-free DPO replacement (channel 3 alternative)
- TAID: annealed teacher interpolation (wraps generalized_jsd_loss for channel 2)
- Entropy-Aware OPD: token-wise gated forward/reverse KL (alternative
channel-2 wrapper, per ICLR 2026 Spotlight)
All three are pure PyTorch — no external deps — so they ship in the core
package without optional extras.
Usage in `compose_loss`:
>>> from composer_replication import compose_loss
>>> components = compose_loss(
... model, batch,
... dpo_variant="simpo", # channel 3: DPO -> SimPO
... sdpo_wrapper="taid", # channel 2: SDPO -> TAID
... taid_t=0.4, # current TAID interpolation coeff
... )
"""
from __future__ import annotations
from composer_replication.distillation.simpo import simpo_loss
from composer_replication.distillation.taid import TAIDScheduler, taid_loss
from composer_replication.distillation.entropy_aware_opd import entropy_aware_opd_loss
__all__ = [
"simpo_loss",
"taid_loss",
"TAIDScheduler",
"entropy_aware_opd_loss",
]