Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch. | |
| Pipeline: | |
| 1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003). | |
| 2. Tokenize each turn of the trace. | |
| 3. Detect error sites (turns where a tool call failed) using a configurable predicate. | |
| 4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary. | |
| 5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position. | |
| 6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere. | |
| 7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step. | |
| The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its | |
| `inputs` argument. See `trl_path/composer_trainer.py` for the consumer side. | |
| Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss | |
| requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's | |
| why we pad/align here rather than inside the loss function. The post-hint section of | |
| ctx_teacher must have token-by-token alignment with the same section of ctx_student. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Callable, Sequence | |
| from dataclasses import dataclass, field | |
| from typing import Any, TypedDict | |
| import torch | |
| # --------------------------------------------------------------------------- | |
| # Types | |
| # --------------------------------------------------------------------------- | |
| class TraceTurn(TypedDict, total=False): | |
| """One turn of an agentic trace.""" | |
| role: str # "user" | "assistant" | "tool" | |
| content: str # text or tool result | |
| tool_call: dict | None # parsed tool call, if assistant-issued | |
| tool_error: str | None # error_kind from the env, e.g. "tool_not_found" | |
| error_meta: dict # extra info for hint generator (available_tools, etc.) | |
| class TraceExample(TypedDict, total=False): | |
| """One training example: a (trace, optional DPO pairs) tuple.""" | |
| trace_id: str | |
| turns: list[TraceTurn] | |
| final_reward: float # RLVR scalar (test-pass etc.) at trajectory end | |
| dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer protocol — duck-typed against HF AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| class TokenizerLike: | |
| """Minimal protocol the collator needs from a tokenizer. | |
| Compatible with HuggingFace `AutoTokenizer` instances (the typical case), | |
| but also satisfiable by simpler stubs for unit-testing. | |
| """ | |
| pad_token_id: int | |
| def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover | |
| ... | |
| def apply_chat_template( # pragma: no cover | |
| self, messages: list[dict], **kwargs: Any | |
| ) -> str | list[int]: | |
| ... | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| class CollatorConfig: | |
| """Tunables for ComposerDataCollator.""" | |
| max_seq_len: int = 4096 | |
| max_dpo_seq_len: int = 2048 | |
| pad_token_id: int = 0 | |
| ignore_index: int = -100 # standard HF "ignore in loss" sentinel | |
| # SDPO behavior | |
| enable_sdpo: bool = True | |
| hint_generator: Callable[[str, dict], str | None] | None = None | |
| """Callable error_kind, error_meta -> hint_text (or None to skip).""" | |
| # Trace-replay DPO behavior | |
| enable_replay_dpo: bool = True | |
| # Reward shaping | |
| rlvr_reward_key: str = "final_reward" | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _is_error_turn(turn: TraceTurn) -> bool: | |
| """Predicate: is this turn an error site that should trigger SDPO?""" | |
| return turn.get("tool_error") is not None | |
| def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]: | |
| """Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template.""" | |
| return [ | |
| {"role": t["role"], "content": t["content"]} | |
| for t in turns if t.get("content") | |
| ] | |
| def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]: | |
| """Right-pad with pad_id, or right-truncate to target_len.""" | |
| if len(seq) >= target_len: | |
| return seq[:target_len] | |
| return seq + [pad_id] * (target_len - len(seq)) | |
| # --------------------------------------------------------------------------- | |
| # The collator | |
| # --------------------------------------------------------------------------- | |
| class ComposerDataCollator: | |
| """Build trainer-ready batches from raw traces + optional DPO pairs. | |
| Usage: | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_example_0, trace_example_1, ...]) | |
| # batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer | |
| The dict contains: | |
| # Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer) | |
| - input_ids: (B, T_max) | |
| - attention_mask: (B, T_max) | |
| - response_mask: (B, T_max) | |
| - rewards: (B,) | |
| # Channel 2 (SDPO hint-distill) — present when any example has error turns | |
| - ctx_teacher_input_ids: (B, T_max) | |
| - sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens | |
| # Channel 3 (trace-replay DPO) — present when any example has dpo_pairs | |
| - dpo_chosen_input_ids: (B', T_dpo) | |
| - dpo_chosen_response_mask: (B', T_dpo) | |
| - dpo_rejected_input_ids: (B', T_dpo) | |
| - dpo_rejected_response_mask: (B', T_dpo) | |
| # ref_logprobs are NOT computed here — the trainer's reference-policy | |
| # forward pass at training time produces them. | |
| """ | |
| tokenizer: TokenizerLike | |
| config: CollatorConfig = field(default_factory=CollatorConfig) | |
| def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: | |
| out: dict[str, torch.Tensor] = {} | |
| # --- Channel 1: GRPO core fields --- | |
| out.update(self._build_grpo_fields(batch)) | |
| # --- Channel 2: SDPO hint-distill fields --- | |
| if self.config.enable_sdpo: | |
| sdpo = self._build_sdpo_fields(batch) | |
| if sdpo is not None: | |
| out.update(sdpo) | |
| # Reconcile student vs teacher shapes for compose_loss's | |
| # `student_logits.shape == teacher_logits.shape` gate. | |
| # | |
| # CRITICAL: hint injection adds tokens IN THE MIDDLE of | |
| # the teacher sequence (before the recovery turn). The | |
| # recovery turn lives at teacher positions | |
| # [hint_end .. hint_end + len(recovery)] but at student | |
| # positions [recovery_start .. recovery_start + len(recovery)] | |
| # where recovery_start < hint_end. Right-padding student | |
| # to teacher length WOULD ALIAS PAD TOKENS to the | |
| # sdpo_loss_mask region — gives a degenerate ~ln(2) | |
| # JSD signal that LOOKS healthy but is meaningless | |
| # (Gemini W19 R1 BLOCKER). | |
| # | |
| # Correct alignment requires walking turns in lock-step, | |
| # padding student WHERE the teacher has hint tokens so | |
| # post-hint positions land at the same indices in both. | |
| # That reshape lives in `_build_aligned_student_for_sdpo`. | |
| aligned = self._build_aligned_student_for_sdpo( | |
| batch, teacher_len=out["ctx_teacher_input_ids"].shape[1] | |
| ) | |
| if aligned is not None: | |
| out["input_ids"] = aligned["input_ids"] | |
| out["attention_mask"] = aligned["attention_mask"] | |
| out["response_mask"] = aligned["response_mask"] | |
| # --- Channel 3: trace-replay DPO fields --- | |
| if self.config.enable_replay_dpo: | |
| dpo = self._build_dpo_fields(batch) | |
| if dpo is not None: | |
| out.update(dpo) | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Channel 1: standard GRPO inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: | |
| input_ids_list: list[list[int]] = [] | |
| response_masks_list: list[list[int]] = [] | |
| rewards: list[float] = [] | |
| for ex in batch: | |
| ids, resp_mask = self._tokenize_trace(ex["turns"]) | |
| input_ids_list.append(ids) | |
| response_masks_list.append(resp_mask) | |
| rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0))) | |
| max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list)) | |
| input_ids = torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list], | |
| dtype=torch.long, | |
| ) | |
| response_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in response_masks_list], | |
| dtype=torch.long, | |
| ) | |
| attention_mask = (input_ids != self.config.pad_token_id).long() | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "response_mask": response_mask, | |
| "rewards": torch.tensor(rewards, dtype=torch.float), | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Channel 2: SDPO hint-distill inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_sdpo_fields( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length.""" | |
| if self.config.hint_generator is None: | |
| return None # nothing to do without a hint generator | |
| ctx_teacher_list: list[list[int]] = [] | |
| sdpo_mask_list: list[list[int]] = [] | |
| any_error_sites = False | |
| for ex in batch: | |
| ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"]) | |
| ctx_teacher_list.append(ctx_teacher_ids) | |
| sdpo_mask_list.append(sdpo_mask) | |
| any_error_sites = any_error_sites or has_errors | |
| if not any_error_sites: | |
| return None # batch has no error sites — SDPO is a no-op for this step | |
| max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list)) | |
| ctx_teacher = torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list], | |
| dtype=torch.long, | |
| ) | |
| sdpo_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list], | |
| dtype=torch.long, | |
| ) | |
| return { | |
| "ctx_teacher_input_ids": ctx_teacher, | |
| "sdpo_loss_mask": sdpo_mask, | |
| } | |
| def _build_hint_injected_trace( | |
| self, turns: Sequence[TraceTurn] | |
| ) -> tuple[list[int], list[int], bool]: | |
| """Walk the trace; at each error-turn boundary, inject a hint and mark | |
| the post-hint tokens as in-loss. | |
| Returns: | |
| (ctx_teacher_ids, sdpo_loss_mask, any_error_sites) | |
| """ | |
| if self.config.hint_generator is None: | |
| # Caller responsibility — short-circuited by the dispatch. | |
| empty: list[int] = [] | |
| return empty, empty, False | |
| teacher_messages: list[dict] = [] | |
| teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text) | |
| any_errors = False | |
| for turn in turns: | |
| if _is_error_turn(turn): | |
| hint_text = self.config.hint_generator( | |
| turn.get("tool_error", "unknown"), | |
| turn.get("error_meta", {}), | |
| ) | |
| if hint_text: | |
| any_errors = True | |
| # Inject hint as a system-style addendum BEFORE the assistant's response | |
| teacher_messages.append({"role": "system", "content": hint_text}) | |
| teacher_loss_segments.append((False, hint_text)) | |
| if turn.get("content"): | |
| teacher_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": turn["content"], | |
| }) | |
| teacher_loss_segments.append((True, turn["content"])) # post-hint tokens = loss | |
| continue | |
| # Non-error turn (or hint generator returned None) — passthrough | |
| if turn.get("content"): | |
| teacher_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": turn["content"], | |
| }) | |
| teacher_loss_segments.append((False, turn["content"])) | |
| # Tokenize the full teacher conversation | |
| teacher_ids = self._tokenize_messages(teacher_messages) | |
| # Build the per-token loss mask by tokenizing each segment and concatenating | |
| sdpo_mask = self._build_segment_mask(teacher_loss_segments) | |
| # Truncate mask to teacher_ids length if tokenization round-tripped slightly differently | |
| sdpo_mask = sdpo_mask[: len(teacher_ids)] | |
| if len(sdpo_mask) < len(teacher_ids): | |
| sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask)) | |
| return teacher_ids, sdpo_mask, any_errors | |
| def _build_aligned_student_for_sdpo( | |
| self, | |
| batch: Sequence[TraceExample], | |
| teacher_len: int, | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Build student input_ids that align position-by-position with the | |
| hint-injected teacher sequence. | |
| For SDPO the gate `student_logits.shape == teacher_logits.shape` | |
| must pass AND the sdpo_loss_mask positions (built relative to the | |
| teacher) must point to the SAME content tokens in the student. | |
| Strategy: build student MESSAGES that mirror the teacher messages | |
| EXCEPT the hint system-message is replaced with a placeholder | |
| system-message whose `content` tokenizes to the same length as | |
| the hint. Both sides go through `apply_chat_template`, so the | |
| chat-template markers (<|im_start|>system\\n, <|im_end|>\\n, etc.) | |
| are added identically. The recovery-turn tokens then land at the | |
| same indices in both tensors and `sdpo_loss_mask` selects | |
| identical content positions. | |
| Returns None if no error sites exist. | |
| """ | |
| if self.config.hint_generator is None: | |
| return None | |
| student_ids_list: list[list[int]] = [] | |
| response_mask_list: list[list[int]] = [] | |
| any_errors = False | |
| for ex in batch: | |
| ids, resp_mask, has_errors = self._build_aligned_student_one(ex["turns"]) | |
| student_ids_list.append(ids) | |
| response_mask_list.append(resp_mask) | |
| any_errors = any_errors or has_errors | |
| if not any_errors: | |
| return None | |
| max_len = teacher_len # match teacher exactly | |
| pad_id = self.config.pad_token_id | |
| input_ids = torch.tensor( | |
| [_pad_or_truncate(s, max_len, pad_id) for s in student_ids_list], | |
| dtype=torch.long, | |
| ) | |
| response_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in response_mask_list], | |
| dtype=torch.long, | |
| ) | |
| attention_mask = (input_ids != pad_id).long() | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "response_mask": response_mask, | |
| } | |
| def _make_placeholder_for_hint_length(self, hint_text: str) -> str: | |
| """Build a placeholder string whose tokenization length matches hint_text's. | |
| We start with a short repeating filler ('. ') and grow it until the | |
| tokenized length matches or exceeds the hint's. If we overshoot, | |
| we trim. This is necessarily approximate at the character-to-token | |
| boundary; we accept ±1 token tolerance and pad/truncate the final | |
| student tensor to match teacher length. | |
| """ | |
| target_len = len(self._tokenize_text(hint_text)) | |
| if target_len == 0: | |
| return "" | |
| # Use a content-free placeholder that tokenizes predictably. | |
| placeholder = ". " * target_len | |
| ph_len = len(self._tokenize_text(placeholder)) | |
| # Trim or extend via binary-search-ish refinement (at most 6 iters). | |
| for _ in range(6): | |
| if ph_len == target_len: | |
| break | |
| if ph_len > target_len: | |
| # Trim char-by-char | |
| while placeholder and ph_len > target_len: | |
| placeholder = placeholder[:-1] | |
| ph_len = len(self._tokenize_text(placeholder)) | |
| else: | |
| placeholder = placeholder + ". " | |
| ph_len = len(self._tokenize_text(placeholder)) | |
| return placeholder | |
| def _build_aligned_student_one( | |
| self, turns: Sequence[TraceTurn] | |
| ) -> tuple[list[int], list[int], bool]: | |
| """Walk one trace's turns, building a STUDENT messages list that | |
| mirrors the TEACHER messages list except hint system-messages are | |
| replaced with placeholder system-messages of the same token length. | |
| Returns (student_ids, response_mask, any_error_sites). | |
| """ | |
| if self.config.hint_generator is None: | |
| return [], [], False | |
| student_messages: list[dict] = [] | |
| # Track per-message (is_response_segment, text_for_response_mask) | |
| # We build response_mask via segment tokenization, same pattern as | |
| # teacher's _build_segment_mask, so the lengths match. | |
| student_loss_segments: list[tuple[bool, str]] = [] | |
| any_errors = False | |
| for turn in turns: | |
| if _is_error_turn(turn): | |
| hint_text = self.config.hint_generator( | |
| turn.get("tool_error", "unknown"), | |
| turn.get("error_meta", {}), | |
| ) | |
| if hint_text: | |
| any_errors = True | |
| placeholder = self._make_placeholder_for_hint_length(hint_text) | |
| # Student gets a placeholder system-msg at the SAME slot | |
| # the teacher gets the hint system-msg. | |
| student_messages.append({"role": "system", "content": placeholder}) | |
| student_loss_segments.append((False, placeholder)) | |
| if turn.get("content"): | |
| student_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": turn["content"], | |
| }) | |
| is_assistant = turn.get("role") == "assistant" | |
| student_loss_segments.append((is_assistant, turn["content"])) | |
| continue | |
| if turn.get("content"): | |
| student_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": turn["content"], | |
| }) | |
| is_assistant = turn.get("role") == "assistant" | |
| student_loss_segments.append((is_assistant, turn["content"])) | |
| # Tokenize the full student conversation via apply_chat_template | |
| # (mirrors teacher's path so chat-template markers are identical). | |
| student_ids = self._tokenize_messages(student_messages) | |
| # Build response mask via the same segment-tokenization helper used | |
| # for sdpo_mask, then reinterpret 1=in-response, 0=not-in-response. | |
| # We can't reuse _build_segment_mask (which uses ignore_index for | |
| # non-loss); inline a 0/1 variant. | |
| resp_mask: list[int] = [] | |
| for is_resp, text in student_loss_segments: | |
| seg_ids = self._tokenize_text(text) | |
| resp_mask.extend([1 if is_resp else 0] * len(seg_ids)) | |
| # Pad/truncate response_mask to student_ids length (same as teacher path). | |
| resp_mask = resp_mask[: len(student_ids)] | |
| if len(resp_mask) < len(student_ids): | |
| resp_mask = resp_mask + [0] * (len(student_ids) - len(resp_mask)) | |
| return student_ids, resp_mask, any_errors | |
| def _build_segment_mask( | |
| self, segments: Sequence[tuple[bool, str]] | |
| ) -> list[int]: | |
| """For each (is_loss, text) segment, tokenize and emit per-token mask values. | |
| Loss-active tokens get 1; non-loss tokens get -100 (ignore_index). | |
| """ | |
| out: list[int] = [] | |
| for is_loss, text in segments: | |
| seg_ids = self._tokenize_text(text) | |
| mask_value = 1 if is_loss else self.config.ignore_index | |
| out.extend([mask_value] * len(seg_ids)) | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Channel 3: trace-replay DPO inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_dpo_fields( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Tokenize chosen/rejected pairs from teacher disagreement. | |
| DPO accounting requires: | |
| - chosen_input_ids = prompt + chosen_response | |
| - rejected_input_ids = prompt + rejected_response | |
| - response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss) | |
| """ | |
| all_chosen: list[list[int]] = [] | |
| all_rejected: list[list[int]] = [] | |
| all_chosen_resp_mask: list[list[int]] = [] | |
| all_rejected_resp_mask: list[list[int]] = [] | |
| for ex in batch: | |
| for pair in ex.get("dpo_pairs") or []: | |
| prompt_msgs = pair.get("state_messages", []) | |
| prompt_ids = self._tokenize_messages(prompt_msgs) | |
| chosen_ids = self._tokenize_text(pair["chosen"]) | |
| rejected_ids = self._tokenize_text(pair["rejected"]) | |
| chosen_full = prompt_ids + chosen_ids | |
| rejected_full = prompt_ids + rejected_ids | |
| # response_mask is 0 over prompt, 1 over response | |
| chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids) | |
| rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids) | |
| all_chosen.append(chosen_full) | |
| all_rejected.append(rejected_full) | |
| all_chosen_resp_mask.append(chosen_mask) | |
| all_rejected_resp_mask.append(rejected_mask) | |
| if not all_chosen: | |
| return None # no DPO pairs in this batch | |
| cap = self.config.max_dpo_seq_len | |
| max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected))) | |
| return { | |
| "dpo_chosen_input_ids": torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen], | |
| dtype=torch.long, | |
| ), | |
| "dpo_chosen_response_mask": torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask], | |
| dtype=torch.long, | |
| ), | |
| "dpo_rejected_input_ids": torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected], | |
| dtype=torch.long, | |
| ), | |
| "dpo_rejected_response_mask": torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask], | |
| dtype=torch.long, | |
| ), | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Tokenization helpers | |
| # ---------------------------------------------------------------------- | |
| def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]: | |
| """Tokenize an entire trace; return (ids, response_mask). | |
| response_mask = 1 over assistant turns (those are the loss-bearing tokens | |
| for GRPO), 0 over user/tool turns (prompt context). | |
| """ | |
| all_ids: list[int] = [] | |
| resp_mask: list[int] = [] | |
| for turn in turns: | |
| if not turn.get("content"): | |
| continue | |
| ids = self._tokenize_text(turn["content"]) | |
| mask_value = 1 if turn.get("role") == "assistant" else 0 | |
| all_ids.extend(ids) | |
| resp_mask.extend([mask_value] * len(ids)) | |
| return all_ids, resp_mask | |
| def _tokenize_text(self, text: str) -> list[int]: | |
| """Tokenize plain text via the tokenizer's __call__.""" | |
| result = self.tokenizer(text, add_special_tokens=False) | |
| ids = result["input_ids"] | |
| if hasattr(ids, "tolist"): | |
| ids = ids.tolist() | |
| # HF tokenizers often return list[list[int]] when batch-shaped; flatten if so | |
| if ids and isinstance(ids[0], list): | |
| ids = ids[0] | |
| return list(ids) | |
| def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]: | |
| """Tokenize a chat-formatted list of messages. | |
| Tries apply_chat_template first; falls back to concatenated content if not available. | |
| NOTE: HF tokenizers' `apply_chat_template(tokenize=True)` is not | |
| consistently typed across families. Some return `list[int]`, others | |
| a `BatchEncoding` (a dict-like with `input_ids` key) — Qwen2.5 | |
| returns the latter. Handle both shapes here. | |
| """ | |
| if not messages: | |
| return [] | |
| try: | |
| raw = self.tokenizer.apply_chat_template( | |
| list(messages), tokenize=True, add_generation_prompt=False | |
| ) | |
| except (AttributeError, NotImplementedError, TypeError): | |
| # Stub tokenizer or no chat template defined — fall back to concatenated content | |
| text = "\n".join(m.get("content", "") for m in messages) | |
| return self._tokenize_text(text) | |
| # BatchEncoding (Qwen2.5 etc.): extract input_ids and unwrap if batched. | |
| if hasattr(raw, "keys") and "input_ids" in raw: | |
| ids = raw["input_ids"] | |
| else: | |
| ids = raw | |
| if hasattr(ids, "tolist"): | |
| ids = ids.tolist() | |
| # If we got list[list[int]] (batch shape), unwrap the single example. | |
| if ids and isinstance(ids[0], list): | |
| ids = ids[0] | |
| return list(ids) | |
| __all__ = [ | |
| "ComposerDataCollator", | |
| "CollatorConfig", | |
| "TraceTurn", | |
| "TraceExample", | |
| "TokenizerLike", | |
| ] | |