"""SimPO loss — reference-free DPO replacement. Paper: "SimPO: Simple Preference Optimization with a Reference-Free Reward" Meng et al., NeurIPS 2024 (arXiv:2405.14734) License: MIT (https://github.com/princeton-nlp/SimPO) Standard DPO requires log-probabilities under both the policy and a reference policy: L_DPO = -log σ( β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))] ) SimPO drops the reference-policy term, replaces it with a target margin γ, and uses average sequence log-probability instead of sum. This removes the reference-model VRAM cost (which is a meaningful fraction of total training-time memory). L_SimPO = -log σ( β·[avg_logπ(c) - avg_logπ(r)] - γ ) Where: - avg_logπ(c) = (1/|c|) · Σ_t logπ(c_t | c_ torch.Tensor: """SimPO loss — reference-free DPO with target margin. Args: chosen_avg_logprobs: (B,) average per-token log-prob of the chosen response under the policy. Computed as `chosen_logprobs.sum() / response_length`. rejected_avg_logprobs: (B,) same for rejected. beta: scaling factor (paper default 2.0) gamma: target margin (paper default 1.0) Returns: Scalar loss; lower is better. Reference: arXiv:2405.14734 Eq. (5). """ if chosen_avg_logprobs.shape != rejected_avg_logprobs.shape: raise ValueError( f"chosen and rejected avg-logprob tensors must have the same shape, " f"got chosen={chosen_avg_logprobs.shape}, " f"rejected={rejected_avg_logprobs.shape}" ) logits = beta * (chosen_avg_logprobs - rejected_avg_logprobs) - gamma return -F.logsigmoid(logits).mean() def avg_sequence_logprob( model_logprobs: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor: """Helper: convert (B, T) per-token log-probs + (B, T) response mask into (B,) per-sequence AVERAGE log-probability over response tokens. SimPO uses the average (not sum) so that long sequences aren't penalized for having many tokens. The mask should be 1 on response tokens and 0 on prompt+padding. """ masked = model_logprobs * response_mask.float() n_tokens = response_mask.sum(dim=-1).clamp_min(1.0).float() return masked.sum(dim=-1) / n_tokens __all__ = ["simpo_loss", "avg_sequence_logprob"]