Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 4,544 Bytes
b266c31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """Entropy-Aware OPD — token-wise gated forward/reverse KL.
Paper: ICLR 2026 Spotlight "Entropy-Aware On-Policy Distillation"
(OpenReview WSRQ37tzk1, code release pending as of 2026-05-26)
Standard reverse-KL distillation (which SDPO/OPSD belongs to) has a known
mode-seeking failure: when the teacher distribution has high entropy at
some token positions (e.g. open-ended generation), reverse KL collapses
the student onto a single mode, throwing away the teacher's diversity.
Forward KL is mode-covering and would handle these positions correctly,
but is mode-flattening in the long tail.
Entropy-Aware OPD computes the per-token entropy of the teacher
distribution and gates between forward and reverse KL on a per-token
basis: high-entropy tokens use forward KL (preserve diversity),
low-entropy tokens use reverse KL (sharpen toward the teacher's mode).
L = Σ_t w(t) · KL_fwd(student || teacher)_t
+ (1 - w(t)) · KL_rev(student || teacher)_t
Where w(t) = clamp(H_teacher(t) / H_max, 0, 1) — high entropy → forward
KL weight near 1, low entropy → reverse KL weight near 1.
This is a clean-room implementation from the paper's pseudocode pending
the official code drop. License question for the official code is open;
this implementation is MIT-compatible by construction.
"""
from __future__ import annotations
import math
import torch
import torch.nn.functional as F
def teacher_entropy(teacher_logits: torch.Tensor) -> torch.Tensor:
"""Per-token entropy of the teacher distribution.
Returns:
(B, T) entropy in nats.
"""
log_p = F.log_softmax(teacher_logits, dim=-1)
p = log_p.exp()
# Entropy = -Σ p log p
return -(p * log_p).sum(dim=-1)
def entropy_aware_opd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
*,
labels: torch.Tensor | None = None,
h_max: float | None = None,
temperature: float = 1.0,
reduction: str = "batchmean",
) -> torch.Tensor:
"""Entropy-aware mixture of forward and reverse KL.
Args:
student_logits: (B, T, V) student logits with grad
teacher_logits: (B, T, V) teacher logits (no grad)
labels: (B, T) optional 0/1 mask — only contribute loss on
labels==1 positions. None means contribute everywhere.
h_max: maximum-entropy normalizer. Defaults to log(V) (uniform-
distribution entropy = the max possible entropy at vocab size V).
temperature: temperature applied to BOTH student and teacher logits
before softmax
reduction: "batchmean" | "sum" | "mean" | "none"
Returns:
Scalar loss (or unreduced if `reduction="none"`).
Reference: ICLR 2026 Spotlight WSRQ37tzk1 §3 (clean-room implementation).
"""
if student_logits.shape != teacher_logits.shape:
raise ValueError(
f"shape mismatch: student={student_logits.shape}, "
f"teacher={teacher_logits.shape}"
)
V = student_logits.size(-1)
if h_max is None:
h_max = math.log(V)
s_log = F.log_softmax(student_logits / temperature, dim=-1)
t_log = F.log_softmax(teacher_logits / temperature, dim=-1)
s_p = s_log.exp()
t_p = t_log.exp()
# Forward KL (teacher || student): mode-covering
# KL(t || s) = Σ t · (log t - log s)
kl_fwd = (t_p * (t_log - s_log)).sum(dim=-1)
# Reverse KL (student || teacher): mode-seeking (this is what SDPO uses)
# KL(s || t) = Σ s · (log s - log t)
kl_rev = (s_p * (s_log - t_log)).sum(dim=-1)
# Per-token teacher entropy → gate weight
H_t = teacher_entropy(teacher_logits) # (B, T) in nats
w = (H_t / h_max).clamp(0.0, 1.0) # (B, T) in [0, 1]
# Mix: high entropy → forward KL; low entropy → reverse KL
per_token_loss = w * kl_fwd + (1 - w) * kl_rev # (B, T)
if labels is not None:
if labels.shape != per_token_loss.shape:
raise ValueError(
f"labels shape {labels.shape} != per-token-loss shape "
f"{per_token_loss.shape}"
)
per_token_loss = per_token_loss * labels.float()
if reduction == "none":
return per_token_loss
if reduction == "sum":
return per_token_loss.sum()
if reduction == "mean":
return per_token_loss.mean()
if reduction == "batchmean":
return per_token_loss.sum() / max(1, per_token_loss.shape[0])
raise ValueError(f"unknown reduction: {reduction!r}")
__all__ = ["teacher_entropy", "entropy_aware_opd_loss"]
|