File size: 4,544 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Entropy-Aware OPD — token-wise gated forward/reverse KL.

Paper: ICLR 2026 Spotlight "Entropy-Aware On-Policy Distillation"
       (OpenReview WSRQ37tzk1, code release pending as of 2026-05-26)

Standard reverse-KL distillation (which SDPO/OPSD belongs to) has a known
mode-seeking failure: when the teacher distribution has high entropy at
some token positions (e.g. open-ended generation), reverse KL collapses
the student onto a single mode, throwing away the teacher's diversity.

Forward KL is mode-covering and would handle these positions correctly,
but is mode-flattening in the long tail.

Entropy-Aware OPD computes the per-token entropy of the teacher
distribution and gates between forward and reverse KL on a per-token
basis: high-entropy tokens use forward KL (preserve diversity),
low-entropy tokens use reverse KL (sharpen toward the teacher's mode).

    L = Σ_t  w(t) · KL_fwd(student || teacher)_t
            + (1 - w(t)) · KL_rev(student || teacher)_t

Where w(t) = clamp(H_teacher(t) / H_max, 0, 1) — high entropy → forward
KL weight near 1, low entropy → reverse KL weight near 1.

This is a clean-room implementation from the paper's pseudocode pending
the official code drop. License question for the official code is open;
this implementation is MIT-compatible by construction.
"""
from __future__ import annotations

import math

import torch
import torch.nn.functional as F


def teacher_entropy(teacher_logits: torch.Tensor) -> torch.Tensor:
    """Per-token entropy of the teacher distribution.

    Returns:
        (B, T) entropy in nats.
    """
    log_p = F.log_softmax(teacher_logits, dim=-1)
    p = log_p.exp()
    # Entropy = -Σ p log p
    return -(p * log_p).sum(dim=-1)


def entropy_aware_opd_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    *,
    labels: torch.Tensor | None = None,
    h_max: float | None = None,
    temperature: float = 1.0,
    reduction: str = "batchmean",
) -> torch.Tensor:
    """Entropy-aware mixture of forward and reverse KL.

    Args:
        student_logits: (B, T, V) student logits with grad
        teacher_logits: (B, T, V) teacher logits (no grad)
        labels: (B, T) optional 0/1 mask — only contribute loss on
            labels==1 positions. None means contribute everywhere.
        h_max: maximum-entropy normalizer. Defaults to log(V) (uniform-
            distribution entropy = the max possible entropy at vocab size V).
        temperature: temperature applied to BOTH student and teacher logits
            before softmax
        reduction: "batchmean" | "sum" | "mean" | "none"

    Returns:
        Scalar loss (or unreduced if `reduction="none"`).

    Reference: ICLR 2026 Spotlight WSRQ37tzk1 §3 (clean-room implementation).
    """
    if student_logits.shape != teacher_logits.shape:
        raise ValueError(
            f"shape mismatch: student={student_logits.shape}, "
            f"teacher={teacher_logits.shape}"
        )

    V = student_logits.size(-1)
    if h_max is None:
        h_max = math.log(V)

    s_log = F.log_softmax(student_logits / temperature, dim=-1)
    t_log = F.log_softmax(teacher_logits / temperature, dim=-1)

    s_p = s_log.exp()
    t_p = t_log.exp()

    # Forward KL (teacher || student): mode-covering
    # KL(t || s) = Σ t · (log t - log s)
    kl_fwd = (t_p * (t_log - s_log)).sum(dim=-1)

    # Reverse KL (student || teacher): mode-seeking (this is what SDPO uses)
    # KL(s || t) = Σ s · (log s - log t)
    kl_rev = (s_p * (s_log - t_log)).sum(dim=-1)

    # Per-token teacher entropy → gate weight
    H_t = teacher_entropy(teacher_logits)  # (B, T) in nats
    w = (H_t / h_max).clamp(0.0, 1.0)  # (B, T) in [0, 1]

    # Mix: high entropy → forward KL; low entropy → reverse KL
    per_token_loss = w * kl_fwd + (1 - w) * kl_rev  # (B, T)

    if labels is not None:
        if labels.shape != per_token_loss.shape:
            raise ValueError(
                f"labels shape {labels.shape} != per-token-loss shape "
                f"{per_token_loss.shape}"
            )
        per_token_loss = per_token_loss * labels.float()

    if reduction == "none":
        return per_token_loss
    if reduction == "sum":
        return per_token_loss.sum()
    if reduction == "mean":
        return per_token_loss.mean()
    if reduction == "batchmean":
        return per_token_loss.sum() / max(1, per_token_loss.shape[0])
    raise ValueError(f"unknown reduction: {reduction!r}")


__all__ = ["teacher_entropy", "entropy_aware_opd_loss"]