Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
composer-replication-framework / spikes /005-integrated-trainer-skeleton /tests /test_data_collator.py
| """test_data_collator.py — verify ComposerDataCollator builds correct batches. | |
| Uses a deterministic stub tokenizer so we can write expected-token-count | |
| assertions without depending on a real HF tokenizer being installed. | |
| Coverage: | |
| - GRPO core fields (input_ids, response_mask, attention_mask, rewards) | |
| - SDPO fields are skipped when no error turns are present | |
| - SDPO fields are constructed when error turns are present + hint generator returns text | |
| - SDPO loss mask correctly marks post-hint tokens with 1, others with -100 | |
| - DPO fields are skipped when no DPO pairs are present | |
| - DPO fields tokenize chosen/rejected pairs with correct response masks | |
| - Padding to max_seq_len works | |
| - Truncation to max_seq_len works | |
| Run: pytest spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py -v | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from trl_path.data_collator import ( # noqa: E402 | |
| CollatorConfig, | |
| ComposerDataCollator, | |
| ) | |
| # ---------------------------------------------------------------------------- | |
| # Stub tokenizer — deterministic, character-by-character ish | |
| # ---------------------------------------------------------------------------- | |
| class StubTokenizer: | |
| """Maps each unique whitespace-separated word to an integer id, deterministically. | |
| Reserves 0 = pad, 1 = bos, 2 = eos. | |
| """ | |
| pad_token_id = 0 | |
| def __init__(self) -> None: | |
| self._vocab: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2} | |
| def _id_for(self, word: str) -> int: | |
| if word not in self._vocab: | |
| self._vocab[word] = len(self._vocab) | |
| return self._vocab[word] | |
| def __call__(self, text: str | list[str], **_kwargs): | |
| if isinstance(text, list): | |
| return {"input_ids": [self._tokenize_one(t) for t in text]} | |
| return {"input_ids": self._tokenize_one(text)} | |
| def _tokenize_one(self, text: str) -> list[int]: | |
| return [self._id_for(w) for w in text.split()] if text else [] | |
| def apply_chat_template(self, messages, tokenize=True, **_kwargs): # noqa: ARG002 | |
| joined = " ".join(m.get("content", "") for m in messages) | |
| return self._tokenize_one(joined) | |
| # ---------------------------------------------------------------------------- | |
| # Fixtures | |
| # ---------------------------------------------------------------------------- | |
| def tok(): | |
| return StubTokenizer() | |
| def hint_gen(): | |
| """Simple hint generator that returns a fixed hint for `tool_not_found`.""" | |
| def _gen(error_kind: str, _meta: dict) -> str | None: | |
| if error_kind == "tool_not_found": | |
| return "HINT use a real tool" | |
| return None | |
| return _gen | |
| def trace_no_errors(): | |
| """Clean trace, no error sites.""" | |
| return { | |
| "trace_id": "ok-1", | |
| "turns": [ | |
| {"role": "user", "content": "task one"}, | |
| {"role": "assistant", "content": "answer one"}, | |
| ], | |
| "final_reward": 1.0, | |
| } | |
| def trace_with_error(): | |
| """Trace with one tool-call error in the middle.""" | |
| return { | |
| "trace_id": "err-1", | |
| "turns": [ | |
| {"role": "user", "content": "task two"}, | |
| { | |
| "role": "assistant", | |
| "content": "wrong attempt", | |
| "tool_error": "tool_not_found", | |
| "error_meta": {"available_tools": ["read", "write"]}, | |
| }, | |
| {"role": "tool", "content": "tool not found"}, | |
| {"role": "assistant", "content": "fixed attempt"}, | |
| ], | |
| "final_reward": 0.5, | |
| } | |
| def trace_with_dpo_pairs(): | |
| return { | |
| "trace_id": "dpo-1", | |
| "turns": [ | |
| {"role": "user", "content": "decide"}, | |
| {"role": "assistant", "content": "option B"}, | |
| ], | |
| "final_reward": 0.0, | |
| "dpo_pairs": [ | |
| { | |
| "state_id": "decide-1", | |
| "state_messages": [{"role": "user", "content": "decide"}], | |
| "chosen": "option A", | |
| "rejected": "option B", | |
| "n_teachers_agreeing": 3, | |
| } | |
| ], | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # Channel 1: GRPO core fields | |
| # ---------------------------------------------------------------------------- | |
| def test_grpo_fields_shape_and_dtype(tok, trace_no_errors): | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_no_errors]) | |
| assert batch["input_ids"].dtype == torch.long | |
| assert batch["attention_mask"].dtype == torch.long | |
| assert batch["response_mask"].dtype == torch.long | |
| assert batch["rewards"].dtype == torch.float | |
| assert batch["input_ids"].shape == batch["response_mask"].shape == batch["attention_mask"].shape | |
| def test_grpo_response_mask_marks_assistant_only(tok, trace_no_errors): | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_no_errors]) | |
| response_mask = batch["response_mask"][0] | |
| # "task one" = 2 user tokens (mask 0), "answer one" = 2 asst tokens (mask 1) | |
| assert response_mask.tolist()[:4] == [0, 0, 1, 1] | |
| def test_grpo_rewards_match_input(tok, trace_no_errors, trace_with_error): | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_no_errors, trace_with_error]) | |
| assert batch["rewards"].tolist() == [1.0, 0.5] | |
| # ---------------------------------------------------------------------------- | |
| # Channel 2: SDPO hint-distill fields | |
| # ---------------------------------------------------------------------------- | |
| def test_sdpo_skipped_when_no_hint_generator_configured(tok, trace_with_error): | |
| """Even with error turns, no hint generator → no SDPO fields emitted.""" | |
| cfg = CollatorConfig(hint_generator=None) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace_with_error]) | |
| assert "ctx_teacher_input_ids" not in batch | |
| assert "sdpo_loss_mask" not in batch | |
| def test_sdpo_skipped_when_no_error_turns(tok, hint_gen, trace_no_errors): | |
| cfg = CollatorConfig(hint_generator=hint_gen) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace_no_errors]) | |
| assert "ctx_teacher_input_ids" not in batch | |
| assert "sdpo_loss_mask" not in batch | |
| def test_sdpo_emitted_when_error_turn_present(tok, hint_gen, trace_with_error): | |
| cfg = CollatorConfig(hint_generator=hint_gen) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace_with_error]) | |
| assert "ctx_teacher_input_ids" in batch | |
| assert "sdpo_loss_mask" in batch | |
| assert batch["ctx_teacher_input_ids"].dtype == torch.long | |
| assert batch["sdpo_loss_mask"].dtype == torch.long | |
| assert batch["ctx_teacher_input_ids"].shape == batch["sdpo_loss_mask"].shape | |
| def test_sdpo_loss_mask_marks_post_hint_tokens_only(tok, hint_gen, trace_with_error): | |
| """The mask should be 1 at post-hint tokens, -100 (ignore_index) elsewhere.""" | |
| cfg = CollatorConfig(hint_generator=hint_gen) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace_with_error]) | |
| mask = batch["sdpo_loss_mask"][0].tolist() | |
| # At least one position should be loss-active | |
| assert any(m == 1 for m in mask), f"Expected ≥1 loss-active position, got {mask}" | |
| # All non-loss positions should be ignore_index (-100), not 0 | |
| assert all(m in (1, -100) for m in mask), f"Mask must be {{1, -100}} only, got {set(mask)}" | |
| def test_sdpo_skipped_when_hint_generator_returns_none(tok, trace_with_error): | |
| """Hint generator returns None → SDPO fields not emitted (no signal to add).""" | |
| cfg = CollatorConfig(hint_generator=lambda _kind, _meta: None) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace_with_error]) | |
| assert "ctx_teacher_input_ids" not in batch | |
| # ---------------------------------------------------------------------------- | |
| # Channel 3: trace-replay DPO fields | |
| # ---------------------------------------------------------------------------- | |
| def test_dpo_skipped_when_no_pairs(tok, trace_no_errors): | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_no_errors]) | |
| assert "dpo_chosen_input_ids" not in batch | |
| def test_dpo_emitted_when_pairs_present(tok, trace_with_dpo_pairs): | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_with_dpo_pairs]) | |
| assert "dpo_chosen_input_ids" in batch | |
| assert "dpo_rejected_input_ids" in batch | |
| assert "dpo_chosen_response_mask" in batch | |
| assert "dpo_rejected_response_mask" in batch | |
| # Same number of pairs in chosen and rejected | |
| assert batch["dpo_chosen_input_ids"].shape[0] == batch["dpo_rejected_input_ids"].shape[0] | |
| def test_dpo_response_mask_zeros_prompt_ones_response(tok, trace_with_dpo_pairs): | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_with_dpo_pairs]) | |
| chosen_mask = batch["dpo_chosen_response_mask"][0].tolist() | |
| # Prompt = "decide" (1 token), chosen = "option A" (2 tokens) | |
| # Mask should be: [0, 1, 1] before any padding | |
| non_pad = [m for m in chosen_mask if m in (0, 1)] | |
| assert non_pad[0] == 0, "First token (prompt) should be 0 in response mask" | |
| assert sum(non_pad) >= 1, "At least one response token should be marked 1" | |
| # ---------------------------------------------------------------------------- | |
| # Padding / truncation | |
| # ---------------------------------------------------------------------------- | |
| def test_padding_to_max_len(tok, trace_no_errors): | |
| """When traces have different lengths, all are padded to the longest in batch.""" | |
| short = trace_no_errors # 4 tokens | |
| long_trace = { | |
| "trace_id": "long", | |
| "turns": [ | |
| {"role": "user", "content": "a b c d e f"}, | |
| {"role": "assistant", "content": "x y z"}, | |
| ], | |
| "final_reward": 1.0, | |
| } | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([short, long_trace]) | |
| # Both should have the same T dimension | |
| assert batch["input_ids"].shape[0] == 2 | |
| assert batch["input_ids"].shape == batch["response_mask"].shape | |
| def test_truncation_to_max_seq_len(tok): | |
| """Traces longer than max_seq_len are truncated.""" | |
| long_text = " ".join(f"w{i}" for i in range(50)) | |
| trace = { | |
| "trace_id": "trunc", | |
| "turns": [{"role": "assistant", "content": long_text}], | |
| "final_reward": 0.0, | |
| } | |
| cfg = CollatorConfig(max_seq_len=10) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace]) | |
| assert batch["input_ids"].shape[1] == 10 | |
| # ---------------------------------------------------------------------------- | |
| # Multi-example batches | |
| # ---------------------------------------------------------------------------- | |
| def test_mixed_batch_some_with_errors_some_without(tok, hint_gen, trace_no_errors, trace_with_error): | |
| """SDPO should fire when at least one example has error turns.""" | |
| cfg = CollatorConfig(hint_generator=hint_gen) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([trace_no_errors, trace_with_error]) | |
| assert "ctx_teacher_input_ids" in batch | |
| # Both rows in ctx_teacher_input_ids have the same length (batch shape) | |
| assert batch["ctx_teacher_input_ids"].shape[0] == 2 | |
| def test_attention_mask_zeros_padding(tok, trace_no_errors): | |
| """attention_mask must be 0 where input_ids is the pad token.""" | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_no_errors]) | |
| am = batch["attention_mask"] | |
| ids = batch["input_ids"] | |
| # At every padding position, attention_mask must be 0 | |
| pad_positions = (ids == 0) | |
| assert (am[pad_positions] == 0).all() | |
| non_pad_positions = ~pad_positions | |
| assert (am[non_pad_positions] == 1).all() | |