Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 10,325 Bytes
b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 b266c31 e5add15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | """TAID loss — Temporally Adaptive Interpolated Distillation.
Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient
Knowledge Transfer in Language Models"
Sakana AI, arXiv:2501.16937
License: Apache-2.0 (https://github.com/SakanaAI/TAID)
This module is a faithful port of the reference implementation at
``SakanaAI/TAID/src/distil_losses/taid.py``. **The previous in-tree
implementation was algorithmically different from the paper** (it mixed in
probability space against a frozen step-0 student snapshot and wrapped a
symmetric JSD criterion). This rewrite replaces it with the upstream
algorithm:
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
loss = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
That is:
1. Mix in **logit space**, not probability space.
2. Anchor against the **current student detached** (re-evaluated each
step), not a frozen step-0 snapshot.
3. Distillation criterion is **forward KL** (Hinton-style soft target),
not symmetric JSD.
Schedule
--------
The original implementation embedded an adaptive momentum-based schedule
inside the loss object; this is now factored out into the optional
:class:`TAIDScheduler` so the loss function itself is pure (single ``t``
in [0, 1]). Callers either:
- Pass a fixed ``t`` for ablations / fixed schedules.
- Drive ``t`` via :class:`TAIDScheduler` (paper-default adaptive scheme).
- Drive ``t`` via any custom schedule of their choosing.
Backward-incompatible change
----------------------------
The previous public signature was:
taid_loss(student_logits, teacher_logits, student_init_logits, *,
schedule_step, total_steps, schedule, alpha_min, alpha_max,
jsd_beta, temperature, reduction)
The new signature is:
taid_loss(student_logits, teacher_logits, mask=None, *, t)
Removed kwargs (``student_init_logits``, ``schedule_step``, ``total_steps``,
``schedule``, ``alpha_min``, ``alpha_max``, ``jsd_beta``, ``temperature``,
``reduction``) have no upstream analogue. Pass ``t`` directly; if you need
a schedule, use :class:`TAIDScheduler` or compute ``t`` yourself.
Reference: arXiv:2501.16937; ``SakanaAI/TAID`` commit history.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
def taid_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor | None = None,
*,
t: float | torch.Tensor,
) -> torch.Tensor:
"""TAID forward-KL loss against a logit-space-interpolated target.
Faithful port of ``SakanaAI/TAID/src/distil_losses/taid.py:compute_loss``
composed with ``fkl.forward_kl``.
Pseudocode::
p_t = softmax( (1 - t) · student_logits.detach() + t · teacher_logits )
log_q = log_softmax( student_logits )
per_token = - Σ_v p_t(v) · log_q(v) # forward KL token-wise
loss = sum(per_token · mask) / sum(mask)
Args:
student_logits: ``(B, T, V)`` current student logits, with grad.
teacher_logits: ``(B, T, V)`` teacher logits (no grad expected;
detached internally only insofar as the interpolation uses the
student detach — teacher gradient is left untouched, matching
upstream).
mask: ``(B, T)`` token mask (1 = include, 0 = ignore). Required by
upstream; defaults to all-ones if omitted for convenience.
t: interpolation coefficient in ``[0, 1]``. Scalar Python float or
0-d torch.Tensor. ``t=0`` makes the target match the (detached)
student — a regularizer with zero gradient signal. ``t=1`` makes
the target the teacher — pure forward-KL distillation.
Returns:
Scalar loss (token-mean, in float32 dtype matching upstream).
Raises:
ValueError: shape mismatch between student/teacher, or invalid mask
shape.
Reference: arXiv:2501.16937 §3.1 + Eq. (4); upstream commit at
``SakanaAI/TAID@main:src/distil_losses/taid.py``.
"""
if student_logits.shape != teacher_logits.shape:
raise ValueError(
f"student/teacher logits shape mismatch: "
f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}"
)
if mask is None:
mask = student_logits.new_ones(student_logits.shape[:-1])
elif mask.shape != student_logits.shape[:-1]:
raise ValueError(
f"mask shape {tuple(mask.shape)} does not match logits prefix "
f"{tuple(student_logits.shape[:-1])}"
)
# 1. Logit-space mix with student detached (anchor = current student, no grad).
blended_logits = (1 - t) * student_logits.detach() + t * teacher_logits
# 2. Target distribution in float32 for numerical stability (upstream choice).
p_t = F.softmax(blended_logits, dim=-1, dtype=torch.float32)
# 3. Forward KL: the gradient flows ONLY through student log-softmax.
student_logprobs = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
# 4. Mask out -inf positions in the student logits (upstream guard).
inf_mask = torch.isinf(student_logits)
prod = torch.masked_fill(p_t * student_logprobs, inf_mask, 0.0)
# 5. Per-token cross-entropy = -sum_v p_t(v) * log_q(v); reduce over vocab.
per_token = -prod.sum(dim=-1).reshape(-1)
flat_mask = mask.reshape(-1).to(per_token.dtype)
denom = flat_mask.sum().clamp_min(1.0)
loss = (per_token * flat_mask).sum() / denom
return loss
class TAIDScheduler:
"""Adaptive momentum-based schedule for TAID's interpolation coefficient ``t``.
Stateful, mirrors ``SakanaAI/TAID/src/distil_losses/taid.py:TAID.update_t``.
Usage::
sched = TAIDScheduler(num_train_steps=10_000)
for step in range(num_train_steps):
t = sched.t # current t (float)
loss = taid_loss(s_logits, t_logits, mask, t=t)
loss.backward(); optimizer.step()
sched.update_t(loss.detach(), global_step=step)
The schedule is monotone non-decreasing: at each step, the floor is the
linear schedule ``t_target = t_start + progress · (t_end - t_start)``,
and an adaptive bump ``alpha · σ(momentum) · (1 - t)`` is added on top
where ``momentum`` tracks the relative loss change with EMA decay
``beta``. ``disable_adaptive=True`` collapses to the deterministic linear
schedule.
Args:
num_train_steps: total planned training steps; required so the linear
floor ``t_target`` is well-defined.
t_start: initial ``t`` (paper default 0.4 — the student is already
close to the teacher in this regime, so ``t=0`` would waste the
warmup phase).
t_end: terminal ``t`` (paper default 1.0).
alpha: adaptive bump magnitude (paper default 5e-4).
beta: EMA decay for the relative-loss-change momentum (paper default
0.99).
disable_adaptive: if True, fall back to deterministic linear schedule
``t_target = t_start + progress · (t_end - t_start)``.
device: device to allocate state buffers on; default cpu.
"""
def __init__(
self,
num_train_steps: int,
*,
t_start: float = 0.4,
t_end: float = 1.0,
alpha: float = 5e-4,
beta: float = 0.99,
disable_adaptive: bool = False,
device: torch.device | str = "cpu",
) -> None:
if not (0.0 <= t_start < 1.0):
raise ValueError(f"t_start must be in [0, 1), got {t_start}")
if not (0.0 < t_end <= 1.0):
raise ValueError(f"t_end must be in (0, 1], got {t_end}")
if not (0.0 <= alpha <= 1.0):
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
if num_train_steps <= 0:
raise ValueError(f"num_train_steps must be > 0, got {num_train_steps}")
self.t_start = t_start
self.t_end = t_end
self.alpha = alpha
self.beta = beta
self.disable_adaptive = disable_adaptive
self.num_train_steps = num_train_steps
self._t = torch.tensor(t_start, device=device, dtype=torch.float32)
self._prev_loss = torch.tensor(
float("inf"), device=device, dtype=torch.float32
)
self._momentum = torch.zeros([], device=device, dtype=torch.float32)
@property
def t(self) -> float:
"""Current interpolation coefficient as a Python float."""
return float(self._t)
def update_t(
self,
loss: torch.Tensor,
global_step: int,
) -> torch.Tensor | None:
"""Update internal ``t`` given the current step's distillation loss.
Mirrors upstream verbatim. First call with finite loss only seeds
``prev_loss`` and returns None. Subsequent calls update momentum +
``t`` and return the (positive) ``delta_t`` that was added on top of
the linear floor (None for the first call).
Args:
loss: scalar loss tensor (caller should pass ``loss.detach()``).
global_step: current global step (0-indexed).
Returns:
The adaptive ``delta_t`` that was applied, or None if this was
the seeding call.
"""
if torch.isinf(self._prev_loss):
self._prev_loss = loss.detach().to(self._prev_loss)
return None
relative_change = (self._prev_loss - loss) / (self._prev_loss + 1e-15)
self._momentum = (
self.beta * self._momentum + (1 - self.beta) * relative_change
)
adaptive_delta = torch.sigmoid(self._momentum)
progress = global_step / self.num_train_steps
t_target = self.t_start + (self.t_end - self.t_start) * progress
delta_t = self.alpha * adaptive_delta * (1 - self._t)
if self.disable_adaptive:
new_t = t_target
else:
new_t = min(self.t_end, max(t_target, float(self._t + delta_t)))
if not isinstance(new_t, torch.Tensor):
new_t = torch.tensor(new_t, device=self._t.device, dtype=self._t.dtype)
self._t = new_t
self._prev_loss = loss.detach().to(self._prev_loss)
return delta_t
__all__ = ["taid_loss", "TAIDScheduler"]
|