baladithyab
Wave 4: data collator + loss composition smoke (38/38 tests pass)
157cdba
"""test_data_collator.py — verify ComposerDataCollator builds correct batches.
Uses a deterministic stub tokenizer so we can write expected-token-count
assertions without depending on a real HF tokenizer being installed.
Coverage:
- GRPO core fields (input_ids, response_mask, attention_mask, rewards)
- SDPO fields are skipped when no error turns are present
- SDPO fields are constructed when error turns are present + hint generator returns text
- SDPO loss mask correctly marks post-hint tokens with 1, others with -100
- DPO fields are skipped when no DPO pairs are present
- DPO fields tokenize chosen/rejected pairs with correct response masks
- Padding to max_seq_len works
- Truncation to max_seq_len works
Run: pytest spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py -v
"""
from __future__ import annotations
import sys
from pathlib import Path
import pytest
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from trl_path.data_collator import ( # noqa: E402
CollatorConfig,
ComposerDataCollator,
)
# ----------------------------------------------------------------------------
# Stub tokenizer — deterministic, character-by-character ish
# ----------------------------------------------------------------------------
class StubTokenizer:
"""Maps each unique whitespace-separated word to an integer id, deterministically.
Reserves 0 = pad, 1 = bos, 2 = eos.
"""
pad_token_id = 0
def __init__(self) -> None:
self._vocab: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
def _id_for(self, word: str) -> int:
if word not in self._vocab:
self._vocab[word] = len(self._vocab)
return self._vocab[word]
def __call__(self, text: str | list[str], **_kwargs):
if isinstance(text, list):
return {"input_ids": [self._tokenize_one(t) for t in text]}
return {"input_ids": self._tokenize_one(text)}
def _tokenize_one(self, text: str) -> list[int]:
return [self._id_for(w) for w in text.split()] if text else []
def apply_chat_template(self, messages, tokenize=True, **_kwargs): # noqa: ARG002
joined = " ".join(m.get("content", "") for m in messages)
return self._tokenize_one(joined)
# ----------------------------------------------------------------------------
# Fixtures
# ----------------------------------------------------------------------------
@pytest.fixture
def tok():
return StubTokenizer()
@pytest.fixture
def hint_gen():
"""Simple hint generator that returns a fixed hint for `tool_not_found`."""
def _gen(error_kind: str, _meta: dict) -> str | None:
if error_kind == "tool_not_found":
return "HINT use a real tool"
return None
return _gen
@pytest.fixture
def trace_no_errors():
"""Clean trace, no error sites."""
return {
"trace_id": "ok-1",
"turns": [
{"role": "user", "content": "task one"},
{"role": "assistant", "content": "answer one"},
],
"final_reward": 1.0,
}
@pytest.fixture
def trace_with_error():
"""Trace with one tool-call error in the middle."""
return {
"trace_id": "err-1",
"turns": [
{"role": "user", "content": "task two"},
{
"role": "assistant",
"content": "wrong attempt",
"tool_error": "tool_not_found",
"error_meta": {"available_tools": ["read", "write"]},
},
{"role": "tool", "content": "tool not found"},
{"role": "assistant", "content": "fixed attempt"},
],
"final_reward": 0.5,
}
@pytest.fixture
def trace_with_dpo_pairs():
return {
"trace_id": "dpo-1",
"turns": [
{"role": "user", "content": "decide"},
{"role": "assistant", "content": "option B"},
],
"final_reward": 0.0,
"dpo_pairs": [
{
"state_id": "decide-1",
"state_messages": [{"role": "user", "content": "decide"}],
"chosen": "option A",
"rejected": "option B",
"n_teachers_agreeing": 3,
}
],
}
# ----------------------------------------------------------------------------
# Channel 1: GRPO core fields
# ----------------------------------------------------------------------------
def test_grpo_fields_shape_and_dtype(tok, trace_no_errors):
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_no_errors])
assert batch["input_ids"].dtype == torch.long
assert batch["attention_mask"].dtype == torch.long
assert batch["response_mask"].dtype == torch.long
assert batch["rewards"].dtype == torch.float
assert batch["input_ids"].shape == batch["response_mask"].shape == batch["attention_mask"].shape
def test_grpo_response_mask_marks_assistant_only(tok, trace_no_errors):
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_no_errors])
response_mask = batch["response_mask"][0]
# "task one" = 2 user tokens (mask 0), "answer one" = 2 asst tokens (mask 1)
assert response_mask.tolist()[:4] == [0, 0, 1, 1]
def test_grpo_rewards_match_input(tok, trace_no_errors, trace_with_error):
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_no_errors, trace_with_error])
assert batch["rewards"].tolist() == [1.0, 0.5]
# ----------------------------------------------------------------------------
# Channel 2: SDPO hint-distill fields
# ----------------------------------------------------------------------------
def test_sdpo_skipped_when_no_hint_generator_configured(tok, trace_with_error):
"""Even with error turns, no hint generator → no SDPO fields emitted."""
cfg = CollatorConfig(hint_generator=None)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace_with_error])
assert "ctx_teacher_input_ids" not in batch
assert "sdpo_loss_mask" not in batch
def test_sdpo_skipped_when_no_error_turns(tok, hint_gen, trace_no_errors):
cfg = CollatorConfig(hint_generator=hint_gen)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace_no_errors])
assert "ctx_teacher_input_ids" not in batch
assert "sdpo_loss_mask" not in batch
def test_sdpo_emitted_when_error_turn_present(tok, hint_gen, trace_with_error):
cfg = CollatorConfig(hint_generator=hint_gen)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace_with_error])
assert "ctx_teacher_input_ids" in batch
assert "sdpo_loss_mask" in batch
assert batch["ctx_teacher_input_ids"].dtype == torch.long
assert batch["sdpo_loss_mask"].dtype == torch.long
assert batch["ctx_teacher_input_ids"].shape == batch["sdpo_loss_mask"].shape
def test_sdpo_loss_mask_marks_post_hint_tokens_only(tok, hint_gen, trace_with_error):
"""The mask should be 1 at post-hint tokens, -100 (ignore_index) elsewhere."""
cfg = CollatorConfig(hint_generator=hint_gen)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace_with_error])
mask = batch["sdpo_loss_mask"][0].tolist()
# At least one position should be loss-active
assert any(m == 1 for m in mask), f"Expected ≥1 loss-active position, got {mask}"
# All non-loss positions should be ignore_index (-100), not 0
assert all(m in (1, -100) for m in mask), f"Mask must be {{1, -100}} only, got {set(mask)}"
def test_sdpo_skipped_when_hint_generator_returns_none(tok, trace_with_error):
"""Hint generator returns None → SDPO fields not emitted (no signal to add)."""
cfg = CollatorConfig(hint_generator=lambda _kind, _meta: None)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace_with_error])
assert "ctx_teacher_input_ids" not in batch
# ----------------------------------------------------------------------------
# Channel 3: trace-replay DPO fields
# ----------------------------------------------------------------------------
def test_dpo_skipped_when_no_pairs(tok, trace_no_errors):
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_no_errors])
assert "dpo_chosen_input_ids" not in batch
def test_dpo_emitted_when_pairs_present(tok, trace_with_dpo_pairs):
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_with_dpo_pairs])
assert "dpo_chosen_input_ids" in batch
assert "dpo_rejected_input_ids" in batch
assert "dpo_chosen_response_mask" in batch
assert "dpo_rejected_response_mask" in batch
# Same number of pairs in chosen and rejected
assert batch["dpo_chosen_input_ids"].shape[0] == batch["dpo_rejected_input_ids"].shape[0]
def test_dpo_response_mask_zeros_prompt_ones_response(tok, trace_with_dpo_pairs):
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_with_dpo_pairs])
chosen_mask = batch["dpo_chosen_response_mask"][0].tolist()
# Prompt = "decide" (1 token), chosen = "option A" (2 tokens)
# Mask should be: [0, 1, 1] before any padding
non_pad = [m for m in chosen_mask if m in (0, 1)]
assert non_pad[0] == 0, "First token (prompt) should be 0 in response mask"
assert sum(non_pad) >= 1, "At least one response token should be marked 1"
# ----------------------------------------------------------------------------
# Padding / truncation
# ----------------------------------------------------------------------------
def test_padding_to_max_len(tok, trace_no_errors):
"""When traces have different lengths, all are padded to the longest in batch."""
short = trace_no_errors # 4 tokens
long_trace = {
"trace_id": "long",
"turns": [
{"role": "user", "content": "a b c d e f"},
{"role": "assistant", "content": "x y z"},
],
"final_reward": 1.0,
}
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([short, long_trace])
# Both should have the same T dimension
assert batch["input_ids"].shape[0] == 2
assert batch["input_ids"].shape == batch["response_mask"].shape
def test_truncation_to_max_seq_len(tok):
"""Traces longer than max_seq_len are truncated."""
long_text = " ".join(f"w{i}" for i in range(50))
trace = {
"trace_id": "trunc",
"turns": [{"role": "assistant", "content": long_text}],
"final_reward": 0.0,
}
cfg = CollatorConfig(max_seq_len=10)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace])
assert batch["input_ids"].shape[1] == 10
# ----------------------------------------------------------------------------
# Multi-example batches
# ----------------------------------------------------------------------------
def test_mixed_batch_some_with_errors_some_without(tok, hint_gen, trace_no_errors, trace_with_error):
"""SDPO should fire when at least one example has error turns."""
cfg = CollatorConfig(hint_generator=hint_gen)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace_no_errors, trace_with_error])
assert "ctx_teacher_input_ids" in batch
# Both rows in ctx_teacher_input_ids have the same length (batch shape)
assert batch["ctx_teacher_input_ids"].shape[0] == 2
def test_attention_mask_zeros_padding(tok, trace_no_errors):
"""attention_mask must be 0 where input_ids is the pad token."""
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_no_errors])
am = batch["attention_mask"]
ids = batch["input_ids"]
# At every padding position, attention_mask must be 0
pad_positions = (ids == 0)
assert (am[pad_positions] == 0).all()
non_pad_positions = ~pad_positions
assert (am[non_pad_positions] == 1).all()