Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels. | |
| Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A". | |
| Verified extension point: GRPOTrainer._compute_loss(model, inputs) | |
| (DeepWiki audit of huggingface/trl, 2026-05-25). | |
| Total loss: | |
| total_loss = grpo_loss | |
| + alpha_sdpo * sdpo_kl_at_error_turns | |
| + beta_replay * trace_replay_dpo_loss | |
| Where: | |
| - grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches). | |
| - sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and | |
| teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only. | |
| - trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from | |
| N external teacher disagreement with the student. | |
| The data collator (data_collator.py) is responsible for: | |
| - Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint. | |
| - Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere). | |
| - Loading DPO pairs from the trace-replay output (see teacher_replay.py). | |
| - Precomputing reference-policy logprobs for DPO. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any | |
| import torch | |
| import torch.nn.functional as F | |
| # These imports work when TRL is installed — they're not skeleton imports. | |
| # When TRL is missing we fall back to `object` so the module still imports | |
| # (e.g. for documentation generation) but raise a clear ImportError at | |
| # instantiation time rather than the cryptic `object.__init__()` error. | |
| try: | |
| from trl import GRPOTrainer # type: ignore | |
| _TRL_AVAILABLE = True | |
| except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL | |
| GRPOTrainer = object # type: ignore — fallback so module imports without TRL | |
| _TRL_AVAILABLE = False | |
| from composer_replication.opsd import generalized_jsd_loss | |
| logger = logging.getLogger(__name__) | |
| class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type] | |
| """TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO. | |
| Args (in addition to GRPOTrainer's): | |
| alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled). | |
| Opt in by passing >0 once your data collator produces | |
| `sdpo_loss_mask` and `ctx_teacher_input_ids` columns. | |
| beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled). | |
| Opt in by passing >0 once your data collator produces | |
| `dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc. | |
| sdpo_jsd_beta: beta param of generalized_jsd_loss | |
| (0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per | |
| upstream OPSD convention; see composer_replication/opsd.py). | |
| sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0. | |
| sdpo_token_clip: per-token JSD clip for stability; None = no clip. | |
| replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula). | |
| """ | |
| def __init__( | |
| self, | |
| *args: Any, | |
| alpha_sdpo: float = 0.0, | |
| beta_replay: float = 0.0, | |
| sdpo_jsd_beta: float = 0.5, | |
| sdpo_temperature: float = 1.0, | |
| sdpo_token_clip: float | None = None, | |
| replay_dpo_beta: float = 0.1, | |
| **kwargs: Any, | |
| ): | |
| if not _TRL_AVAILABLE: | |
| raise ImportError( | |
| "ComposerReplicationTrainer requires TRL. Install with " | |
| "`pip install -e .[train]`." | |
| ) | |
| super().__init__(*args, **kwargs) | |
| self.alpha_sdpo = alpha_sdpo | |
| self.beta_replay = beta_replay | |
| self.sdpo_jsd_beta = sdpo_jsd_beta | |
| self.sdpo_temperature = sdpo_temperature | |
| self.sdpo_token_clip = sdpo_token_clip | |
| self.replay_dpo_beta = replay_dpo_beta | |
| # ---------------------------------------------------------------------- | |
| # Loss override (the integration core) | |
| # ---------------------------------------------------------------------- | |
| def _compute_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Override: total_loss = grpo + α*sdpo + β*replay.""" | |
| # Channel 1: standard GRPO loss | |
| grpo_loss = super()._compute_loss(model, inputs) | |
| # Channel 2: SDPO hint-distill at error sites | |
| sdpo_kl = self._compute_sdpo_loss(model, inputs) | |
| # Channel 3: trace-replay DPO from teacher disagreement | |
| replay_dpo = self._compute_trace_replay_loss(model, inputs) | |
| # Compose | |
| total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo | |
| # Log per-channel components (so we can ablate post-hoc) | |
| if hasattr(self, "state") and getattr(self, "args", None) is not None: | |
| log_steps = getattr(self.args, "logging_steps", 50) | |
| if self.state.global_step % log_steps == 0: | |
| self.log({ # type: ignore[attr-defined] | |
| "loss/grpo": float(grpo_loss.detach()), | |
| "loss/sdpo_kl": float(sdpo_kl.detach()), | |
| "loss/trace_replay_dpo": float(replay_dpo.detach()), | |
| "loss/total": float(total.detach()), | |
| "loss/alpha_sdpo": self.alpha_sdpo, | |
| "loss/beta_replay": self.beta_replay, | |
| }) | |
| return total | |
| # ---------------------------------------------------------------------- | |
| # Channel 2: SDPO hint-distill | |
| # ---------------------------------------------------------------------- | |
| def _compute_sdpo_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Compute generalized_jsd_loss between student and hint-conditioned teacher. | |
| Both come from the SAME model — teacher just has hint inserted into context. | |
| Skipped (returns 0) if the batch has no error sites (data collator emits | |
| empty ctx_teacher_input_ids). | |
| """ | |
| if ( | |
| self.alpha_sdpo == 0.0 | |
| or "ctx_teacher_input_ids" not in inputs | |
| or inputs["ctx_teacher_input_ids"].numel() == 0 | |
| ): | |
| return torch.tensor(0.0, device=_device_of(model), requires_grad=True) | |
| # Student forward (with grad, on the original-context input) | |
| student_logits = model(input_ids=inputs["input_ids"]).logits | |
| # Teacher forward (no grad — same model, hint-conditioned context) | |
| with torch.no_grad(): | |
| teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits | |
| # NOTE: in real implementation, ctx_teacher and ctx_student must be the | |
| # SAME LENGTH at the post-hint section so logits align position-by-position. | |
| # The data collator pads/aligns. The skeleton trusts that's done correctly. | |
| if student_logits.shape != teacher_logits.shape: | |
| logger.warning( | |
| "SDPO logit shape mismatch: student=%s vs teacher=%s. " | |
| "Skipping SDPO loss for this step. Check the data collator's " | |
| "alignment — the post-hint section must have identical token-counts.", | |
| student_logits.shape, teacher_logits.shape, | |
| ) | |
| return torch.tensor(0.0, device=_device_of(model), requires_grad=True) | |
| return generalized_jsd_loss( | |
| student_logits=student_logits, | |
| teacher_logits=teacher_logits, | |
| labels=inputs.get("sdpo_loss_mask"), # error-turn token mask | |
| beta=self.sdpo_jsd_beta, | |
| temperature=self.sdpo_temperature, | |
| token_clip=self.sdpo_token_clip, | |
| reduction="batchmean", | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Channel 3: trace-replay DPO | |
| # ---------------------------------------------------------------------- | |
| def _compute_trace_replay_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Standard DPO loss using (chosen, rejected) pairs from teacher disagreement. | |
| DPO loss formula (Rafailov et al. 2023): | |
| L = -log σ(β · (logπ(chosen) - logπ_ref(chosen) | |
| - logπ(rejected) + logπ_ref(rejected))) | |
| Where logπ_ref are precomputed by the data collator using the | |
| reference (init student) policy. | |
| """ | |
| if ( | |
| self.beta_replay == 0.0 | |
| or "dpo_chosen_input_ids" not in inputs | |
| or inputs["dpo_chosen_input_ids"].numel() == 0 | |
| ): | |
| return torch.tensor(0.0, device=_device_of(model), requires_grad=True) | |
| # Forward passes for chosen and rejected, gather logprobs at response tokens | |
| chosen_logprobs = self._sequence_logprobs( | |
| model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"] | |
| ) | |
| rejected_logprobs = self._sequence_logprobs( | |
| model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"] | |
| ) | |
| ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"] | |
| ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"] | |
| logits = self.replay_dpo_beta * ( | |
| (chosen_logprobs - ref_chosen_logprobs) | |
| - (rejected_logprobs - ref_rejected_logprobs) | |
| ) | |
| return -F.logsigmoid(logits).mean() | |
| def _sequence_logprobs( | |
| model: torch.nn.Module, | |
| input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Sum logprob of response tokens given the prompt prefix. | |
| Standard DPO accounting: we only score the response tokens (where | |
| response_mask == 1), not the prompt tokens. | |
| """ | |
| outputs = model(input_ids=input_ids) | |
| # Shift for next-token prediction: logits[t] predicts input_ids[t+1] | |
| logits = outputs.logits[:, :-1, :] | |
| targets = input_ids[:, 1:] | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) | |
| # Mask out prompt + padding; sum response-token logprobs | |
| masked = token_logprobs * response_mask[:, 1:].float() | |
| return masked.sum(dim=-1) | |
| def _device_of(model: torch.nn.Module) -> torch.device: | |
| """Return the device of any parameter of the model — robust to FSDP/DDP wrappers.""" | |
| return next(model.parameters()).device | |
| __all__ = ["ComposerReplicationTrainer"] | |