Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 26,895 Bytes
ac05fbf 03bf323 ac05fbf 03bf323 ac05fbf 03bf323 ac05fbf 03bf323 ac05fbf 03bf323 ac05fbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 | """data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch.
Pipeline:
1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003).
2. Tokenize each turn of the trace.
3. Detect error sites (turns where a tool call failed) using a configurable predicate.
4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary.
5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position.
6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere.
7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step.
The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its
`inputs` argument. See `trl_path/composer_trainer.py` for the consumer side.
Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss
requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's
why we pad/align here rather than inside the loss function. The post-hint section of
ctx_teacher must have token-by-token alignment with the same section of ctx_student.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any, TypedDict
import torch
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class TraceTurn(TypedDict, total=False):
"""One turn of an agentic trace."""
role: str # "user" | "assistant" | "tool"
content: str # text or tool result
tool_call: dict | None # parsed tool call, if assistant-issued
tool_error: str | None # error_kind from the env, e.g. "tool_not_found"
error_meta: dict # extra info for hint generator (available_tools, etc.)
class TraceExample(TypedDict, total=False):
"""One training example: a (trace, optional DPO pairs) tuple."""
trace_id: str
turns: list[TraceTurn]
final_reward: float # RLVR scalar (test-pass etc.) at trajectory end
dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs
# ---------------------------------------------------------------------------
# Tokenizer protocol — duck-typed against HF AutoTokenizer
# ---------------------------------------------------------------------------
class TokenizerLike:
"""Minimal protocol the collator needs from a tokenizer.
Compatible with HuggingFace `AutoTokenizer` instances (the typical case),
but also satisfiable by simpler stubs for unit-testing.
"""
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover
...
def apply_chat_template( # pragma: no cover
self, messages: list[dict], **kwargs: Any
) -> str | list[int]:
...
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class CollatorConfig:
"""Tunables for ComposerDataCollator."""
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100 # standard HF "ignore in loss" sentinel
# SDPO behavior
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
"""Callable error_kind, error_meta -> hint_text (or None to skip)."""
# Trace-replay DPO behavior
enable_replay_dpo: bool = True
# Reward shaping
rlvr_reward_key: str = "final_reward"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_error_turn(turn: TraceTurn) -> bool:
"""Predicate: is this turn an error site that should trigger SDPO?"""
return turn.get("tool_error") is not None
def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]:
"""Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template."""
return [
{"role": t["role"], "content": t["content"]}
for t in turns if t.get("content")
]
def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]:
"""Right-pad with pad_id, or right-truncate to target_len."""
if len(seq) >= target_len:
return seq[:target_len]
return seq + [pad_id] * (target_len - len(seq))
# ---------------------------------------------------------------------------
# The collator
# ---------------------------------------------------------------------------
@dataclass
class ComposerDataCollator:
"""Build trainer-ready batches from raw traces + optional DPO pairs.
Usage:
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_example_0, trace_example_1, ...])
# batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer
The dict contains:
# Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer)
- input_ids: (B, T_max)
- attention_mask: (B, T_max)
- response_mask: (B, T_max)
- rewards: (B,)
# Channel 2 (SDPO hint-distill) — present when any example has error turns
- ctx_teacher_input_ids: (B, T_max)
- sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens
# Channel 3 (trace-replay DPO) — present when any example has dpo_pairs
- dpo_chosen_input_ids: (B', T_dpo)
- dpo_chosen_response_mask: (B', T_dpo)
- dpo_rejected_input_ids: (B', T_dpo)
- dpo_rejected_response_mask: (B', T_dpo)
# ref_logprobs are NOT computed here — the trainer's reference-policy
# forward pass at training time produces them.
"""
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
# --- Channel 1: GRPO core fields ---
out.update(self._build_grpo_fields(batch))
# --- Channel 2: SDPO hint-distill fields ---
if self.config.enable_sdpo:
sdpo = self._build_sdpo_fields(batch)
if sdpo is not None:
out.update(sdpo)
# Reconcile student vs teacher shapes for compose_loss's
# `student_logits.shape == teacher_logits.shape` gate.
#
# CRITICAL: hint injection adds tokens IN THE MIDDLE of
# the teacher sequence (before the recovery turn). The
# recovery turn lives at teacher positions
# [hint_end .. hint_end + len(recovery)] but at student
# positions [recovery_start .. recovery_start + len(recovery)]
# where recovery_start < hint_end. Right-padding student
# to teacher length WOULD ALIAS PAD TOKENS to the
# sdpo_loss_mask region — gives a degenerate ~ln(2)
# JSD signal that LOOKS healthy but is meaningless
# (Gemini W19 R1 BLOCKER).
#
# Correct alignment requires walking turns in lock-step,
# padding student WHERE the teacher has hint tokens so
# post-hint positions land at the same indices in both.
# That reshape lives in `_build_aligned_student_for_sdpo`.
aligned = self._build_aligned_student_for_sdpo(
batch, teacher_len=out["ctx_teacher_input_ids"].shape[1]
)
if aligned is not None:
out["input_ids"] = aligned["input_ids"]
out["attention_mask"] = aligned["attention_mask"]
out["response_mask"] = aligned["response_mask"]
# --- Channel 3: trace-replay DPO fields ---
if self.config.enable_replay_dpo:
dpo = self._build_dpo_fields(batch)
if dpo is not None:
out.update(dpo)
return out
# ----------------------------------------------------------------------
# Channel 1: standard GRPO inputs
# ----------------------------------------------------------------------
def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
input_ids_list: list[list[int]] = []
response_masks_list: list[list[int]] = []
rewards: list[float] = []
for ex in batch:
ids, resp_mask = self._tokenize_trace(ex["turns"])
input_ids_list.append(ids)
response_masks_list.append(resp_mask)
rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0)))
max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list))
input_ids = torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list],
dtype=torch.long,
)
response_mask = torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in response_masks_list],
dtype=torch.long,
)
attention_mask = (input_ids != self.config.pad_token_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"response_mask": response_mask,
"rewards": torch.tensor(rewards, dtype=torch.float),
}
# ----------------------------------------------------------------------
# Channel 2: SDPO hint-distill inputs
# ----------------------------------------------------------------------
def _build_sdpo_fields(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor] | None:
"""Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length."""
if self.config.hint_generator is None:
return None # nothing to do without a hint generator
ctx_teacher_list: list[list[int]] = []
sdpo_mask_list: list[list[int]] = []
any_error_sites = False
for ex in batch:
ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"])
ctx_teacher_list.append(ctx_teacher_ids)
sdpo_mask_list.append(sdpo_mask)
any_error_sites = any_error_sites or has_errors
if not any_error_sites:
return None # batch has no error sites — SDPO is a no-op for this step
max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list))
ctx_teacher = torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list],
dtype=torch.long,
)
sdpo_mask = torch.tensor(
[_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list],
dtype=torch.long,
)
return {
"ctx_teacher_input_ids": ctx_teacher,
"sdpo_loss_mask": sdpo_mask,
}
def _build_hint_injected_trace(
self, turns: Sequence[TraceTurn]
) -> tuple[list[int], list[int], bool]:
"""Walk the trace; at each error-turn boundary, inject a hint and mark
the post-hint tokens as in-loss.
Returns:
(ctx_teacher_ids, sdpo_loss_mask, any_error_sites)
"""
if self.config.hint_generator is None:
# Caller responsibility — short-circuited by the dispatch.
empty: list[int] = []
return empty, empty, False
teacher_messages: list[dict] = []
teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text)
any_errors = False
for turn in turns:
if _is_error_turn(turn):
hint_text = self.config.hint_generator(
turn.get("tool_error", "unknown"),
turn.get("error_meta", {}),
)
if hint_text:
any_errors = True
# Inject hint as a system-style addendum BEFORE the assistant's response
teacher_messages.append({"role": "system", "content": hint_text})
teacher_loss_segments.append((False, hint_text))
if turn.get("content"):
teacher_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
teacher_loss_segments.append((True, turn["content"])) # post-hint tokens = loss
continue
# Non-error turn (or hint generator returned None) — passthrough
if turn.get("content"):
teacher_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
teacher_loss_segments.append((False, turn["content"]))
# Tokenize the full teacher conversation
teacher_ids = self._tokenize_messages(teacher_messages)
# Build the per-token loss mask by tokenizing each segment and concatenating
sdpo_mask = self._build_segment_mask(teacher_loss_segments)
# Truncate mask to teacher_ids length if tokenization round-tripped slightly differently
sdpo_mask = sdpo_mask[: len(teacher_ids)]
if len(sdpo_mask) < len(teacher_ids):
sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask))
return teacher_ids, sdpo_mask, any_errors
def _build_aligned_student_for_sdpo(
self,
batch: Sequence[TraceExample],
teacher_len: int,
) -> dict[str, torch.Tensor] | None:
"""Build student input_ids that align position-by-position with the
hint-injected teacher sequence.
For SDPO the gate `student_logits.shape == teacher_logits.shape`
must pass AND the sdpo_loss_mask positions (built relative to the
teacher) must point to the SAME content tokens in the student.
Strategy: build student MESSAGES that mirror the teacher messages
EXCEPT the hint system-message is replaced with a placeholder
system-message whose `content` tokenizes to the same length as
the hint. Both sides go through `apply_chat_template`, so the
chat-template markers (<|im_start|>system\\n, <|im_end|>\\n, etc.)
are added identically. The recovery-turn tokens then land at the
same indices in both tensors and `sdpo_loss_mask` selects
identical content positions.
Returns None if no error sites exist.
"""
if self.config.hint_generator is None:
return None
student_ids_list: list[list[int]] = []
response_mask_list: list[list[int]] = []
any_errors = False
for ex in batch:
ids, resp_mask, has_errors = self._build_aligned_student_one(ex["turns"])
student_ids_list.append(ids)
response_mask_list.append(resp_mask)
any_errors = any_errors or has_errors
if not any_errors:
return None
max_len = teacher_len # match teacher exactly
pad_id = self.config.pad_token_id
input_ids = torch.tensor(
[_pad_or_truncate(s, max_len, pad_id) for s in student_ids_list],
dtype=torch.long,
)
response_mask = torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in response_mask_list],
dtype=torch.long,
)
attention_mask = (input_ids != pad_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"response_mask": response_mask,
}
def _make_placeholder_for_hint_length(self, hint_text: str) -> str:
"""Build a placeholder string whose tokenization length matches hint_text's.
We start with a short repeating filler ('. ') and grow it until the
tokenized length matches or exceeds the hint's. If we overshoot,
we trim. This is necessarily approximate at the character-to-token
boundary; we accept ±1 token tolerance and pad/truncate the final
student tensor to match teacher length.
"""
target_len = len(self._tokenize_text(hint_text))
if target_len == 0:
return ""
# Use a content-free placeholder that tokenizes predictably.
placeholder = ". " * target_len
ph_len = len(self._tokenize_text(placeholder))
# Trim or extend via binary-search-ish refinement (at most 6 iters).
for _ in range(6):
if ph_len == target_len:
break
if ph_len > target_len:
# Trim char-by-char
while placeholder and ph_len > target_len:
placeholder = placeholder[:-1]
ph_len = len(self._tokenize_text(placeholder))
else:
placeholder = placeholder + ". "
ph_len = len(self._tokenize_text(placeholder))
return placeholder
def _build_aligned_student_one(
self, turns: Sequence[TraceTurn]
) -> tuple[list[int], list[int], bool]:
"""Walk one trace's turns, building a STUDENT messages list that
mirrors the TEACHER messages list except hint system-messages are
replaced with placeholder system-messages of the same token length.
Returns (student_ids, response_mask, any_error_sites).
"""
if self.config.hint_generator is None:
return [], [], False
student_messages: list[dict] = []
# Track per-message (is_response_segment, text_for_response_mask)
# We build response_mask via segment tokenization, same pattern as
# teacher's _build_segment_mask, so the lengths match.
student_loss_segments: list[tuple[bool, str]] = []
any_errors = False
for turn in turns:
if _is_error_turn(turn):
hint_text = self.config.hint_generator(
turn.get("tool_error", "unknown"),
turn.get("error_meta", {}),
)
if hint_text:
any_errors = True
placeholder = self._make_placeholder_for_hint_length(hint_text)
# Student gets a placeholder system-msg at the SAME slot
# the teacher gets the hint system-msg.
student_messages.append({"role": "system", "content": placeholder})
student_loss_segments.append((False, placeholder))
if turn.get("content"):
student_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
is_assistant = turn.get("role") == "assistant"
student_loss_segments.append((is_assistant, turn["content"]))
continue
if turn.get("content"):
student_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
is_assistant = turn.get("role") == "assistant"
student_loss_segments.append((is_assistant, turn["content"]))
# Tokenize the full student conversation via apply_chat_template
# (mirrors teacher's path so chat-template markers are identical).
student_ids = self._tokenize_messages(student_messages)
# Build response mask via the same segment-tokenization helper used
# for sdpo_mask, then reinterpret 1=in-response, 0=not-in-response.
# We can't reuse _build_segment_mask (which uses ignore_index for
# non-loss); inline a 0/1 variant.
resp_mask: list[int] = []
for is_resp, text in student_loss_segments:
seg_ids = self._tokenize_text(text)
resp_mask.extend([1 if is_resp else 0] * len(seg_ids))
# Pad/truncate response_mask to student_ids length (same as teacher path).
resp_mask = resp_mask[: len(student_ids)]
if len(resp_mask) < len(student_ids):
resp_mask = resp_mask + [0] * (len(student_ids) - len(resp_mask))
return student_ids, resp_mask, any_errors
def _build_segment_mask(
self, segments: Sequence[tuple[bool, str]]
) -> list[int]:
"""For each (is_loss, text) segment, tokenize and emit per-token mask values.
Loss-active tokens get 1; non-loss tokens get -100 (ignore_index).
"""
out: list[int] = []
for is_loss, text in segments:
seg_ids = self._tokenize_text(text)
mask_value = 1 if is_loss else self.config.ignore_index
out.extend([mask_value] * len(seg_ids))
return out
# ----------------------------------------------------------------------
# Channel 3: trace-replay DPO inputs
# ----------------------------------------------------------------------
def _build_dpo_fields(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor] | None:
"""Tokenize chosen/rejected pairs from teacher disagreement.
DPO accounting requires:
- chosen_input_ids = prompt + chosen_response
- rejected_input_ids = prompt + rejected_response
- response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss)
"""
all_chosen: list[list[int]] = []
all_rejected: list[list[int]] = []
all_chosen_resp_mask: list[list[int]] = []
all_rejected_resp_mask: list[list[int]] = []
for ex in batch:
for pair in ex.get("dpo_pairs") or []:
prompt_msgs = pair.get("state_messages", [])
prompt_ids = self._tokenize_messages(prompt_msgs)
chosen_ids = self._tokenize_text(pair["chosen"])
rejected_ids = self._tokenize_text(pair["rejected"])
chosen_full = prompt_ids + chosen_ids
rejected_full = prompt_ids + rejected_ids
# response_mask is 0 over prompt, 1 over response
chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids)
rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids)
all_chosen.append(chosen_full)
all_rejected.append(rejected_full)
all_chosen_resp_mask.append(chosen_mask)
all_rejected_resp_mask.append(rejected_mask)
if not all_chosen:
return None # no DPO pairs in this batch
cap = self.config.max_dpo_seq_len
max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected)))
return {
"dpo_chosen_input_ids": torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen],
dtype=torch.long,
),
"dpo_chosen_response_mask": torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask],
dtype=torch.long,
),
"dpo_rejected_input_ids": torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected],
dtype=torch.long,
),
"dpo_rejected_response_mask": torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask],
dtype=torch.long,
),
}
# ----------------------------------------------------------------------
# Tokenization helpers
# ----------------------------------------------------------------------
def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]:
"""Tokenize an entire trace; return (ids, response_mask).
response_mask = 1 over assistant turns (those are the loss-bearing tokens
for GRPO), 0 over user/tool turns (prompt context).
"""
all_ids: list[int] = []
resp_mask: list[int] = []
for turn in turns:
if not turn.get("content"):
continue
ids = self._tokenize_text(turn["content"])
mask_value = 1 if turn.get("role") == "assistant" else 0
all_ids.extend(ids)
resp_mask.extend([mask_value] * len(ids))
return all_ids, resp_mask
def _tokenize_text(self, text: str) -> list[int]:
"""Tokenize plain text via the tokenizer's __call__."""
result = self.tokenizer(text, add_special_tokens=False)
ids = result["input_ids"]
if hasattr(ids, "tolist"):
ids = ids.tolist()
# HF tokenizers often return list[list[int]] when batch-shaped; flatten if so
if ids and isinstance(ids[0], list):
ids = ids[0]
return list(ids)
def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]:
"""Tokenize a chat-formatted list of messages.
Tries apply_chat_template first; falls back to concatenated content if not available.
NOTE: HF tokenizers' `apply_chat_template(tokenize=True)` is not
consistently typed across families. Some return `list[int]`, others
a `BatchEncoding` (a dict-like with `input_ids` key) — Qwen2.5
returns the latter. Handle both shapes here.
"""
if not messages:
return []
try:
raw = self.tokenizer.apply_chat_template(
list(messages), tokenize=True, add_generation_prompt=False
)
except (AttributeError, NotImplementedError, TypeError):
# Stub tokenizer or no chat template defined — fall back to concatenated content
text = "\n".join(m.get("content", "") for m in messages)
return self._tokenize_text(text)
# BatchEncoding (Qwen2.5 etc.): extract input_ids and unwrap if batched.
if hasattr(raw, "keys") and "input_ids" in raw:
ids = raw["input_ids"]
else:
ids = raw
if hasattr(ids, "tolist"):
ids = ids.tolist()
# If we got list[list[int]] (batch shape), unwrap the single example.
if ids and isinstance(ids[0], list):
ids = ids[0]
return list(ids)
__all__ = [
"ComposerDataCollator",
"CollatorConfig",
"TraceTurn",
"TraceExample",
"TokenizerLike",
]
|