Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """TAID loss — Temporally Adaptive Interpolated Distillation. | |
| Paper: "TAID: Temporally Adaptive Interpolated Distillation for Efficient | |
| Knowledge Transfer in Language Models" | |
| Sakana AI, arXiv:2501.16937 | |
| License: Apache-2.0 (https://github.com/SakanaAI/TAID) | |
| This module is a faithful port of the reference implementation at | |
| ``SakanaAI/TAID/src/distil_losses/taid.py``. **The previous in-tree | |
| implementation was algorithmically different from the paper** (it mixed in | |
| probability space against a frozen step-0 student snapshot and wrapped a | |
| symmetric JSD criterion). This rewrite replaces it with the upstream | |
| algorithm: | |
| p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits ) | |
| loss = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v) | |
| That is: | |
| 1. Mix in **logit space**, not probability space. | |
| 2. Anchor against the **current student detached** (re-evaluated each | |
| step), not a frozen step-0 snapshot. | |
| 3. Distillation criterion is **forward KL** (Hinton-style soft target), | |
| not symmetric JSD. | |
| Schedule | |
| -------- | |
| The original implementation embedded an adaptive momentum-based schedule | |
| inside the loss object; this is now factored out into the optional | |
| :class:`TAIDScheduler` so the loss function itself is pure (single ``t`` | |
| in [0, 1]). Callers either: | |
| - Pass a fixed ``t`` for ablations / fixed schedules. | |
| - Drive ``t`` via :class:`TAIDScheduler` (paper-default adaptive scheme). | |
| - Drive ``t`` via any custom schedule of their choosing. | |
| Backward-incompatible change | |
| ---------------------------- | |
| The previous public signature was: | |
| taid_loss(student_logits, teacher_logits, student_init_logits, *, | |
| schedule_step, total_steps, schedule, alpha_min, alpha_max, | |
| jsd_beta, temperature, reduction) | |
| The new signature is: | |
| taid_loss(student_logits, teacher_logits, mask=None, *, t) | |
| Removed kwargs (``student_init_logits``, ``schedule_step``, ``total_steps``, | |
| ``schedule``, ``alpha_min``, ``alpha_max``, ``jsd_beta``, ``temperature``, | |
| ``reduction``) have no upstream analogue. Pass ``t`` directly; if you need | |
| a schedule, use :class:`TAIDScheduler` or compute ``t`` yourself. | |
| Reference: arXiv:2501.16937; ``SakanaAI/TAID`` commit history. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn.functional as F | |
| def taid_loss( | |
| student_logits: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| *, | |
| t: float | torch.Tensor, | |
| ) -> torch.Tensor: | |
| """TAID forward-KL loss against a logit-space-interpolated target. | |
| Faithful port of ``SakanaAI/TAID/src/distil_losses/taid.py:compute_loss`` | |
| composed with ``fkl.forward_kl``. | |
| Pseudocode:: | |
| p_t = softmax( (1 - t) · student_logits.detach() + t · teacher_logits ) | |
| log_q = log_softmax( student_logits ) | |
| per_token = - Σ_v p_t(v) · log_q(v) # forward KL token-wise | |
| loss = sum(per_token · mask) / sum(mask) | |
| Args: | |
| student_logits: ``(B, T, V)`` current student logits, with grad. | |
| teacher_logits: ``(B, T, V)`` teacher logits (no grad expected; | |
| detached internally only insofar as the interpolation uses the | |
| student detach — teacher gradient is left untouched, matching | |
| upstream). | |
| mask: ``(B, T)`` token mask (1 = include, 0 = ignore). Required by | |
| upstream; defaults to all-ones if omitted for convenience. | |
| t: interpolation coefficient in ``[0, 1]``. Scalar Python float or | |
| 0-d torch.Tensor. ``t=0`` makes the target match the (detached) | |
| student — a regularizer with zero gradient signal. ``t=1`` makes | |
| the target the teacher — pure forward-KL distillation. | |
| Returns: | |
| Scalar loss (token-mean, in float32 dtype matching upstream). | |
| Raises: | |
| ValueError: shape mismatch between student/teacher, or invalid mask | |
| shape. | |
| Reference: arXiv:2501.16937 §3.1 + Eq. (4); upstream commit at | |
| ``SakanaAI/TAID@main:src/distil_losses/taid.py``. | |
| """ | |
| if student_logits.shape != teacher_logits.shape: | |
| raise ValueError( | |
| f"student/teacher logits shape mismatch: " | |
| f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}" | |
| ) | |
| if mask is None: | |
| mask = student_logits.new_ones(student_logits.shape[:-1]) | |
| elif mask.shape != student_logits.shape[:-1]: | |
| raise ValueError( | |
| f"mask shape {tuple(mask.shape)} does not match logits prefix " | |
| f"{tuple(student_logits.shape[:-1])}" | |
| ) | |
| # 1. Logit-space mix with student detached (anchor = current student, no grad). | |
| blended_logits = (1 - t) * student_logits.detach() + t * teacher_logits | |
| # 2. Target distribution in float32 for numerical stability (upstream choice). | |
| p_t = F.softmax(blended_logits, dim=-1, dtype=torch.float32) | |
| # 3. Forward KL: the gradient flows ONLY through student log-softmax. | |
| student_logprobs = F.log_softmax(student_logits, dim=-1, dtype=torch.float32) | |
| # 4. Mask out -inf positions in the student logits (upstream guard). | |
| inf_mask = torch.isinf(student_logits) | |
| prod = torch.masked_fill(p_t * student_logprobs, inf_mask, 0.0) | |
| # 5. Per-token cross-entropy = -sum_v p_t(v) * log_q(v); reduce over vocab. | |
| per_token = -prod.sum(dim=-1).reshape(-1) | |
| flat_mask = mask.reshape(-1).to(per_token.dtype) | |
| denom = flat_mask.sum().clamp_min(1.0) | |
| loss = (per_token * flat_mask).sum() / denom | |
| return loss | |
| class TAIDScheduler: | |
| """Adaptive momentum-based schedule for TAID's interpolation coefficient ``t``. | |
| Stateful, mirrors ``SakanaAI/TAID/src/distil_losses/taid.py:TAID.update_t``. | |
| Usage:: | |
| sched = TAIDScheduler(num_train_steps=10_000) | |
| for step in range(num_train_steps): | |
| t = sched.t # current t (float) | |
| loss = taid_loss(s_logits, t_logits, mask, t=t) | |
| loss.backward(); optimizer.step() | |
| sched.update_t(loss.detach(), global_step=step) | |
| The schedule is monotone non-decreasing: at each step, the floor is the | |
| linear schedule ``t_target = t_start + progress · (t_end - t_start)``, | |
| and an adaptive bump ``alpha · σ(momentum) · (1 - t)`` is added on top | |
| where ``momentum`` tracks the relative loss change with EMA decay | |
| ``beta``. ``disable_adaptive=True`` collapses to the deterministic linear | |
| schedule. | |
| Args: | |
| num_train_steps: total planned training steps; required so the linear | |
| floor ``t_target`` is well-defined. | |
| t_start: initial ``t`` (paper default 0.4 — the student is already | |
| close to the teacher in this regime, so ``t=0`` would waste the | |
| warmup phase). | |
| t_end: terminal ``t`` (paper default 1.0). | |
| alpha: adaptive bump magnitude (paper default 5e-4). | |
| beta: EMA decay for the relative-loss-change momentum (paper default | |
| 0.99). | |
| disable_adaptive: if True, fall back to deterministic linear schedule | |
| ``t_target = t_start + progress · (t_end - t_start)``. | |
| device: device to allocate state buffers on; default cpu. | |
| """ | |
| def __init__( | |
| self, | |
| num_train_steps: int, | |
| *, | |
| t_start: float = 0.4, | |
| t_end: float = 1.0, | |
| alpha: float = 5e-4, | |
| beta: float = 0.99, | |
| disable_adaptive: bool = False, | |
| device: torch.device | str = "cpu", | |
| ) -> None: | |
| if not (0.0 <= t_start < 1.0): | |
| raise ValueError(f"t_start must be in [0, 1), got {t_start}") | |
| if not (0.0 < t_end <= 1.0): | |
| raise ValueError(f"t_end must be in (0, 1], got {t_end}") | |
| if not (0.0 <= alpha <= 1.0): | |
| raise ValueError(f"alpha must be in [0, 1], got {alpha}") | |
| if num_train_steps <= 0: | |
| raise ValueError(f"num_train_steps must be > 0, got {num_train_steps}") | |
| self.t_start = t_start | |
| self.t_end = t_end | |
| self.alpha = alpha | |
| self.beta = beta | |
| self.disable_adaptive = disable_adaptive | |
| self.num_train_steps = num_train_steps | |
| self._t = torch.tensor(t_start, device=device, dtype=torch.float32) | |
| self._prev_loss = torch.tensor( | |
| float("inf"), device=device, dtype=torch.float32 | |
| ) | |
| self._momentum = torch.zeros([], device=device, dtype=torch.float32) | |
| def t(self) -> float: | |
| """Current interpolation coefficient as a Python float.""" | |
| return float(self._t) | |
| def update_t( | |
| self, | |
| loss: torch.Tensor, | |
| global_step: int, | |
| ) -> torch.Tensor | None: | |
| """Update internal ``t`` given the current step's distillation loss. | |
| Mirrors upstream verbatim. First call with finite loss only seeds | |
| ``prev_loss`` and returns None. Subsequent calls update momentum + | |
| ``t`` and return the (positive) ``delta_t`` that was added on top of | |
| the linear floor (None for the first call). | |
| Args: | |
| loss: scalar loss tensor (caller should pass ``loss.detach()``). | |
| global_step: current global step (0-indexed). | |
| Returns: | |
| The adaptive ``delta_t`` that was applied, or None if this was | |
| the seeding call. | |
| """ | |
| if torch.isinf(self._prev_loss): | |
| self._prev_loss = loss.detach().to(self._prev_loss) | |
| return None | |
| relative_change = (self._prev_loss - loss) / (self._prev_loss + 1e-15) | |
| self._momentum = ( | |
| self.beta * self._momentum + (1 - self.beta) * relative_change | |
| ) | |
| adaptive_delta = torch.sigmoid(self._momentum) | |
| progress = global_step / self.num_train_steps | |
| t_target = self.t_start + (self.t_end - self.t_start) * progress | |
| delta_t = self.alpha * adaptive_delta * (1 - self._t) | |
| if self.disable_adaptive: | |
| new_t = t_target | |
| else: | |
| new_t = min(self.t_end, max(t_target, float(self._t + delta_t))) | |
| if not isinstance(new_t, torch.Tensor): | |
| new_t = torch.tensor(new_t, device=self._t.device, dtype=self._t.dtype) | |
| self._t = new_t | |
| self._prev_loss = loss.detach().to(self._prev_loss) | |
| return delta_t | |
| __all__ = ["taid_loss", "TAIDScheduler"] | |