Codeseys's picture
Wave 13: serverless DiLoCo + replaysim normalization + 3 distillation losses + PRIME-RL + Monarch
b266c31
"""SimPO loss — reference-free DPO replacement.
Paper: "SimPO: Simple Preference Optimization with a Reference-Free Reward"
Meng et al., NeurIPS 2024 (arXiv:2405.14734)
License: MIT (https://github.com/princeton-nlp/SimPO)
Standard DPO requires log-probabilities under both the policy and a
reference policy:
L_DPO = -log σ( β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))] )
SimPO drops the reference-policy term, replaces it with a target margin γ,
and uses average sequence log-probability instead of sum. This removes the
reference-model VRAM cost (which is a meaningful fraction of total
training-time memory).
L_SimPO = -log σ( β·[avg_logπ(c) - avg_logπ(r)] - γ )
Where:
- avg_logπ(c) = (1/|c|) · Σ_t logπ(c_t | c_<t, prompt)
- β: scaling factor (paper default: 2.0)
- γ: target margin (paper default: 1.0)
Compose with the framework: replace channel-3 `_compute_trace_replay_loss`
when `dpo_variant="simpo"` is passed to `compose_loss`. Inputs change:
SimPO does NOT consume `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`
(those become unused).
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
def simpo_loss(
chosen_avg_logprobs: torch.Tensor,
rejected_avg_logprobs: torch.Tensor,
*,
beta: float = 2.0,
gamma: float = 1.0,
) -> torch.Tensor:
"""SimPO loss — reference-free DPO with target margin.
Args:
chosen_avg_logprobs: (B,) average per-token log-prob of the chosen
response under the policy. Computed as
`chosen_logprobs.sum() / response_length`.
rejected_avg_logprobs: (B,) same for rejected.
beta: scaling factor (paper default 2.0)
gamma: target margin (paper default 1.0)
Returns:
Scalar loss; lower is better.
Reference: arXiv:2405.14734 Eq. (5).
"""
if chosen_avg_logprobs.shape != rejected_avg_logprobs.shape:
raise ValueError(
f"chosen and rejected avg-logprob tensors must have the same shape, "
f"got chosen={chosen_avg_logprobs.shape}, "
f"rejected={rejected_avg_logprobs.shape}"
)
logits = beta * (chosen_avg_logprobs - rejected_avg_logprobs) - gamma
return -F.logsigmoid(logits).mean()
def avg_sequence_logprob(
model_logprobs: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Helper: convert (B, T) per-token log-probs + (B, T) response mask into
(B,) per-sequence AVERAGE log-probability over response tokens.
SimPO uses the average (not sum) so that long sequences aren't
penalized for having many tokens. The mask should be 1 on response
tokens and 0 on prompt+padding.
"""
masked = model_logprobs * response_mask.float()
n_tokens = response_mask.sum(dim=-1).clamp_min(1.0).float()
return masked.sum(dim=-1) / n_tokens
__all__ = ["simpo_loss", "avg_sequence_logprob"]