"""test_data_collator.py — verify ComposerDataCollator builds correct batches. Uses a deterministic stub tokenizer so we can write expected-token-count assertions without depending on a real HF tokenizer being installed. Coverage: - GRPO core fields (input_ids, response_mask, attention_mask, rewards) - SDPO fields are skipped when no error turns are present - SDPO fields are constructed when error turns are present + hint generator returns text - SDPO loss mask correctly marks post-hint tokens with 1, others with -100 - DPO fields are skipped when no DPO pairs are present - DPO fields tokenize chosen/rejected pairs with correct response masks - Padding to max_seq_len works - Truncation to max_seq_len works Run: pytest spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py -v """ from __future__ import annotations import sys from pathlib import Path import pytest import torch sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from trl_path.data_collator import ( # noqa: E402 CollatorConfig, ComposerDataCollator, ) # ---------------------------------------------------------------------------- # Stub tokenizer — deterministic, character-by-character ish # ---------------------------------------------------------------------------- class StubTokenizer: """Maps each unique whitespace-separated word to an integer id, deterministically. Reserves 0 = pad, 1 = bos, 2 = eos. """ pad_token_id = 0 def __init__(self) -> None: self._vocab: dict[str, int] = {"": 0, "": 1, "": 2} def _id_for(self, word: str) -> int: if word not in self._vocab: self._vocab[word] = len(self._vocab) return self._vocab[word] def __call__(self, text: str | list[str], **_kwargs): if isinstance(text, list): return {"input_ids": [self._tokenize_one(t) for t in text]} return {"input_ids": self._tokenize_one(text)} def _tokenize_one(self, text: str) -> list[int]: return [self._id_for(w) for w in text.split()] if text else [] def apply_chat_template(self, messages, tokenize=True, **_kwargs): # noqa: ARG002 joined = " ".join(m.get("content", "") for m in messages) return self._tokenize_one(joined) # ---------------------------------------------------------------------------- # Fixtures # ---------------------------------------------------------------------------- @pytest.fixture def tok(): return StubTokenizer() @pytest.fixture def hint_gen(): """Simple hint generator that returns a fixed hint for `tool_not_found`.""" def _gen(error_kind: str, _meta: dict) -> str | None: if error_kind == "tool_not_found": return "HINT use a real tool" return None return _gen @pytest.fixture def trace_no_errors(): """Clean trace, no error sites.""" return { "trace_id": "ok-1", "turns": [ {"role": "user", "content": "task one"}, {"role": "assistant", "content": "answer one"}, ], "final_reward": 1.0, } @pytest.fixture def trace_with_error(): """Trace with one tool-call error in the middle.""" return { "trace_id": "err-1", "turns": [ {"role": "user", "content": "task two"}, { "role": "assistant", "content": "wrong attempt", "tool_error": "tool_not_found", "error_meta": {"available_tools": ["read", "write"]}, }, {"role": "tool", "content": "tool not found"}, {"role": "assistant", "content": "fixed attempt"}, ], "final_reward": 0.5, } @pytest.fixture def trace_with_dpo_pairs(): return { "trace_id": "dpo-1", "turns": [ {"role": "user", "content": "decide"}, {"role": "assistant", "content": "option B"}, ], "final_reward": 0.0, "dpo_pairs": [ { "state_id": "decide-1", "state_messages": [{"role": "user", "content": "decide"}], "chosen": "option A", "rejected": "option B", "n_teachers_agreeing": 3, } ], } # ---------------------------------------------------------------------------- # Channel 1: GRPO core fields # ---------------------------------------------------------------------------- def test_grpo_fields_shape_and_dtype(tok, trace_no_errors): collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_no_errors]) assert batch["input_ids"].dtype == torch.long assert batch["attention_mask"].dtype == torch.long assert batch["response_mask"].dtype == torch.long assert batch["rewards"].dtype == torch.float assert batch["input_ids"].shape == batch["response_mask"].shape == batch["attention_mask"].shape def test_grpo_response_mask_marks_assistant_only(tok, trace_no_errors): collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_no_errors]) response_mask = batch["response_mask"][0] # "task one" = 2 user tokens (mask 0), "answer one" = 2 asst tokens (mask 1) assert response_mask.tolist()[:4] == [0, 0, 1, 1] def test_grpo_rewards_match_input(tok, trace_no_errors, trace_with_error): collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_no_errors, trace_with_error]) assert batch["rewards"].tolist() == [1.0, 0.5] # ---------------------------------------------------------------------------- # Channel 2: SDPO hint-distill fields # ---------------------------------------------------------------------------- def test_sdpo_skipped_when_no_hint_generator_configured(tok, trace_with_error): """Even with error turns, no hint generator → no SDPO fields emitted.""" cfg = CollatorConfig(hint_generator=None) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace_with_error]) assert "ctx_teacher_input_ids" not in batch assert "sdpo_loss_mask" not in batch def test_sdpo_skipped_when_no_error_turns(tok, hint_gen, trace_no_errors): cfg = CollatorConfig(hint_generator=hint_gen) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace_no_errors]) assert "ctx_teacher_input_ids" not in batch assert "sdpo_loss_mask" not in batch def test_sdpo_emitted_when_error_turn_present(tok, hint_gen, trace_with_error): cfg = CollatorConfig(hint_generator=hint_gen) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace_with_error]) assert "ctx_teacher_input_ids" in batch assert "sdpo_loss_mask" in batch assert batch["ctx_teacher_input_ids"].dtype == torch.long assert batch["sdpo_loss_mask"].dtype == torch.long assert batch["ctx_teacher_input_ids"].shape == batch["sdpo_loss_mask"].shape def test_sdpo_loss_mask_marks_post_hint_tokens_only(tok, hint_gen, trace_with_error): """The mask should be 1 at post-hint tokens, -100 (ignore_index) elsewhere.""" cfg = CollatorConfig(hint_generator=hint_gen) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace_with_error]) mask = batch["sdpo_loss_mask"][0].tolist() # At least one position should be loss-active assert any(m == 1 for m in mask), f"Expected ≥1 loss-active position, got {mask}" # All non-loss positions should be ignore_index (-100), not 0 assert all(m in (1, -100) for m in mask), f"Mask must be {{1, -100}} only, got {set(mask)}" def test_sdpo_skipped_when_hint_generator_returns_none(tok, trace_with_error): """Hint generator returns None → SDPO fields not emitted (no signal to add).""" cfg = CollatorConfig(hint_generator=lambda _kind, _meta: None) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace_with_error]) assert "ctx_teacher_input_ids" not in batch # ---------------------------------------------------------------------------- # Channel 3: trace-replay DPO fields # ---------------------------------------------------------------------------- def test_dpo_skipped_when_no_pairs(tok, trace_no_errors): collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_no_errors]) assert "dpo_chosen_input_ids" not in batch def test_dpo_emitted_when_pairs_present(tok, trace_with_dpo_pairs): collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_with_dpo_pairs]) assert "dpo_chosen_input_ids" in batch assert "dpo_rejected_input_ids" in batch assert "dpo_chosen_response_mask" in batch assert "dpo_rejected_response_mask" in batch # Same number of pairs in chosen and rejected assert batch["dpo_chosen_input_ids"].shape[0] == batch["dpo_rejected_input_ids"].shape[0] def test_dpo_response_mask_zeros_prompt_ones_response(tok, trace_with_dpo_pairs): collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_with_dpo_pairs]) chosen_mask = batch["dpo_chosen_response_mask"][0].tolist() # Prompt = "decide" (1 token), chosen = "option A" (2 tokens) # Mask should be: [0, 1, 1] before any padding non_pad = [m for m in chosen_mask if m in (0, 1)] assert non_pad[0] == 0, "First token (prompt) should be 0 in response mask" assert sum(non_pad) >= 1, "At least one response token should be marked 1" # ---------------------------------------------------------------------------- # Padding / truncation # ---------------------------------------------------------------------------- def test_padding_to_max_len(tok, trace_no_errors): """When traces have different lengths, all are padded to the longest in batch.""" short = trace_no_errors # 4 tokens long_trace = { "trace_id": "long", "turns": [ {"role": "user", "content": "a b c d e f"}, {"role": "assistant", "content": "x y z"}, ], "final_reward": 1.0, } collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([short, long_trace]) # Both should have the same T dimension assert batch["input_ids"].shape[0] == 2 assert batch["input_ids"].shape == batch["response_mask"].shape def test_truncation_to_max_seq_len(tok): """Traces longer than max_seq_len are truncated.""" long_text = " ".join(f"w{i}" for i in range(50)) trace = { "trace_id": "trunc", "turns": [{"role": "assistant", "content": long_text}], "final_reward": 0.0, } cfg = CollatorConfig(max_seq_len=10) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace]) assert batch["input_ids"].shape[1] == 10 # ---------------------------------------------------------------------------- # Multi-example batches # ---------------------------------------------------------------------------- def test_mixed_batch_some_with_errors_some_without(tok, hint_gen, trace_no_errors, trace_with_error): """SDPO should fire when at least one example has error turns.""" cfg = CollatorConfig(hint_generator=hint_gen) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace_no_errors, trace_with_error]) assert "ctx_teacher_input_ids" in batch # Both rows in ctx_teacher_input_ids have the same length (batch shape) assert batch["ctx_teacher_input_ids"].shape[0] == 2 def test_attention_mask_zeros_padding(tok, trace_no_errors): """attention_mask must be 0 where input_ids is the pad token.""" collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_no_errors]) am = batch["attention_mask"] ids = batch["input_ids"] # At every padding position, attention_mask must be 0 pad_positions = (ids == 0) assert (am[pad_positions] == 0).all() non_pad_positions = ~pad_positions assert (am[non_pad_positions] == 1).all()