Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels.
Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A".
Verified extension point: GRPOTrainer._compute_loss(model, inputs)
(DeepWiki audit of huggingface/trl, 2026-05-25).
Total loss:
total_loss = grpo_loss
+ alpha_sdpo * sdpo_kl_at_error_turns
+ beta_replay * trace_replay_dpo_loss
Where:
- grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches).
- sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and
teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only.
- trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from
N external teacher disagreement with the student.
The data collator (data_collator.py) is responsible for:
- Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint.
- Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere).
- Loading DPO pairs from the trace-replay output (see teacher_replay.py).
- Precomputing reference-policy logprobs for DPO.
"""
from __future__ import annotations
import logging
from typing import Any
import torch
import torch.nn.functional as F
# These imports work when TRL is installed — they're not skeleton imports.
# When TRL is missing we fall back to `object` so the module still imports
# (e.g. for documentation generation) but raise a clear ImportError at
# instantiation time rather than the cryptic `object.__init__()` error.
try:
from trl import GRPOTrainer # type: ignore
_TRL_AVAILABLE = True
except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL
GRPOTrainer = object # type: ignore — fallback so module imports without TRL
_TRL_AVAILABLE = False
from composer_replication.opsd import generalized_jsd_loss
logger = logging.getLogger(__name__)
class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
"""TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.
Args (in addition to GRPOTrainer's):
alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled).
Opt in by passing >0 once your data collator produces
`sdpo_loss_mask` and `ctx_teacher_input_ids` columns.
beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled).
Opt in by passing >0 once your data collator produces
`dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc.
sdpo_jsd_beta: beta param of generalized_jsd_loss
(0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per
upstream OPSD convention; see composer_replication/opsd.py).
sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
sdpo_token_clip: per-token JSD clip for stability; None = no clip.
replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
"""
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.0,
beta_replay: float = 0.0,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
**kwargs: Any,
):
if not _TRL_AVAILABLE:
raise ImportError(
"ComposerReplicationTrainer requires TRL. Install with "
"`pip install -e .[train]`."
)
super().__init__(*args, **kwargs)
self.alpha_sdpo = alpha_sdpo
self.beta_replay = beta_replay
self.sdpo_jsd_beta = sdpo_jsd_beta
self.sdpo_temperature = sdpo_temperature
self.sdpo_token_clip = sdpo_token_clip
self.replay_dpo_beta = replay_dpo_beta
# ----------------------------------------------------------------------
# Loss override (the integration core)
# ----------------------------------------------------------------------
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Override: total_loss = grpo + α*sdpo + β*replay."""
# Channel 1: standard GRPO loss
grpo_loss = super()._compute_loss(model, inputs)
# Channel 2: SDPO hint-distill at error sites
sdpo_kl = self._compute_sdpo_loss(model, inputs)
# Channel 3: trace-replay DPO from teacher disagreement
replay_dpo = self._compute_trace_replay_loss(model, inputs)
# Compose
total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo
# Log per-channel components (so we can ablate post-hoc)
if hasattr(self, "state") and getattr(self, "args", None) is not None:
log_steps = getattr(self.args, "logging_steps", 50)
if self.state.global_step % log_steps == 0:
self.log({ # type: ignore[attr-defined]
"loss/grpo": float(grpo_loss.detach()),
"loss/sdpo_kl": float(sdpo_kl.detach()),
"loss/trace_replay_dpo": float(replay_dpo.detach()),
"loss/total": float(total.detach()),
"loss/alpha_sdpo": self.alpha_sdpo,
"loss/beta_replay": self.beta_replay,
})
return total
# ----------------------------------------------------------------------
# Channel 2: SDPO hint-distill
# ----------------------------------------------------------------------
def _compute_sdpo_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Compute generalized_jsd_loss between student and hint-conditioned teacher.
Both come from the SAME model — teacher just has hint inserted into context.
Skipped (returns 0) if the batch has no error sites (data collator emits
empty ctx_teacher_input_ids).
"""
if (
self.alpha_sdpo == 0.0
or "ctx_teacher_input_ids" not in inputs
or inputs["ctx_teacher_input_ids"].numel() == 0
):
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
# Student forward (with grad, on the original-context input)
student_logits = model(input_ids=inputs["input_ids"]).logits
# Teacher forward (no grad — same model, hint-conditioned context)
with torch.no_grad():
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
# NOTE: in real implementation, ctx_teacher and ctx_student must be the
# SAME LENGTH at the post-hint section so logits align position-by-position.
# The data collator pads/aligns. The skeleton trusts that's done correctly.
if student_logits.shape != teacher_logits.shape:
logger.warning(
"SDPO logit shape mismatch: student=%s vs teacher=%s. "
"Skipping SDPO loss for this step. Check the data collator's "
"alignment — the post-hint section must have identical token-counts.",
student_logits.shape, teacher_logits.shape,
)
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
return generalized_jsd_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
labels=inputs.get("sdpo_loss_mask"), # error-turn token mask
beta=self.sdpo_jsd_beta,
temperature=self.sdpo_temperature,
token_clip=self.sdpo_token_clip,
reduction="batchmean",
)
# ----------------------------------------------------------------------
# Channel 3: trace-replay DPO
# ----------------------------------------------------------------------
def _compute_trace_replay_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Standard DPO loss using (chosen, rejected) pairs from teacher disagreement.
DPO loss formula (Rafailov et al. 2023):
L = -log σ(β · (logπ(chosen) - logπ_ref(chosen)
- logπ(rejected) + logπ_ref(rejected)))
Where logπ_ref are precomputed by the data collator using the
reference (init student) policy.
"""
if (
self.beta_replay == 0.0
or "dpo_chosen_input_ids" not in inputs
or inputs["dpo_chosen_input_ids"].numel() == 0
):
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
# Forward passes for chosen and rejected, gather logprobs at response tokens
chosen_logprobs = self._sequence_logprobs(
model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"]
)
rejected_logprobs = self._sequence_logprobs(
model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"]
)
ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"]
ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"]
logits = self.replay_dpo_beta * (
(chosen_logprobs - ref_chosen_logprobs)
- (rejected_logprobs - ref_rejected_logprobs)
)
return -F.logsigmoid(logits).mean()
@staticmethod
def _sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Sum logprob of response tokens given the prompt prefix.
Standard DPO accounting: we only score the response tokens (where
response_mask == 1), not the prompt tokens.
"""
outputs = model(input_ids=input_ids)
# Shift for next-token prediction: logits[t] predicts input_ids[t+1]
logits = outputs.logits[:, :-1, :]
targets = input_ids[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
# Mask out prompt + padding; sum response-token logprobs
masked = token_logprobs * response_mask[:, 1:].float()
return masked.sum(dim=-1)
def _device_of(model: torch.nn.Module) -> torch.device:
"""Return the device of any parameter of the model — robust to FSDP/DDP wrappers."""
return next(model.parameters()).device
__all__ = ["ComposerReplicationTrainer"]