"""composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels. Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A". Verified extension point: GRPOTrainer._compute_loss(model, inputs) (DeepWiki audit of huggingface/trl, 2026-05-25). Total loss: total_loss = grpo_loss + alpha_sdpo * sdpo_kl_at_error_turns + beta_replay * trace_replay_dpo_loss Where: - grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches). - sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only. - trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from N external teacher disagreement with the student. The data collator (data_collator.py) is responsible for: - Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint. - Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere). - Loading DPO pairs from the trace-replay output (see teacher_replay.py). - Precomputing reference-policy logprobs for DPO. """ from __future__ import annotations import logging from typing import Any import torch import torch.nn.functional as F # These imports work when TRL is installed — they're not skeleton imports. # When TRL is missing we fall back to `object` so the module still imports # (e.g. for documentation generation) but raise a clear ImportError at # instantiation time rather than the cryptic `object.__init__()` error. try: from trl import GRPOTrainer # type: ignore _TRL_AVAILABLE = True except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL GRPOTrainer = object # type: ignore — fallback so module imports without TRL _TRL_AVAILABLE = False from composer_replication.opsd import generalized_jsd_loss logger = logging.getLogger(__name__) class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type] """TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO. Args (in addition to GRPOTrainer's): alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled). Opt in by passing >0 once your data collator produces `sdpo_loss_mask` and `ctx_teacher_input_ids` columns. beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled). Opt in by passing >0 once your data collator produces `dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc. sdpo_jsd_beta: beta param of generalized_jsd_loss (0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per upstream OPSD convention; see composer_replication/opsd.py). sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0. sdpo_token_clip: per-token JSD clip for stability; None = no clip. replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula). """ def __init__( self, *args: Any, alpha_sdpo: float = 0.0, beta_replay: float = 0.0, sdpo_jsd_beta: float = 0.5, sdpo_temperature: float = 1.0, sdpo_token_clip: float | None = None, replay_dpo_beta: float = 0.1, **kwargs: Any, ): if not _TRL_AVAILABLE: raise ImportError( "ComposerReplicationTrainer requires TRL. Install with " "`pip install -e .[train]`." ) super().__init__(*args, **kwargs) self.alpha_sdpo = alpha_sdpo self.beta_replay = beta_replay self.sdpo_jsd_beta = sdpo_jsd_beta self.sdpo_temperature = sdpo_temperature self.sdpo_token_clip = sdpo_token_clip self.replay_dpo_beta = replay_dpo_beta # ---------------------------------------------------------------------- # Loss override (the integration core) # ---------------------------------------------------------------------- def _compute_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """Override: total_loss = grpo + α*sdpo + β*replay.""" # Channel 1: standard GRPO loss grpo_loss = super()._compute_loss(model, inputs) # Channel 2: SDPO hint-distill at error sites sdpo_kl = self._compute_sdpo_loss(model, inputs) # Channel 3: trace-replay DPO from teacher disagreement replay_dpo = self._compute_trace_replay_loss(model, inputs) # Compose total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo # Log per-channel components (so we can ablate post-hoc) if hasattr(self, "state") and getattr(self, "args", None) is not None: log_steps = getattr(self.args, "logging_steps", 50) if self.state.global_step % log_steps == 0: self.log({ # type: ignore[attr-defined] "loss/grpo": float(grpo_loss.detach()), "loss/sdpo_kl": float(sdpo_kl.detach()), "loss/trace_replay_dpo": float(replay_dpo.detach()), "loss/total": float(total.detach()), "loss/alpha_sdpo": self.alpha_sdpo, "loss/beta_replay": self.beta_replay, }) return total # ---------------------------------------------------------------------- # Channel 2: SDPO hint-distill # ---------------------------------------------------------------------- def _compute_sdpo_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """Compute generalized_jsd_loss between student and hint-conditioned teacher. Both come from the SAME model — teacher just has hint inserted into context. Skipped (returns 0) if the batch has no error sites (data collator emits empty ctx_teacher_input_ids). """ if ( self.alpha_sdpo == 0.0 or "ctx_teacher_input_ids" not in inputs or inputs["ctx_teacher_input_ids"].numel() == 0 ): return torch.tensor(0.0, device=_device_of(model), requires_grad=True) # Student forward (with grad, on the original-context input) student_logits = model(input_ids=inputs["input_ids"]).logits # Teacher forward (no grad — same model, hint-conditioned context) with torch.no_grad(): teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits # NOTE: in real implementation, ctx_teacher and ctx_student must be the # SAME LENGTH at the post-hint section so logits align position-by-position. # The data collator pads/aligns. The skeleton trusts that's done correctly. if student_logits.shape != teacher_logits.shape: logger.warning( "SDPO logit shape mismatch: student=%s vs teacher=%s. " "Skipping SDPO loss for this step. Check the data collator's " "alignment — the post-hint section must have identical token-counts.", student_logits.shape, teacher_logits.shape, ) return torch.tensor(0.0, device=_device_of(model), requires_grad=True) return generalized_jsd_loss( student_logits=student_logits, teacher_logits=teacher_logits, labels=inputs.get("sdpo_loss_mask"), # error-turn token mask beta=self.sdpo_jsd_beta, temperature=self.sdpo_temperature, token_clip=self.sdpo_token_clip, reduction="batchmean", ) # ---------------------------------------------------------------------- # Channel 3: trace-replay DPO # ---------------------------------------------------------------------- def _compute_trace_replay_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """Standard DPO loss using (chosen, rejected) pairs from teacher disagreement. DPO loss formula (Rafailov et al. 2023): L = -log σ(β · (logπ(chosen) - logπ_ref(chosen) - logπ(rejected) + logπ_ref(rejected))) Where logπ_ref are precomputed by the data collator using the reference (init student) policy. """ if ( self.beta_replay == 0.0 or "dpo_chosen_input_ids" not in inputs or inputs["dpo_chosen_input_ids"].numel() == 0 ): return torch.tensor(0.0, device=_device_of(model), requires_grad=True) # Forward passes for chosen and rejected, gather logprobs at response tokens chosen_logprobs = self._sequence_logprobs( model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"] ) rejected_logprobs = self._sequence_logprobs( model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"] ) ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"] ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"] logits = self.replay_dpo_beta * ( (chosen_logprobs - ref_chosen_logprobs) - (rejected_logprobs - ref_rejected_logprobs) ) return -F.logsigmoid(logits).mean() @staticmethod def _sequence_logprobs( model: torch.nn.Module, input_ids: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor: """Sum logprob of response tokens given the prompt prefix. Standard DPO accounting: we only score the response tokens (where response_mask == 1), not the prompt tokens. """ outputs = model(input_ids=input_ids) # Shift for next-token prediction: logits[t] predicts input_ids[t+1] logits = outputs.logits[:, :-1, :] targets = input_ids[:, 1:] log_probs = F.log_softmax(logits, dim=-1) token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # Mask out prompt + padding; sum response-token logprobs masked = token_logprobs * response_mask[:, 1:].float() return masked.sum(dim=-1) def _device_of(model: torch.nn.Module) -> torch.device: """Return the device of any parameter of the model — robust to FSDP/DDP wrappers.""" return next(model.parameters()).device __all__ = ["ComposerReplicationTrainer"]