"""Entropy-Aware OPD — token-wise gated forward/reverse KL. Paper: ICLR 2026 Spotlight "Entropy-Aware On-Policy Distillation" (OpenReview WSRQ37tzk1, code release pending as of 2026-05-26) Standard reverse-KL distillation (which SDPO/OPSD belongs to) has a known mode-seeking failure: when the teacher distribution has high entropy at some token positions (e.g. open-ended generation), reverse KL collapses the student onto a single mode, throwing away the teacher's diversity. Forward KL is mode-covering and would handle these positions correctly, but is mode-flattening in the long tail. Entropy-Aware OPD computes the per-token entropy of the teacher distribution and gates between forward and reverse KL on a per-token basis: high-entropy tokens use forward KL (preserve diversity), low-entropy tokens use reverse KL (sharpen toward the teacher's mode). L = Σ_t w(t) · KL_fwd(student || teacher)_t + (1 - w(t)) · KL_rev(student || teacher)_t Where w(t) = clamp(H_teacher(t) / H_max, 0, 1) — high entropy → forward KL weight near 1, low entropy → reverse KL weight near 1. This is a clean-room implementation from the paper's pseudocode pending the official code drop. License question for the official code is open; this implementation is MIT-compatible by construction. """ from __future__ import annotations import math import torch import torch.nn.functional as F def teacher_entropy(teacher_logits: torch.Tensor) -> torch.Tensor: """Per-token entropy of the teacher distribution. Returns: (B, T) entropy in nats. """ log_p = F.log_softmax(teacher_logits, dim=-1) p = log_p.exp() # Entropy = -Σ p log p return -(p * log_p).sum(dim=-1) def entropy_aware_opd_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, labels: torch.Tensor | None = None, h_max: float | None = None, temperature: float = 1.0, reduction: str = "batchmean", ) -> torch.Tensor: """Entropy-aware mixture of forward and reverse KL. Args: student_logits: (B, T, V) student logits with grad teacher_logits: (B, T, V) teacher logits (no grad) labels: (B, T) optional 0/1 mask — only contribute loss on labels==1 positions. None means contribute everywhere. h_max: maximum-entropy normalizer. Defaults to log(V) (uniform- distribution entropy = the max possible entropy at vocab size V). temperature: temperature applied to BOTH student and teacher logits before softmax reduction: "batchmean" | "sum" | "mean" | "none" Returns: Scalar loss (or unreduced if `reduction="none"`). Reference: ICLR 2026 Spotlight WSRQ37tzk1 §3 (clean-room implementation). """ if student_logits.shape != teacher_logits.shape: raise ValueError( f"shape mismatch: student={student_logits.shape}, " f"teacher={teacher_logits.shape}" ) V = student_logits.size(-1) if h_max is None: h_max = math.log(V) s_log = F.log_softmax(student_logits / temperature, dim=-1) t_log = F.log_softmax(teacher_logits / temperature, dim=-1) s_p = s_log.exp() t_p = t_log.exp() # Forward KL (teacher || student): mode-covering # KL(t || s) = Σ t · (log t - log s) kl_fwd = (t_p * (t_log - s_log)).sum(dim=-1) # Reverse KL (student || teacher): mode-seeking (this is what SDPO uses) # KL(s || t) = Σ s · (log s - log t) kl_rev = (s_p * (s_log - t_log)).sum(dim=-1) # Per-token teacher entropy → gate weight H_t = teacher_entropy(teacher_logits) # (B, T) in nats w = (H_t / h_max).clamp(0.0, 1.0) # (B, T) in [0, 1] # Mix: high entropy → forward KL; low entropy → reverse KL per_token_loss = w * kl_fwd + (1 - w) * kl_rev # (B, T) if labels is not None: if labels.shape != per_token_loss.shape: raise ValueError( f"labels shape {labels.shape} != per-token-loss shape " f"{per_token_loss.shape}" ) per_token_loss = per_token_loss * labels.float() if reduction == "none": return per_token_loss if reduction == "sum": return per_token_loss.sum() if reduction == "mean": return per_token_loss.mean() if reduction == "batchmean": return per_token_loss.sum() / max(1, per_token_loss.shape[0]) raise ValueError(f"unknown reduction: {reduction!r}") __all__ = ["teacher_entropy", "entropy_aware_opd_loss"]