File size: 10,325 Bytes
b266c31
 
 
 
 
 
 
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
 
 
 
 
 
 
e5add15
 
 
 
b266c31
e5add15
 
 
b266c31
e5add15
 
b266c31
e5add15
b266c31
e5add15
 
 
 
b266c31
 
e5add15
 
 
 
 
 
 
 
 
 
 
b266c31
 
e5add15
 
 
 
 
 
 
 
b266c31
e5add15
b266c31
e5add15
 
 
 
 
 
 
 
 
b266c31
 
e5add15
 
b266c31
e5add15
 
b266c31
e5add15
 
b266c31
e5add15
 
 
b266c31
e5add15
 
 
 
 
 
b266c31
 
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""TAID loss — Temporally Adaptive Interpolated Distillation.

Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient
        Knowledge Transfer in Language Models"
       Sakana AI, arXiv:2501.16937
License: Apache-2.0 (https://github.com/SakanaAI/TAID)

This module is a faithful port of the reference implementation at
``SakanaAI/TAID/src/distil_losses/taid.py``. **The previous in-tree
implementation was algorithmically different from the paper** (it mixed in
probability space against a frozen step-0 student snapshot and wrapped a
symmetric JSD criterion). This rewrite replaces it with the upstream
algorithm:

    p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
    loss = - mean_token  Σ_v  p_t(v) · log_softmax(student_logits)(v)

That is:
    1. Mix in **logit space**, not probability space.
    2. Anchor against the **current student detached** (re-evaluated each
       step), not a frozen step-0 snapshot.
    3. Distillation criterion is **forward KL** (Hinton-style soft target),
       not symmetric JSD.

Schedule
--------
The original implementation embedded an adaptive momentum-based schedule
inside the loss object; this is now factored out into the optional
:class:`TAIDScheduler` so the loss function itself is pure (single ``t``
in [0, 1]). Callers either:

- Pass a fixed ``t`` for ablations / fixed schedules.
- Drive ``t`` via :class:`TAIDScheduler` (paper-default adaptive scheme).
- Drive ``t`` via any custom schedule of their choosing.

Backward-incompatible change
----------------------------
The previous public signature was:

    taid_loss(student_logits, teacher_logits, student_init_logits, *,
              schedule_step, total_steps, schedule, alpha_min, alpha_max,
              jsd_beta, temperature, reduction)

The new signature is:

    taid_loss(student_logits, teacher_logits, mask=None, *, t)

Removed kwargs (``student_init_logits``, ``schedule_step``, ``total_steps``,
``schedule``, ``alpha_min``, ``alpha_max``, ``jsd_beta``, ``temperature``,
``reduction``) have no upstream analogue. Pass ``t`` directly; if you need
a schedule, use :class:`TAIDScheduler` or compute ``t`` yourself.

Reference: arXiv:2501.16937; ``SakanaAI/TAID`` commit history.
"""
from __future__ import annotations

import torch
import torch.nn.functional as F


def taid_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    mask: torch.Tensor | None = None,
    *,
    t: float | torch.Tensor,
) -> torch.Tensor:
    """TAID forward-KL loss against a logit-space-interpolated target.

    Faithful port of ``SakanaAI/TAID/src/distil_losses/taid.py:compute_loss``
    composed with ``fkl.forward_kl``.

    Pseudocode::

        p_t = softmax( (1 - t) · student_logits.detach() + t · teacher_logits )
        log_q = log_softmax( student_logits )
        per_token = - Σ_v  p_t(v) · log_q(v)            # forward KL token-wise
        loss = sum(per_token · mask) / sum(mask)

    Args:
        student_logits: ``(B, T, V)`` current student logits, with grad.
        teacher_logits: ``(B, T, V)`` teacher logits (no grad expected;
            detached internally only insofar as the interpolation uses the
            student detach — teacher gradient is left untouched, matching
            upstream).
        mask: ``(B, T)`` token mask (1 = include, 0 = ignore). Required by
            upstream; defaults to all-ones if omitted for convenience.
        t: interpolation coefficient in ``[0, 1]``. Scalar Python float or
            0-d torch.Tensor. ``t=0`` makes the target match the (detached)
            student — a regularizer with zero gradient signal. ``t=1`` makes
            the target the teacher — pure forward-KL distillation.

    Returns:
        Scalar loss (token-mean, in float32 dtype matching upstream).

    Raises:
        ValueError: shape mismatch between student/teacher, or invalid mask
            shape.

    Reference: arXiv:2501.16937 §3.1 + Eq. (4); upstream commit at
        ``SakanaAI/TAID@main:src/distil_losses/taid.py``.
    """
    if student_logits.shape != teacher_logits.shape:
        raise ValueError(
            f"student/teacher logits shape mismatch: "
            f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}"
        )
    if mask is None:
        mask = student_logits.new_ones(student_logits.shape[:-1])
    elif mask.shape != student_logits.shape[:-1]:
        raise ValueError(
            f"mask shape {tuple(mask.shape)} does not match logits prefix "
            f"{tuple(student_logits.shape[:-1])}"
        )

    # 1. Logit-space mix with student detached (anchor = current student, no grad).
    blended_logits = (1 - t) * student_logits.detach() + t * teacher_logits

    # 2. Target distribution in float32 for numerical stability (upstream choice).
    p_t = F.softmax(blended_logits, dim=-1, dtype=torch.float32)

    # 3. Forward KL: the gradient flows ONLY through student log-softmax.
    student_logprobs = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)

    # 4. Mask out -inf positions in the student logits (upstream guard).
    inf_mask = torch.isinf(student_logits)
    prod = torch.masked_fill(p_t * student_logprobs, inf_mask, 0.0)

    # 5. Per-token cross-entropy = -sum_v p_t(v) * log_q(v); reduce over vocab.
    per_token = -prod.sum(dim=-1).reshape(-1)
    flat_mask = mask.reshape(-1).to(per_token.dtype)
    denom = flat_mask.sum().clamp_min(1.0)
    loss = (per_token * flat_mask).sum() / denom
    return loss


class TAIDScheduler:
    """Adaptive momentum-based schedule for TAID's interpolation coefficient ``t``.

    Stateful, mirrors ``SakanaAI/TAID/src/distil_losses/taid.py:TAID.update_t``.

    Usage::

        sched = TAIDScheduler(num_train_steps=10_000)
        for step in range(num_train_steps):
            t = sched.t                         # current t (float)
            loss = taid_loss(s_logits, t_logits, mask, t=t)
            loss.backward(); optimizer.step()
            sched.update_t(loss.detach(), global_step=step)

    The schedule is monotone non-decreasing: at each step, the floor is the
    linear schedule ``t_target = t_start + progress · (t_end - t_start)``,
    and an adaptive bump ``alpha · σ(momentum) · (1 - t)`` is added on top
    where ``momentum`` tracks the relative loss change with EMA decay
    ``beta``. ``disable_adaptive=True`` collapses to the deterministic linear
    schedule.

    Args:
        num_train_steps: total planned training steps; required so the linear
            floor ``t_target`` is well-defined.
        t_start: initial ``t`` (paper default 0.4 — the student is already
            close to the teacher in this regime, so ``t=0`` would waste the
            warmup phase).
        t_end: terminal ``t`` (paper default 1.0).
        alpha: adaptive bump magnitude (paper default 5e-4).
        beta: EMA decay for the relative-loss-change momentum (paper default
            0.99).
        disable_adaptive: if True, fall back to deterministic linear schedule
            ``t_target = t_start + progress · (t_end - t_start)``.
        device: device to allocate state buffers on; default cpu.
    """

    def __init__(
        self,
        num_train_steps: int,
        *,
        t_start: float = 0.4,
        t_end: float = 1.0,
        alpha: float = 5e-4,
        beta: float = 0.99,
        disable_adaptive: bool = False,
        device: torch.device | str = "cpu",
    ) -> None:
        if not (0.0 <= t_start < 1.0):
            raise ValueError(f"t_start must be in [0, 1), got {t_start}")
        if not (0.0 < t_end <= 1.0):
            raise ValueError(f"t_end must be in (0, 1], got {t_end}")
        if not (0.0 <= alpha <= 1.0):
            raise ValueError(f"alpha must be in [0, 1], got {alpha}")
        if num_train_steps <= 0:
            raise ValueError(f"num_train_steps must be > 0, got {num_train_steps}")

        self.t_start = t_start
        self.t_end = t_end
        self.alpha = alpha
        self.beta = beta
        self.disable_adaptive = disable_adaptive
        self.num_train_steps = num_train_steps

        self._t = torch.tensor(t_start, device=device, dtype=torch.float32)
        self._prev_loss = torch.tensor(
            float("inf"), device=device, dtype=torch.float32
        )
        self._momentum = torch.zeros([], device=device, dtype=torch.float32)

    @property
    def t(self) -> float:
        """Current interpolation coefficient as a Python float."""
        return float(self._t)

    def update_t(
        self,
        loss: torch.Tensor,
        global_step: int,
    ) -> torch.Tensor | None:
        """Update internal ``t`` given the current step's distillation loss.

        Mirrors upstream verbatim. First call with finite loss only seeds
        ``prev_loss`` and returns None. Subsequent calls update momentum +
        ``t`` and return the (positive) ``delta_t`` that was added on top of
        the linear floor (None for the first call).

        Args:
            loss: scalar loss tensor (caller should pass ``loss.detach()``).
            global_step: current global step (0-indexed).

        Returns:
            The adaptive ``delta_t`` that was applied, or None if this was
            the seeding call.
        """
        if torch.isinf(self._prev_loss):
            self._prev_loss = loss.detach().to(self._prev_loss)
            return None

        relative_change = (self._prev_loss - loss) / (self._prev_loss + 1e-15)
        self._momentum = (
            self.beta * self._momentum + (1 - self.beta) * relative_change
        )

        adaptive_delta = torch.sigmoid(self._momentum)
        progress = global_step / self.num_train_steps
        t_target = self.t_start + (self.t_end - self.t_start) * progress
        delta_t = self.alpha * adaptive_delta * (1 - self._t)

        if self.disable_adaptive:
            new_t = t_target
        else:
            new_t = min(self.t_end, max(t_target, float(self._t + delta_t)))

        if not isinstance(new_t, torch.Tensor):
            new_t = torch.tensor(new_t, device=self._t.device, dtype=self._t.dtype)
        self._t = new_t
        self._prev_loss = loss.detach().to(self._prev_loss)
        return delta_t


__all__ = ["taid_loss", "TAIDScheduler"]