Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
composer-replication-framework / spikes /005-integrated-trainer-skeleton /trl_path /data_collator.py
| """data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch. | |
| Pipeline: | |
| 1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003). | |
| 2. Tokenize each turn of the trace. | |
| 3. Detect error sites (turns where a tool call failed) using a configurable predicate. | |
| 4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary. | |
| 5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position. | |
| 6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere. | |
| 7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step. | |
| The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its | |
| `inputs` argument. See `trl_path/composer_trainer.py` for the consumer side. | |
| Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss | |
| requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's | |
| why we pad/align here rather than inside the loss function. The post-hint section of | |
| ctx_teacher must have token-by-token alignment with the same section of ctx_student. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Callable, Sequence | |
| from dataclasses import dataclass, field | |
| from typing import Any, TypedDict | |
| import torch | |
| # --------------------------------------------------------------------------- | |
| # Types | |
| # --------------------------------------------------------------------------- | |
| class TraceTurn(TypedDict, total=False): | |
| """One turn of an agentic trace.""" | |
| role: str # "user" | "assistant" | "tool" | |
| content: str # text or tool result | |
| tool_call: dict | None # parsed tool call, if assistant-issued | |
| tool_error: str | None # error_kind from the env, e.g. "tool_not_found" | |
| error_meta: dict # extra info for hint generator (available_tools, etc.) | |
| class TraceExample(TypedDict, total=False): | |
| """One training example: a (trace, optional DPO pairs) tuple.""" | |
| trace_id: str | |
| turns: list[TraceTurn] | |
| final_reward: float # RLVR scalar (test-pass etc.) at trajectory end | |
| dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer protocol — duck-typed against HF AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| class TokenizerLike: | |
| """Minimal protocol the collator needs from a tokenizer. | |
| Compatible with HuggingFace `AutoTokenizer` instances (the typical case), | |
| but also satisfiable by simpler stubs for unit-testing. | |
| """ | |
| pad_token_id: int | |
| def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover | |
| ... | |
| def apply_chat_template( # pragma: no cover | |
| self, messages: list[dict], **kwargs: Any | |
| ) -> str | list[int]: | |
| ... | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| class CollatorConfig: | |
| """Tunables for ComposerDataCollator.""" | |
| max_seq_len: int = 4096 | |
| max_dpo_seq_len: int = 2048 | |
| pad_token_id: int = 0 | |
| ignore_index: int = -100 # standard HF "ignore in loss" sentinel | |
| # SDPO behavior | |
| enable_sdpo: bool = True | |
| hint_generator: Callable[[str, dict], str | None] | None = None | |
| """Callable error_kind, error_meta -> hint_text (or None to skip).""" | |
| # Trace-replay DPO behavior | |
| enable_replay_dpo: bool = True | |
| # Reward shaping | |
| rlvr_reward_key: str = "final_reward" | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _is_error_turn(turn: TraceTurn) -> bool: | |
| """Predicate: is this turn an error site that should trigger SDPO?""" | |
| return turn.get("tool_error") is not None | |
| def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]: | |
| """Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template.""" | |
| return [ | |
| {"role": t["role"], "content": t["content"]} | |
| for t in turns if t.get("content") | |
| ] | |
| def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]: | |
| """Right-pad with pad_id, or right-truncate to target_len.""" | |
| if len(seq) >= target_len: | |
| return seq[:target_len] | |
| return seq + [pad_id] * (target_len - len(seq)) | |
| # --------------------------------------------------------------------------- | |
| # The collator | |
| # --------------------------------------------------------------------------- | |
| class ComposerDataCollator: | |
| """Build trainer-ready batches from raw traces + optional DPO pairs. | |
| Usage: | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_example_0, trace_example_1, ...]) | |
| # batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer | |
| The dict contains: | |
| # Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer) | |
| - input_ids: (B, T_max) | |
| - attention_mask: (B, T_max) | |
| - response_mask: (B, T_max) | |
| - rewards: (B,) | |
| # Channel 2 (SDPO hint-distill) — present when any example has error turns | |
| - ctx_teacher_input_ids: (B, T_max) | |
| - sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens | |
| # Channel 3 (trace-replay DPO) — present when any example has dpo_pairs | |
| - dpo_chosen_input_ids: (B', T_dpo) | |
| - dpo_chosen_response_mask: (B', T_dpo) | |
| - dpo_rejected_input_ids: (B', T_dpo) | |
| - dpo_rejected_response_mask: (B', T_dpo) | |
| # ref_logprobs are NOT computed here — the trainer's reference-policy | |
| # forward pass at training time produces them. | |
| """ | |
| tokenizer: TokenizerLike | |
| config: CollatorConfig = field(default_factory=CollatorConfig) | |
| def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: | |
| out: dict[str, torch.Tensor] = {} | |
| # --- Channel 1: GRPO core fields --- | |
| out.update(self._build_grpo_fields(batch)) | |
| # --- Channel 2: SDPO hint-distill fields --- | |
| if self.config.enable_sdpo: | |
| sdpo = self._build_sdpo_fields(batch) | |
| if sdpo is not None: | |
| out.update(sdpo) | |
| # --- Channel 3: trace-replay DPO fields --- | |
| if self.config.enable_replay_dpo: | |
| dpo = self._build_dpo_fields(batch) | |
| if dpo is not None: | |
| out.update(dpo) | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Channel 1: standard GRPO inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: | |
| input_ids_list: list[list[int]] = [] | |
| response_masks_list: list[list[int]] = [] | |
| rewards: list[float] = [] | |
| for ex in batch: | |
| ids, resp_mask = self._tokenize_trace(ex["turns"]) | |
| input_ids_list.append(ids) | |
| response_masks_list.append(resp_mask) | |
| rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0))) | |
| max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list)) | |
| input_ids = torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list], | |
| dtype=torch.long, | |
| ) | |
| response_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in response_masks_list], | |
| dtype=torch.long, | |
| ) | |
| attention_mask = (input_ids != self.config.pad_token_id).long() | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "response_mask": response_mask, | |
| "rewards": torch.tensor(rewards, dtype=torch.float), | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Channel 2: SDPO hint-distill inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_sdpo_fields( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length.""" | |
| if self.config.hint_generator is None: | |
| return None # nothing to do without a hint generator | |
| ctx_teacher_list: list[list[int]] = [] | |
| sdpo_mask_list: list[list[int]] = [] | |
| any_error_sites = False | |
| for ex in batch: | |
| ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"]) | |
| ctx_teacher_list.append(ctx_teacher_ids) | |
| sdpo_mask_list.append(sdpo_mask) | |
| any_error_sites = any_error_sites or has_errors | |
| if not any_error_sites: | |
| return None # batch has no error sites — SDPO is a no-op for this step | |
| max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list)) | |
| ctx_teacher = torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list], | |
| dtype=torch.long, | |
| ) | |
| sdpo_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list], | |
| dtype=torch.long, | |
| ) | |
| return { | |
| "ctx_teacher_input_ids": ctx_teacher, | |
| "sdpo_loss_mask": sdpo_mask, | |
| } | |
| def _build_hint_injected_trace( | |
| self, turns: Sequence[TraceTurn] | |
| ) -> tuple[list[int], list[int], bool]: | |
| """Walk the trace; at each error-turn boundary, inject a hint and mark | |
| the post-hint tokens as in-loss. | |
| Returns: | |
| (ctx_teacher_ids, sdpo_loss_mask, any_error_sites) | |
| """ | |
| if self.config.hint_generator is None: | |
| # Caller responsibility — short-circuited by the dispatch. | |
| empty: list[int] = [] | |
| return empty, empty, False | |
| teacher_messages: list[dict] = [] | |
| teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text) | |
| any_errors = False | |
| for turn in turns: | |
| if _is_error_turn(turn): | |
| hint_text = self.config.hint_generator( | |
| turn.get("tool_error", "unknown"), | |
| turn.get("error_meta", {}), | |
| ) | |
| if hint_text: | |
| any_errors = True | |
| # Inject hint as a system-style addendum BEFORE the assistant's response | |
| teacher_messages.append({"role": "system", "content": hint_text}) | |
| teacher_loss_segments.append((False, hint_text)) | |
| if turn.get("content"): | |
| teacher_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": turn["content"], | |
| }) | |
| teacher_loss_segments.append((True, turn["content"])) # post-hint tokens = loss | |
| continue | |
| # Non-error turn (or hint generator returned None) — passthrough | |
| if turn.get("content"): | |
| teacher_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": turn["content"], | |
| }) | |
| teacher_loss_segments.append((False, turn["content"])) | |
| # Tokenize the full teacher conversation | |
| teacher_ids = self._tokenize_messages(teacher_messages) | |
| # Build the per-token loss mask by tokenizing each segment and concatenating | |
| sdpo_mask = self._build_segment_mask(teacher_loss_segments) | |
| # Truncate mask to teacher_ids length if tokenization round-tripped slightly differently | |
| sdpo_mask = sdpo_mask[: len(teacher_ids)] | |
| if len(sdpo_mask) < len(teacher_ids): | |
| sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask)) | |
| return teacher_ids, sdpo_mask, any_errors | |
| def _build_segment_mask( | |
| self, segments: Sequence[tuple[bool, str]] | |
| ) -> list[int]: | |
| """For each (is_loss, text) segment, tokenize and emit per-token mask values. | |
| Loss-active tokens get 1; non-loss tokens get -100 (ignore_index). | |
| """ | |
| out: list[int] = [] | |
| for is_loss, text in segments: | |
| seg_ids = self._tokenize_text(text) | |
| mask_value = 1 if is_loss else self.config.ignore_index | |
| out.extend([mask_value] * len(seg_ids)) | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Channel 3: trace-replay DPO inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_dpo_fields( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Tokenize chosen/rejected pairs from teacher disagreement. | |
| DPO accounting requires: | |
| - chosen_input_ids = prompt + chosen_response | |
| - rejected_input_ids = prompt + rejected_response | |
| - response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss) | |
| """ | |
| all_chosen: list[list[int]] = [] | |
| all_rejected: list[list[int]] = [] | |
| all_chosen_resp_mask: list[list[int]] = [] | |
| all_rejected_resp_mask: list[list[int]] = [] | |
| for ex in batch: | |
| for pair in ex.get("dpo_pairs") or []: | |
| prompt_msgs = pair.get("state_messages", []) | |
| prompt_ids = self._tokenize_messages(prompt_msgs) | |
| chosen_ids = self._tokenize_text(pair["chosen"]) | |
| rejected_ids = self._tokenize_text(pair["rejected"]) | |
| chosen_full = prompt_ids + chosen_ids | |
| rejected_full = prompt_ids + rejected_ids | |
| # response_mask is 0 over prompt, 1 over response | |
| chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids) | |
| rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids) | |
| all_chosen.append(chosen_full) | |
| all_rejected.append(rejected_full) | |
| all_chosen_resp_mask.append(chosen_mask) | |
| all_rejected_resp_mask.append(rejected_mask) | |
| if not all_chosen: | |
| return None # no DPO pairs in this batch | |
| cap = self.config.max_dpo_seq_len | |
| max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected))) | |
| return { | |
| "dpo_chosen_input_ids": torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen], | |
| dtype=torch.long, | |
| ), | |
| "dpo_chosen_response_mask": torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask], | |
| dtype=torch.long, | |
| ), | |
| "dpo_rejected_input_ids": torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected], | |
| dtype=torch.long, | |
| ), | |
| "dpo_rejected_response_mask": torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask], | |
| dtype=torch.long, | |
| ), | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Tokenization helpers | |
| # ---------------------------------------------------------------------- | |
| def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]: | |
| """Tokenize an entire trace; return (ids, response_mask). | |
| response_mask = 1 over assistant turns (those are the loss-bearing tokens | |
| for GRPO), 0 over user/tool turns (prompt context). | |
| """ | |
| all_ids: list[int] = [] | |
| resp_mask: list[int] = [] | |
| for turn in turns: | |
| if not turn.get("content"): | |
| continue | |
| ids = self._tokenize_text(turn["content"]) | |
| mask_value = 1 if turn.get("role") == "assistant" else 0 | |
| all_ids.extend(ids) | |
| resp_mask.extend([mask_value] * len(ids)) | |
| return all_ids, resp_mask | |
| def _tokenize_text(self, text: str) -> list[int]: | |
| """Tokenize plain text via the tokenizer's __call__.""" | |
| result = self.tokenizer(text, add_special_tokens=False) | |
| ids = result["input_ids"] | |
| if hasattr(ids, "tolist"): | |
| ids = ids.tolist() | |
| # HF tokenizers often return list[list[int]] when batch-shaped; flatten if so | |
| if ids and isinstance(ids[0], list): | |
| ids = ids[0] | |
| return list(ids) | |
| def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]: | |
| """Tokenize a chat-formatted list of messages. | |
| Tries apply_chat_template first; falls back to concatenated content if not available. | |
| """ | |
| if not messages: | |
| return [] | |
| try: | |
| ids = self.tokenizer.apply_chat_template( | |
| list(messages), tokenize=True, add_generation_prompt=False | |
| ) | |
| if hasattr(ids, "tolist"): | |
| ids = ids.tolist() | |
| return list(ids) | |
| except (AttributeError, NotImplementedError, TypeError): | |
| # Stub tokenizer or no chat template defined — fall back to concatenated content | |
| text = "\n".join(m.get("content", "") for m in messages) | |
| return self._tokenize_text(text) | |
| __all__ = [ | |
| "ComposerDataCollator", | |
| "CollatorConfig", | |
| "TraceTurn", | |
| "TraceExample", | |
| "TokenizerLike", | |
| ] | |