Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""TAID loss — Temporally Adaptive Interpolated Distillation.
Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient
Knowledge Transfer in Language Models"
Sakana AI, arXiv:2501.16937
License: Apache-2.0 (https://github.com/SakanaAI/TAID)
This module is a faithful port of the reference implementation at
``SakanaAI/TAID/src/distil_losses/taid.py``. **The previous in-tree
implementation was algorithmically different from the paper** (it mixed in
probability space against a frozen step-0 student snapshot and wrapped a
symmetric JSD criterion). This rewrite replaces it with the upstream
algorithm:
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
loss = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
That is:
1. Mix in **logit space**, not probability space.
2. Anchor against the **current student detached** (re-evaluated each
step), not a frozen step-0 snapshot.
3. Distillation criterion is **forward KL** (Hinton-style soft target),
not symmetric JSD.
Schedule
--------
The original implementation embedded an adaptive momentum-based schedule
inside the loss object; this is now factored out into the optional
:class:`TAIDScheduler` so the loss function itself is pure (single ``t``
in [0, 1]). Callers either:
- Pass a fixed ``t`` for ablations / fixed schedules.
- Drive ``t`` via :class:`TAIDScheduler` (paper-default adaptive scheme).
- Drive ``t`` via any custom schedule of their choosing.
Backward-incompatible change
----------------------------
The previous public signature was:
taid_loss(student_logits, teacher_logits, student_init_logits, *,
schedule_step, total_steps, schedule, alpha_min, alpha_max,
jsd_beta, temperature, reduction)
The new signature is:
taid_loss(student_logits, teacher_logits, mask=None, *, t)
Removed kwargs (``student_init_logits``, ``schedule_step``, ``total_steps``,
``schedule``, ``alpha_min``, ``alpha_max``, ``jsd_beta``, ``temperature``,
``reduction``) have no upstream analogue. Pass ``t`` directly; if you need
a schedule, use :class:`TAIDScheduler` or compute ``t`` yourself.
Reference: arXiv:2501.16937; ``SakanaAI/TAID`` commit history.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
def taid_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor | None = None,
*,
t: float | torch.Tensor,
) -> torch.Tensor:
"""TAID forward-KL loss against a logit-space-interpolated target.
Faithful port of ``SakanaAI/TAID/src/distil_losses/taid.py:compute_loss``
composed with ``fkl.forward_kl``.
Pseudocode::
p_t = softmax( (1 - t) · student_logits.detach() + t · teacher_logits )
log_q = log_softmax( student_logits )
per_token = - Σ_v p_t(v) · log_q(v) # forward KL token-wise
loss = sum(per_token · mask) / sum(mask)
Args:
student_logits: ``(B, T, V)`` current student logits, with grad.
teacher_logits: ``(B, T, V)`` teacher logits (no grad expected;
detached internally only insofar as the interpolation uses the
student detach — teacher gradient is left untouched, matching
upstream).
mask: ``(B, T)`` token mask (1 = include, 0 = ignore). Required by
upstream; defaults to all-ones if omitted for convenience.
t: interpolation coefficient in ``[0, 1]``. Scalar Python float or
0-d torch.Tensor. ``t=0`` makes the target match the (detached)
student — a regularizer with zero gradient signal. ``t=1`` makes
the target the teacher — pure forward-KL distillation.
Returns:
Scalar loss (token-mean, in float32 dtype matching upstream).
Raises:
ValueError: shape mismatch between student/teacher, or invalid mask
shape.
Reference: arXiv:2501.16937 §3.1 + Eq. (4); upstream commit at
``SakanaAI/TAID@main:src/distil_losses/taid.py``.
"""
if student_logits.shape != teacher_logits.shape:
raise ValueError(
f"student/teacher logits shape mismatch: "
f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}"
)
if mask is None:
mask = student_logits.new_ones(student_logits.shape[:-1])
elif mask.shape != student_logits.shape[:-1]:
raise ValueError(
f"mask shape {tuple(mask.shape)} does not match logits prefix "
f"{tuple(student_logits.shape[:-1])}"
)
# 1. Logit-space mix with student detached (anchor = current student, no grad).
blended_logits = (1 - t) * student_logits.detach() + t * teacher_logits
# 2. Target distribution in float32 for numerical stability (upstream choice).
p_t = F.softmax(blended_logits, dim=-1, dtype=torch.float32)
# 3. Forward KL: the gradient flows ONLY through student log-softmax.
student_logprobs = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
# 4. Mask out -inf positions in the student logits (upstream guard).
inf_mask = torch.isinf(student_logits)
prod = torch.masked_fill(p_t * student_logprobs, inf_mask, 0.0)
# 5. Per-token cross-entropy = -sum_v p_t(v) * log_q(v); reduce over vocab.
per_token = -prod.sum(dim=-1).reshape(-1)
flat_mask = mask.reshape(-1).to(per_token.dtype)
denom = flat_mask.sum().clamp_min(1.0)
loss = (per_token * flat_mask).sum() / denom
return loss
class TAIDScheduler:
"""Adaptive momentum-based schedule for TAID's interpolation coefficient ``t``.
Stateful, mirrors ``SakanaAI/TAID/src/distil_losses/taid.py:TAID.update_t``.
Usage::
sched = TAIDScheduler(num_train_steps=10_000)
for step in range(num_train_steps):
t = sched.t # current t (float)
loss = taid_loss(s_logits, t_logits, mask, t=t)
loss.backward(); optimizer.step()
sched.update_t(loss.detach(), global_step=step)
The schedule is monotone non-decreasing: at each step, the floor is the
linear schedule ``t_target = t_start + progress · (t_end - t_start)``,
and an adaptive bump ``alpha · σ(momentum) · (1 - t)`` is added on top
where ``momentum`` tracks the relative loss change with EMA decay
``beta``. ``disable_adaptive=True`` collapses to the deterministic linear
schedule.
Args:
num_train_steps: total planned training steps; required so the linear
floor ``t_target`` is well-defined.
t_start: initial ``t`` (paper default 0.4 — the student is already
close to the teacher in this regime, so ``t=0`` would waste the
warmup phase).
t_end: terminal ``t`` (paper default 1.0).
alpha: adaptive bump magnitude (paper default 5e-4).
beta: EMA decay for the relative-loss-change momentum (paper default
0.99).
disable_adaptive: if True, fall back to deterministic linear schedule
``t_target = t_start + progress · (t_end - t_start)``.
device: device to allocate state buffers on; default cpu.
"""
def __init__(
self,
num_train_steps: int,
*,
t_start: float = 0.4,
t_end: float = 1.0,
alpha: float = 5e-4,
beta: float = 0.99,
disable_adaptive: bool = False,
device: torch.device | str = "cpu",
) -> None:
if not (0.0 <= t_start < 1.0):
raise ValueError(f"t_start must be in [0, 1), got {t_start}")
if not (0.0 < t_end <= 1.0):
raise ValueError(f"t_end must be in (0, 1], got {t_end}")
if not (0.0 <= alpha <= 1.0):
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
if num_train_steps <= 0:
raise ValueError(f"num_train_steps must be > 0, got {num_train_steps}")
self.t_start = t_start
self.t_end = t_end
self.alpha = alpha
self.beta = beta
self.disable_adaptive = disable_adaptive
self.num_train_steps = num_train_steps
self._t = torch.tensor(t_start, device=device, dtype=torch.float32)
self._prev_loss = torch.tensor(
float("inf"), device=device, dtype=torch.float32
)
self._momentum = torch.zeros([], device=device, dtype=torch.float32)
@property
def t(self) -> float:
"""Current interpolation coefficient as a Python float."""
return float(self._t)
def update_t(
self,
loss: torch.Tensor,
global_step: int,
) -> torch.Tensor | None:
"""Update internal ``t`` given the current step's distillation loss.
Mirrors upstream verbatim. First call with finite loss only seeds
``prev_loss`` and returns None. Subsequent calls update momentum +
``t`` and return the (positive) ``delta_t`` that was added on top of
the linear floor (None for the first call).
Args:
loss: scalar loss tensor (caller should pass ``loss.detach()``).
global_step: current global step (0-indexed).
Returns:
The adaptive ``delta_t`` that was applied, or None if this was
the seeding call.
"""
if torch.isinf(self._prev_loss):
self._prev_loss = loss.detach().to(self._prev_loss)
return None
relative_change = (self._prev_loss - loss) / (self._prev_loss + 1e-15)
self._momentum = (
self.beta * self._momentum + (1 - self.beta) * relative_change
)
adaptive_delta = torch.sigmoid(self._momentum)
progress = global_step / self.num_train_steps
t_target = self.t_start + (self.t_end - self.t_start) * progress
delta_t = self.alpha * adaptive_delta * (1 - self._t)
if self.disable_adaptive:
new_t = t_target
else:
new_t = min(self.t_end, max(t_target, float(self._t + delta_t)))
if not isinstance(new_t, torch.Tensor):
new_t = torch.tensor(new_t, device=self._t.device, dtype=self._t.dtype)
self._t = new_t
self._prev_loss = loss.detach().to(self._prev_loss)
return delta_t
__all__ = ["taid_loss", "TAIDScheduler"]