"""compose_loss.py — free 3-channel loss composer for verification smokes. This is a verification-harness mirror of `ComposerReplicationTrainer._compute_loss` that does NOT depend on TRL's GRPOTrainer parent. The GRPO channel is replaced with standard LM next-token-prediction cross-entropy, which is the limit GRPO converges to under deterministic rewards. Use it for: - CPU smokes on real HF models (Spike 006) - Unit tests of loss composition without spinning up TRL - Anywhere we want to verify gradient flow through the 3-channel sum without paying TRL's full machinery cost Do NOT use it as the production training loss. Production = ComposerReplicationTrainer (a real GRPOTrainer subclass) which uses TRL's reward + advantage estimation. Total loss: total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo Channels: - lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub) - sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits - trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs ADR-007 extensions ------------------ Three pluggable distillation losses can swap the default DPO/SDPO channels: - ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with margin) instead of standard DPO. Reference logprobs are no longer required. - ``sdpo_wrapper="taid"`` — channel 2 replaces SDPO with TAID (Temporally Adaptive Interpolated Distillation, SakanaAI port). Requires ``taid_t`` (the current interpolation coefficient in ``[0, 1]``). The schedule that produces ``taid_t`` is the trainer's responsibility — typically a :class:`composer_replication.distillation.taid.TAIDScheduler` instance driven by the per-step distillation loss. - ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a per-token gated forward/reverse KL. All three default to off; passing the new kwargs at their defaults is bit-exact equivalent to the legacy 3-channel composition. """ from __future__ import annotations from dataclasses import dataclass from typing import Literal import torch import torch.nn.functional as F from composer_replication.opsd import generalized_jsd_loss @dataclass class LossComponents: """Per-channel breakdown of the total loss for logging + ablation.""" lm_ce: torch.Tensor sdpo_jsd: torch.Tensor trace_replay_dpo: torch.Tensor total: torch.Tensor def detached(self) -> dict[str, float]: return { "lm_ce": float(self.lm_ce.detach()), "sdpo_jsd": float(self.sdpo_jsd.detach()), "trace_replay_dpo": float(self.trace_replay_dpo.detach()), "total": float(self.total.detach()), } def compose_loss( model: torch.nn.Module, inputs: dict[str, torch.Tensor], *, alpha_sdpo: float = 0.1, beta_replay: float = 0.05, sdpo_jsd_beta: float = 0.5, sdpo_temperature: float = 1.0, sdpo_token_clip: float | None = None, replay_dpo_beta: float = 0.1, lm_ce_label_smoothing: float = 0.0, # ADR-007 extensions ------------------------------------------------ dpo_variant: Literal["dpo", "simpo"] = "dpo", sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none", taid_t: float | None = None, # SimPO knobs (only used when dpo_variant="simpo") ------------------ simpo_beta: float = 2.0, simpo_gamma: float = 1.0, # Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd") entropy_opd_h_max: float | None = None, ) -> LossComponents: """Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo. Required keys in `inputs`: - input_ids: (B, T_s) student rollout - response_mask: (B, T_s) 1 on assistant-response tokens, 0 elsewhere Optional keys (channel auto-disables if missing OR if its weight = 0): SDPO: - ctx_teacher_input_ids: (B, T_t) hint-conditioned context - sdpo_loss_mask: (B, T_t) 1 at error-turn tokens DPO (dpo_variant="dpo"): - dpo_chosen_input_ids, dpo_chosen_response_mask - dpo_rejected_input_ids, dpo_rejected_response_mask - dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed) SimPO (dpo_variant="simpo"): - dpo_chosen_input_ids, dpo_chosen_response_mask - dpo_rejected_input_ids, dpo_rejected_response_mask (reference logprobs not required and silently ignored) TAID (sdpo_wrapper="taid"): - taid_t kwarg: scalar float in [0, 1] giving the current interpolation coefficient. The trainer is responsible for the schedule (use TAIDScheduler from composer_replication.distillation.taid for the paper-default adaptive scheme, or any custom schedule of your choosing). """ if dpo_variant not in ("dpo", "simpo"): raise ValueError( f"dpo_variant must be 'dpo' or 'simpo', got {dpo_variant!r}" ) if sdpo_wrapper not in ("none", "taid", "entropy_opd"): raise ValueError( f"sdpo_wrapper must be 'none', 'taid', or 'entropy_opd', " f"got {sdpo_wrapper!r}" ) if sdpo_wrapper == "taid": if taid_t is None: raise ValueError( "sdpo_wrapper='taid' requires taid_t (float in [0, 1]). " "Drive it from a TAIDScheduler or pass a fixed value." ) if not (0.0 <= float(taid_t) <= 1.0): raise ValueError( f"taid_t must be in [0, 1], got {taid_t}" ) device = _device_of(model) # ------------------------------------------------------------------ # Channel 1 (GRPO stub): LM cross-entropy on response tokens # ------------------------------------------------------------------ lm_ce = _lm_response_ce( model, inputs["input_ids"], inputs["response_mask"], label_smoothing=lm_ce_label_smoothing, ) # ------------------------------------------------------------------ # Channel 2 (SDPO): generalized JSD on hint-conditioned forward # Optionally wrapped by TAID or replaced by Entropy-Aware OPD. # ------------------------------------------------------------------ sdpo_jsd = _zero(device) if ( alpha_sdpo > 0.0 and "ctx_teacher_input_ids" in inputs and inputs["ctx_teacher_input_ids"].numel() > 0 ): student_logits = model(input_ids=inputs["input_ids"]).logits with torch.no_grad(): teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits if student_logits.shape == teacher_logits.shape: if sdpo_wrapper == "none": sdpo_jsd = generalized_jsd_loss( student_logits=student_logits, teacher_logits=teacher_logits, labels=inputs.get("sdpo_loss_mask"), beta=sdpo_jsd_beta, temperature=sdpo_temperature, token_clip=sdpo_token_clip, reduction="batchmean", ) elif sdpo_wrapper == "taid": from composer_replication.distillation import taid_loss # taid_t validated non-None and in-range above. assert taid_t is not None # Reuse the SDPO loss-mask if provided so we only score the # error-turn tokens; otherwise score all tokens. taid_mask_bt = inputs.get("sdpo_loss_mask") if taid_mask_bt is not None: taid_mask_bt = taid_mask_bt.to(student_logits.device).float() sdpo_jsd = taid_loss( student_logits=student_logits, teacher_logits=teacher_logits, mask=taid_mask_bt, t=float(taid_t), ) elif sdpo_wrapper == "entropy_opd": from composer_replication.distillation import ( entropy_aware_opd_loss, ) sdpo_jsd = entropy_aware_opd_loss( student_logits=student_logits, teacher_logits=teacher_logits, labels=inputs.get("sdpo_loss_mask"), h_max=entropy_opd_h_max, temperature=sdpo_temperature, reduction="batchmean", ) # else: silently zero — the data collator is responsible for shape # alignment in production. For the smoke we accept misalignment and # exercise the fallback path. # ------------------------------------------------------------------ # Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement # pairs. With dpo_variant="simpo", swap to SimPO (reference-free). # ------------------------------------------------------------------ trace_replay_dpo = _zero(device) if ( beta_replay > 0.0 and "dpo_chosen_input_ids" in inputs and inputs["dpo_chosen_input_ids"].numel() > 0 ): if dpo_variant == "dpo": chosen_lp = _sequence_logprobs( model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"], ) rejected_lp = _sequence_logprobs( model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"], ) ref_chosen = inputs["dpo_chosen_ref_logprobs"] ref_rejected = inputs["dpo_rejected_ref_logprobs"] dpo_logits = replay_dpo_beta * ( (chosen_lp - ref_chosen) - (rejected_lp - ref_rejected) ) trace_replay_dpo = -F.logsigmoid(dpo_logits).mean() else: # dpo_variant == "simpo" from composer_replication.distillation import simpo_loss chosen_avg_lp = _avg_sequence_logprobs( model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"], ) rejected_avg_lp = _avg_sequence_logprobs( model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"], ) trace_replay_dpo = simpo_loss( chosen_avg_logprobs=chosen_avg_lp, rejected_avg_logprobs=rejected_avg_lp, beta=simpo_beta, gamma=simpo_gamma, ) total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo return LossComponents( lm_ce=lm_ce, sdpo_jsd=sdpo_jsd, trace_replay_dpo=trace_replay_dpo, total=total, ) # ---------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------- def _zero(device: torch.device) -> torch.Tensor: """Differentiable zero — safe to add into a sum without breaking backward.""" return torch.zeros(1, device=device, requires_grad=True).squeeze() def _device_of(model: torch.nn.Module) -> torch.device: return next(model.parameters()).device def _lm_response_ce( model: torch.nn.Module, input_ids: torch.Tensor, response_mask: torch.Tensor, *, label_smoothing: float = 0.0, ) -> torch.Tensor: """Standard next-token-prediction cross-entropy on response tokens only. Mirrors what GRPO converges to under deterministic rewards (the policy gradient devolves to behavior cloning of high-reward rollouts). """ outputs = model(input_ids=input_ids) # Shift: logits[t] predicts input_ids[t+1] logits = outputs.logits[:, :-1, :] targets = input_ids[:, 1:] mask = response_mask[:, 1:].float() loss_per_token = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), reduction="none", label_smoothing=label_smoothing, ).view_as(targets) masked = loss_per_token * mask n_tokens = mask.sum().clamp_min(1.0) return masked.sum() / n_tokens def _sequence_logprobs( model: torch.nn.Module, input_ids: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor: """Sum of next-token logprobs over response tokens (standard DPO accounting).""" outputs = model(input_ids=input_ids) logits = outputs.logits[:, :-1, :] targets = input_ids[:, 1:] log_probs = F.log_softmax(logits, dim=-1) token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) masked = token_lp * response_mask[:, 1:].float() return masked.sum(dim=-1) def _avg_sequence_logprobs( model: torch.nn.Module, input_ids: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor: """Per-sequence AVERAGE next-token logprob over response tokens. SimPO accounting: divide the sum by the number of response tokens so long sequences aren't penalized for length. """ outputs = model(input_ids=input_ids) logits = outputs.logits[:, :-1, :] targets = input_ids[:, 1:] log_probs = F.log_softmax(logits, dim=-1) token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) mask = response_mask[:, 1:].float() masked = token_lp * mask n_tokens = mask.sum(dim=-1).clamp_min(1.0) return masked.sum(dim=-1) / n_tokens __all__ = ["compose_loss", "LossComponents"]