File size: 2,932 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""SimPO loss — reference-free DPO replacement.

Paper: "SimPO: Simple Preference Optimization with a Reference-Free Reward"
       Meng et al., NeurIPS 2024 (arXiv:2405.14734)
License: MIT (https://github.com/princeton-nlp/SimPO)

Standard DPO requires log-probabilities under both the policy and a
reference policy:

    L_DPO = -log σ( β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))] )

SimPO drops the reference-policy term, replaces it with a target margin γ,
and uses average sequence log-probability instead of sum. This removes the
reference-model VRAM cost (which is a meaningful fraction of total
training-time memory).

    L_SimPO = -log σ( β·[avg_logπ(c) - avg_logπ(r)] - γ )

Where:
- avg_logπ(c) = (1/|c|) · Σ_t logπ(c_t | c_<t, prompt)
- β: scaling factor (paper default: 2.0)
- γ: target margin (paper default: 1.0)

Compose with the framework: replace channel-3 `_compute_trace_replay_loss`
when `dpo_variant="simpo"` is passed to `compose_loss`. Inputs change:
SimPO does NOT consume `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`
(those become unused).
"""
from __future__ import annotations

import torch
import torch.nn.functional as F


def simpo_loss(
    chosen_avg_logprobs: torch.Tensor,
    rejected_avg_logprobs: torch.Tensor,
    *,
    beta: float = 2.0,
    gamma: float = 1.0,
) -> torch.Tensor:
    """SimPO loss — reference-free DPO with target margin.

    Args:
        chosen_avg_logprobs: (B,) average per-token log-prob of the chosen
            response under the policy. Computed as
            `chosen_logprobs.sum() / response_length`.
        rejected_avg_logprobs: (B,) same for rejected.
        beta: scaling factor (paper default 2.0)
        gamma: target margin (paper default 1.0)

    Returns:
        Scalar loss; lower is better.

    Reference: arXiv:2405.14734 Eq. (5).
    """
    if chosen_avg_logprobs.shape != rejected_avg_logprobs.shape:
        raise ValueError(
            f"chosen and rejected avg-logprob tensors must have the same shape, "
            f"got chosen={chosen_avg_logprobs.shape}, "
            f"rejected={rejected_avg_logprobs.shape}"
        )
    logits = beta * (chosen_avg_logprobs - rejected_avg_logprobs) - gamma
    return -F.logsigmoid(logits).mean()


def avg_sequence_logprob(
    model_logprobs: torch.Tensor,
    response_mask: torch.Tensor,
) -> torch.Tensor:
    """Helper: convert (B, T) per-token log-probs + (B, T) response mask into
    (B,) per-sequence AVERAGE log-probability over response tokens.

    SimPO uses the average (not sum) so that long sequences aren't
    penalized for having many tokens. The mask should be 1 on response
    tokens and 0 on prompt+padding.
    """
    masked = model_logprobs * response_mask.float()
    n_tokens = response_mask.sum(dim=-1).clamp_min(1.0).float()
    return masked.sum(dim=-1) / n_tokens


__all__ = ["simpo_loss", "avg_sequence_logprob"]