"""TAID loss — Temporally Adaptive Interpolated Distillation. Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient Knowledge Transfer in Language Models" Sakana AI, arXiv:2501.16937 License: Apache-2.0 (https://github.com/SakanaAI/TAID) This module is a faithful port of the reference implementation at ``SakanaAI/TAID/src/distil_losses/taid.py``. **The previous in-tree implementation was algorithmically different from the paper** (it mixed in probability space against a frozen step-0 student snapshot and wrapped a symmetric JSD criterion). This rewrite replaces it with the upstream algorithm: p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits ) loss = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v) That is: 1. Mix in **logit space**, not probability space. 2. Anchor against the **current student detached** (re-evaluated each step), not a frozen step-0 snapshot. 3. Distillation criterion is **forward KL** (Hinton-style soft target), not symmetric JSD. Schedule -------- The original implementation embedded an adaptive momentum-based schedule inside the loss object; this is now factored out into the optional :class:`TAIDScheduler` so the loss function itself is pure (single ``t`` in [0, 1]). Callers either: - Pass a fixed ``t`` for ablations / fixed schedules. - Drive ``t`` via :class:`TAIDScheduler` (paper-default adaptive scheme). - Drive ``t`` via any custom schedule of their choosing. Backward-incompatible change ---------------------------- The previous public signature was: taid_loss(student_logits, teacher_logits, student_init_logits, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction) The new signature is: taid_loss(student_logits, teacher_logits, mask=None, *, t) Removed kwargs (``student_init_logits``, ``schedule_step``, ``total_steps``, ``schedule``, ``alpha_min``, ``alpha_max``, ``jsd_beta``, ``temperature``, ``reduction``) have no upstream analogue. Pass ``t`` directly; if you need a schedule, use :class:`TAIDScheduler` or compute ``t`` yourself. Reference: arXiv:2501.16937; ``SakanaAI/TAID`` commit history. """ from __future__ import annotations import torch import torch.nn.functional as F def taid_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, mask: torch.Tensor | None = None, *, t: float | torch.Tensor, ) -> torch.Tensor: """TAID forward-KL loss against a logit-space-interpolated target. Faithful port of ``SakanaAI/TAID/src/distil_losses/taid.py:compute_loss`` composed with ``fkl.forward_kl``. Pseudocode:: p_t = softmax( (1 - t) · student_logits.detach() + t · teacher_logits ) log_q = log_softmax( student_logits ) per_token = - Σ_v p_t(v) · log_q(v) # forward KL token-wise loss = sum(per_token · mask) / sum(mask) Args: student_logits: ``(B, T, V)`` current student logits, with grad. teacher_logits: ``(B, T, V)`` teacher logits (no grad expected; detached internally only insofar as the interpolation uses the student detach — teacher gradient is left untouched, matching upstream). mask: ``(B, T)`` token mask (1 = include, 0 = ignore). Required by upstream; defaults to all-ones if omitted for convenience. t: interpolation coefficient in ``[0, 1]``. Scalar Python float or 0-d torch.Tensor. ``t=0`` makes the target match the (detached) student — a regularizer with zero gradient signal. ``t=1`` makes the target the teacher — pure forward-KL distillation. Returns: Scalar loss (token-mean, in float32 dtype matching upstream). Raises: ValueError: shape mismatch between student/teacher, or invalid mask shape. Reference: arXiv:2501.16937 §3.1 + Eq. (4); upstream commit at ``SakanaAI/TAID@main:src/distil_losses/taid.py``. """ if student_logits.shape != teacher_logits.shape: raise ValueError( f"student/teacher logits shape mismatch: " f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}" ) if mask is None: mask = student_logits.new_ones(student_logits.shape[:-1]) elif mask.shape != student_logits.shape[:-1]: raise ValueError( f"mask shape {tuple(mask.shape)} does not match logits prefix " f"{tuple(student_logits.shape[:-1])}" ) # 1. Logit-space mix with student detached (anchor = current student, no grad). blended_logits = (1 - t) * student_logits.detach() + t * teacher_logits # 2. Target distribution in float32 for numerical stability (upstream choice). p_t = F.softmax(blended_logits, dim=-1, dtype=torch.float32) # 3. Forward KL: the gradient flows ONLY through student log-softmax. student_logprobs = F.log_softmax(student_logits, dim=-1, dtype=torch.float32) # 4. Mask out -inf positions in the student logits (upstream guard). inf_mask = torch.isinf(student_logits) prod = torch.masked_fill(p_t * student_logprobs, inf_mask, 0.0) # 5. Per-token cross-entropy = -sum_v p_t(v) * log_q(v); reduce over vocab. per_token = -prod.sum(dim=-1).reshape(-1) flat_mask = mask.reshape(-1).to(per_token.dtype) denom = flat_mask.sum().clamp_min(1.0) loss = (per_token * flat_mask).sum() / denom return loss class TAIDScheduler: """Adaptive momentum-based schedule for TAID's interpolation coefficient ``t``. Stateful, mirrors ``SakanaAI/TAID/src/distil_losses/taid.py:TAID.update_t``. Usage:: sched = TAIDScheduler(num_train_steps=10_000) for step in range(num_train_steps): t = sched.t # current t (float) loss = taid_loss(s_logits, t_logits, mask, t=t) loss.backward(); optimizer.step() sched.update_t(loss.detach(), global_step=step) The schedule is monotone non-decreasing: at each step, the floor is the linear schedule ``t_target = t_start + progress · (t_end - t_start)``, and an adaptive bump ``alpha · σ(momentum) · (1 - t)`` is added on top where ``momentum`` tracks the relative loss change with EMA decay ``beta``. ``disable_adaptive=True`` collapses to the deterministic linear schedule. Args: num_train_steps: total planned training steps; required so the linear floor ``t_target`` is well-defined. t_start: initial ``t`` (paper default 0.4 — the student is already close to the teacher in this regime, so ``t=0`` would waste the warmup phase). t_end: terminal ``t`` (paper default 1.0). alpha: adaptive bump magnitude (paper default 5e-4). beta: EMA decay for the relative-loss-change momentum (paper default 0.99). disable_adaptive: if True, fall back to deterministic linear schedule ``t_target = t_start + progress · (t_end - t_start)``. device: device to allocate state buffers on; default cpu. """ def __init__( self, num_train_steps: int, *, t_start: float = 0.4, t_end: float = 1.0, alpha: float = 5e-4, beta: float = 0.99, disable_adaptive: bool = False, device: torch.device | str = "cpu", ) -> None: if not (0.0 <= t_start < 1.0): raise ValueError(f"t_start must be in [0, 1), got {t_start}") if not (0.0 < t_end <= 1.0): raise ValueError(f"t_end must be in (0, 1], got {t_end}") if not (0.0 <= alpha <= 1.0): raise ValueError(f"alpha must be in [0, 1], got {alpha}") if num_train_steps <= 0: raise ValueError(f"num_train_steps must be > 0, got {num_train_steps}") self.t_start = t_start self.t_end = t_end self.alpha = alpha self.beta = beta self.disable_adaptive = disable_adaptive self.num_train_steps = num_train_steps self._t = torch.tensor(t_start, device=device, dtype=torch.float32) self._prev_loss = torch.tensor( float("inf"), device=device, dtype=torch.float32 ) self._momentum = torch.zeros([], device=device, dtype=torch.float32) @property def t(self) -> float: """Current interpolation coefficient as a Python float.""" return float(self._t) def update_t( self, loss: torch.Tensor, global_step: int, ) -> torch.Tensor | None: """Update internal ``t`` given the current step's distillation loss. Mirrors upstream verbatim. First call with finite loss only seeds ``prev_loss`` and returns None. Subsequent calls update momentum + ``t`` and return the (positive) ``delta_t`` that was added on top of the linear floor (None for the first call). Args: loss: scalar loss tensor (caller should pass ``loss.detach()``). global_step: current global step (0-indexed). Returns: The adaptive ``delta_t`` that was applied, or None if this was the seeding call. """ if torch.isinf(self._prev_loss): self._prev_loss = loss.detach().to(self._prev_loss) return None relative_change = (self._prev_loss - loss) / (self._prev_loss + 1e-15) self._momentum = ( self.beta * self._momentum + (1 - self.beta) * relative_change ) adaptive_delta = torch.sigmoid(self._momentum) progress = global_step / self.num_train_steps t_target = self.t_start + (self.t_end - self.t_start) * progress delta_t = self.alpha * adaptive_delta * (1 - self._t) if self.disable_adaptive: new_t = t_target else: new_t = min(self.t_end, max(t_target, float(self._t + delta_t))) if not isinstance(new_t, torch.Tensor): new_t = torch.tensor(new_t, device=self._t.device, dtype=self._t.dtype) self._t = new_t self._prev_loss = loss.detach().to(self._prev_loss) return delta_t __all__ = ["taid_loss", "TAIDScheduler"]