"""data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch. Pipeline: 1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003). 2. Tokenize each turn of the trace. 3. Detect error sites (turns where a tool call failed) using a configurable predicate. 4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary. 5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position. 6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere. 7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step. The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its `inputs` argument. See `trl_path/composer_trainer.py` for the consumer side. Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's why we pad/align here rather than inside the loss function. The post-hint section of ctx_teacher must have token-by-token alignment with the same section of ctx_student. """ from __future__ import annotations from collections.abc import Callable, Sequence from dataclasses import dataclass, field from typing import Any, TypedDict import torch # --------------------------------------------------------------------------- # Types # --------------------------------------------------------------------------- class TraceTurn(TypedDict, total=False): """One turn of an agentic trace.""" role: str # "user" | "assistant" | "tool" content: str # text or tool result tool_call: dict | None # parsed tool call, if assistant-issued tool_error: str | None # error_kind from the env, e.g. "tool_not_found" error_meta: dict # extra info for hint generator (available_tools, etc.) class TraceExample(TypedDict, total=False): """One training example: a (trace, optional DPO pairs) tuple.""" trace_id: str turns: list[TraceTurn] final_reward: float # RLVR scalar (test-pass etc.) at trajectory end dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs # --------------------------------------------------------------------------- # Tokenizer protocol — duck-typed against HF AutoTokenizer # --------------------------------------------------------------------------- class TokenizerLike: """Minimal protocol the collator needs from a tokenizer. Compatible with HuggingFace `AutoTokenizer` instances (the typical case), but also satisfiable by simpler stubs for unit-testing. """ pad_token_id: int def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover ... def apply_chat_template( # pragma: no cover self, messages: list[dict], **kwargs: Any ) -> str | list[int]: ... # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- @dataclass class CollatorConfig: """Tunables for ComposerDataCollator.""" max_seq_len: int = 4096 max_dpo_seq_len: int = 2048 pad_token_id: int = 0 ignore_index: int = -100 # standard HF "ignore in loss" sentinel # SDPO behavior enable_sdpo: bool = True hint_generator: Callable[[str, dict], str | None] | None = None """Callable error_kind, error_meta -> hint_text (or None to skip).""" # Trace-replay DPO behavior enable_replay_dpo: bool = True # Reward shaping rlvr_reward_key: str = "final_reward" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _is_error_turn(turn: TraceTurn) -> bool: """Predicate: is this turn an error site that should trigger SDPO?""" return turn.get("tool_error") is not None def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]: """Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template.""" return [ {"role": t["role"], "content": t["content"]} for t in turns if t.get("content") ] def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]: """Right-pad with pad_id, or right-truncate to target_len.""" if len(seq) >= target_len: return seq[:target_len] return seq + [pad_id] * (target_len - len(seq)) # --------------------------------------------------------------------------- # The collator # --------------------------------------------------------------------------- @dataclass class ComposerDataCollator: """Build trainer-ready batches from raw traces + optional DPO pairs. Usage: collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([trace_example_0, trace_example_1, ...]) # batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer The dict contains: # Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer) - input_ids: (B, T_max) - attention_mask: (B, T_max) - response_mask: (B, T_max) - rewards: (B,) # Channel 2 (SDPO hint-distill) — present when any example has error turns - ctx_teacher_input_ids: (B, T_max) - sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens # Channel 3 (trace-replay DPO) — present when any example has dpo_pairs - dpo_chosen_input_ids: (B', T_dpo) - dpo_chosen_response_mask: (B', T_dpo) - dpo_rejected_input_ids: (B', T_dpo) - dpo_rejected_response_mask: (B', T_dpo) # ref_logprobs are NOT computed here — the trainer's reference-policy # forward pass at training time produces them. """ tokenizer: TokenizerLike config: CollatorConfig = field(default_factory=CollatorConfig) def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: out: dict[str, torch.Tensor] = {} # --- Channel 1: GRPO core fields --- out.update(self._build_grpo_fields(batch)) # --- Channel 2: SDPO hint-distill fields --- if self.config.enable_sdpo: sdpo = self._build_sdpo_fields(batch) if sdpo is not None: out.update(sdpo) # --- Channel 3: trace-replay DPO fields --- if self.config.enable_replay_dpo: dpo = self._build_dpo_fields(batch) if dpo is not None: out.update(dpo) return out # ---------------------------------------------------------------------- # Channel 1: standard GRPO inputs # ---------------------------------------------------------------------- def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: input_ids_list: list[list[int]] = [] response_masks_list: list[list[int]] = [] rewards: list[float] = [] for ex in batch: ids, resp_mask = self._tokenize_trace(ex["turns"]) input_ids_list.append(ids) response_masks_list.append(resp_mask) rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0))) max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list)) input_ids = torch.tensor( [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list], dtype=torch.long, ) response_mask = torch.tensor( [_pad_or_truncate(m, max_len, 0) for m in response_masks_list], dtype=torch.long, ) attention_mask = (input_ids != self.config.pad_token_id).long() return { "input_ids": input_ids, "attention_mask": attention_mask, "response_mask": response_mask, "rewards": torch.tensor(rewards, dtype=torch.float), } # ---------------------------------------------------------------------- # Channel 2: SDPO hint-distill inputs # ---------------------------------------------------------------------- def _build_sdpo_fields( self, batch: Sequence[TraceExample] ) -> dict[str, torch.Tensor] | None: """Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length.""" if self.config.hint_generator is None: return None # nothing to do without a hint generator ctx_teacher_list: list[list[int]] = [] sdpo_mask_list: list[list[int]] = [] any_error_sites = False for ex in batch: ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"]) ctx_teacher_list.append(ctx_teacher_ids) sdpo_mask_list.append(sdpo_mask) any_error_sites = any_error_sites or has_errors if not any_error_sites: return None # batch has no error sites — SDPO is a no-op for this step max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list)) ctx_teacher = torch.tensor( [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list], dtype=torch.long, ) sdpo_mask = torch.tensor( [_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list], dtype=torch.long, ) return { "ctx_teacher_input_ids": ctx_teacher, "sdpo_loss_mask": sdpo_mask, } def _build_hint_injected_trace( self, turns: Sequence[TraceTurn] ) -> tuple[list[int], list[int], bool]: """Walk the trace; at each error-turn boundary, inject a hint and mark the post-hint tokens as in-loss. Returns: (ctx_teacher_ids, sdpo_loss_mask, any_error_sites) """ if self.config.hint_generator is None: # Caller responsibility — short-circuited by the dispatch. empty: list[int] = [] return empty, empty, False teacher_messages: list[dict] = [] teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text) any_errors = False for turn in turns: if _is_error_turn(turn): hint_text = self.config.hint_generator( turn.get("tool_error", "unknown"), turn.get("error_meta", {}), ) if hint_text: any_errors = True # Inject hint as a system-style addendum BEFORE the assistant's response teacher_messages.append({"role": "system", "content": hint_text}) teacher_loss_segments.append((False, hint_text)) if turn.get("content"): teacher_messages.append({ "role": turn.get("role", "assistant"), "content": turn["content"], }) teacher_loss_segments.append((True, turn["content"])) # post-hint tokens = loss continue # Non-error turn (or hint generator returned None) — passthrough if turn.get("content"): teacher_messages.append({ "role": turn.get("role", "assistant"), "content": turn["content"], }) teacher_loss_segments.append((False, turn["content"])) # Tokenize the full teacher conversation teacher_ids = self._tokenize_messages(teacher_messages) # Build the per-token loss mask by tokenizing each segment and concatenating sdpo_mask = self._build_segment_mask(teacher_loss_segments) # Truncate mask to teacher_ids length if tokenization round-tripped slightly differently sdpo_mask = sdpo_mask[: len(teacher_ids)] if len(sdpo_mask) < len(teacher_ids): sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask)) return teacher_ids, sdpo_mask, any_errors def _build_segment_mask( self, segments: Sequence[tuple[bool, str]] ) -> list[int]: """For each (is_loss, text) segment, tokenize and emit per-token mask values. Loss-active tokens get 1; non-loss tokens get -100 (ignore_index). """ out: list[int] = [] for is_loss, text in segments: seg_ids = self._tokenize_text(text) mask_value = 1 if is_loss else self.config.ignore_index out.extend([mask_value] * len(seg_ids)) return out # ---------------------------------------------------------------------- # Channel 3: trace-replay DPO inputs # ---------------------------------------------------------------------- def _build_dpo_fields( self, batch: Sequence[TraceExample] ) -> dict[str, torch.Tensor] | None: """Tokenize chosen/rejected pairs from teacher disagreement. DPO accounting requires: - chosen_input_ids = prompt + chosen_response - rejected_input_ids = prompt + rejected_response - response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss) """ all_chosen: list[list[int]] = [] all_rejected: list[list[int]] = [] all_chosen_resp_mask: list[list[int]] = [] all_rejected_resp_mask: list[list[int]] = [] for ex in batch: for pair in ex.get("dpo_pairs") or []: prompt_msgs = pair.get("state_messages", []) prompt_ids = self._tokenize_messages(prompt_msgs) chosen_ids = self._tokenize_text(pair["chosen"]) rejected_ids = self._tokenize_text(pair["rejected"]) chosen_full = prompt_ids + chosen_ids rejected_full = prompt_ids + rejected_ids # response_mask is 0 over prompt, 1 over response chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids) rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids) all_chosen.append(chosen_full) all_rejected.append(rejected_full) all_chosen_resp_mask.append(chosen_mask) all_rejected_resp_mask.append(rejected_mask) if not all_chosen: return None # no DPO pairs in this batch cap = self.config.max_dpo_seq_len max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected))) return { "dpo_chosen_input_ids": torch.tensor( [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen], dtype=torch.long, ), "dpo_chosen_response_mask": torch.tensor( [_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask], dtype=torch.long, ), "dpo_rejected_input_ids": torch.tensor( [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected], dtype=torch.long, ), "dpo_rejected_response_mask": torch.tensor( [_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask], dtype=torch.long, ), } # ---------------------------------------------------------------------- # Tokenization helpers # ---------------------------------------------------------------------- def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]: """Tokenize an entire trace; return (ids, response_mask). response_mask = 1 over assistant turns (those are the loss-bearing tokens for GRPO), 0 over user/tool turns (prompt context). """ all_ids: list[int] = [] resp_mask: list[int] = [] for turn in turns: if not turn.get("content"): continue ids = self._tokenize_text(turn["content"]) mask_value = 1 if turn.get("role") == "assistant" else 0 all_ids.extend(ids) resp_mask.extend([mask_value] * len(ids)) return all_ids, resp_mask def _tokenize_text(self, text: str) -> list[int]: """Tokenize plain text via the tokenizer's __call__.""" result = self.tokenizer(text, add_special_tokens=False) ids = result["input_ids"] if hasattr(ids, "tolist"): ids = ids.tolist() # HF tokenizers often return list[list[int]] when batch-shaped; flatten if so if ids and isinstance(ids[0], list): ids = ids[0] return list(ids) def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]: """Tokenize a chat-formatted list of messages. Tries apply_chat_template first; falls back to concatenated content if not available. """ if not messages: return [] try: ids = self.tokenizer.apply_chat_template( list(messages), tokenize=True, add_generation_prompt=False ) if hasattr(ids, "tolist"): ids = ids.tolist() return list(ids) except (AttributeError, NotImplementedError, TypeError): # Stub tokenizer or no chat template defined — fall back to concatenated content text = "\n".join(m.get("content", "") for m in messages) return self._tokenize_text(text) __all__ = [ "ComposerDataCollator", "CollatorConfig", "TraceTurn", "TraceExample", "TokenizerLike", ]