Codeseys commited on
Commit
6806cf7
·
1 Parent(s): a384097

Wave 21: close both Wave 20 debt items — chat-template alignment + structural is_error

Browse files

Two architectural-debt items flagged "Wave 20 candidate" in the framework doc,
both CPU-only, both now fixed and regression-guarded.

## 1. SDPO mask chat-template drift (was ~67% aligned, now ~100%)

`ComposerDataCollator._build_segment_mask` built the sdpo_loss_mask (and the
aligned-student response_mask) by tokenizing each content segment in isolation
and concatenating. That ignored the scaffolding tokens apply_chat_template
inserts around every message (<|im_start|>{role}\n ... <|im_end|>\n, BOS), so
mask bits drifted left of the real content tokens — the residual ~33%
contamination documented in the Wave 19 production audit.

Fix: new `_build_chat_aligned_mask` derives the mask from per-message
apply_chat_template prefix deltas. For message k it computes the token span
len(template(msgs[:k+1])) - len(template(msgs[:k])), then locates the content
run inside that span by subsequence match and marks ONLY those positions as
loss. Falls back to whole-span marking if the content run can't be located
(tokenizer merge across the boundary) so SDPO signal is never silently dropped.
Degenerates exactly to the old concat behavior on stub tokenizers (no template),
so the 15 stub collator tests stay green.

Validated against the real Qwen2.5-0.5B-Instruct chat template: alignment ratio
67% -> 100%, and the in-loss tokens decode to exactly the recovery turn's
content with zero <|im_start|>/<|im_end|> leakage.

## 2. TOOL_ERROR_TAG string-coupling

ClaudeCodeIngester wrote the literal "[TOOL_RESULT (ERROR)]" string for
is_error:true tool_results, then the trace_examples adapter grepped that same
string back out to detect SDPO error sites. Brittle: any serialization drift
silently darkened the SDPO channel.

Fix: ingester now surfaces a structural `tool_error: True` boolean on user
messages (the is_error bool was already known, just discarded into the tag).
The adapter's new `_user_turn_has_error` reads the boolean first and falls back
to the string tag only when the structural flag is absent (backward-compat for
old traces / third-party producers). Structural flag wins both ways — a
producer can set tool_error:False to suppress a tag that appears in quoted text.

## Tests
- composer_replication/trainer/tests/ (new): 3 real-tokenizer alignment guards
(skip cleanly when transformers/model cache absent)
- test_trace_examples_adapter.py: +5 structural-flag tests incl. drift-resilience
(tag absent but flag present -> still detected) and inverse (flag False wins)
- Full package suite: 146 passed, 16 skipped, 0 failed

## Docs
- production SDPO example audit note updated: drift is fixed, sub-95% now means
a NEW regression rather than a known-residual bug

Also adds uv.lock (uv-pinned deps; pyproject already present, lock was untracked).

composer_replication/ingestion/claude_code.py CHANGED
@@ -135,9 +135,16 @@ class ClaudeCodeIngester:
135
  # Either text blocks (a real human prompt) or tool_result
136
  # blocks (an observation). Both go into history as user
137
  # messages, but we serialize them differently.
138
- flat = self._flatten_user_content(content)
139
  if flat:
140
- history.append({"role": "user", "content": flat})
 
 
 
 
 
 
 
141
 
142
  elif rec_type == "assistant":
143
  msg = rec.get("message", {})
@@ -215,9 +222,20 @@ class ClaudeCodeIngester:
215
  logger.debug("Truncated/malformed line in %s: %s", path, e)
216
  continue
217
 
218
- def _flatten_user_content(self, content: list[Any]) -> str:
219
- """Convert a user record's content list to a single string."""
 
 
 
 
 
 
 
 
 
 
220
  parts: list[str] = []
 
221
  for block in content:
222
  if not isinstance(block, dict):
223
  continue
@@ -237,11 +255,13 @@ class ClaudeCodeIngester:
237
  tc = "\n".join(sub)
238
  tu_id = block.get("tool_use_id", "<unknown>")
239
  is_err = block.get("is_error", False)
 
 
240
  tag = "[TOOL_RESULT (ERROR)]" if is_err else "[TOOL_RESULT]"
241
  parts.append(f"{tag} (id={tu_id})\n{tc}")
242
  elif bt == "image":
243
  parts.append("[IMAGE OMITTED]")
244
- return "\n\n".join(parts)
245
 
246
  def _serialize_assistant_content(
247
  self, content: list[Any], *, strip_thinking: bool,
 
135
  # Either text blocks (a real human prompt) or tool_result
136
  # blocks (an observation). Both go into history as user
137
  # messages, but we serialize them differently.
138
+ flat, had_tool_error = self._flatten_user_content(content)
139
  if flat:
140
+ user_msg: dict[str, Any] = {"role": "user", "content": flat}
141
+ if had_tool_error:
142
+ # Structural error flag — the SDPO source of truth.
143
+ # The [TOOL_RESULT (ERROR)] string tag still lives in
144
+ # `content` for readability, but downstream detection
145
+ # reads THIS boolean (see trace_examples adapter).
146
+ user_msg["tool_error"] = True
147
+ history.append(user_msg)
148
 
149
  elif rec_type == "assistant":
150
  msg = rec.get("message", {})
 
222
  logger.debug("Truncated/malformed line in %s: %s", path, e)
223
  continue
224
 
225
+ def _flatten_user_content(self, content: list[Any]) -> tuple[str, bool]:
226
+ """Convert a user record's content list to a single string.
227
+
228
+ Returns ``(flattened_text, had_tool_error)`` where ``had_tool_error``
229
+ is True iff any ``tool_result`` block in this user content carried
230
+ ``is_error: true``. The boolean is the STRUCTURAL source of truth for
231
+ SDPO error-site detection; the ``[TOOL_RESULT (ERROR)]`` string tag in
232
+ the text is kept only for human-readability and ``apply_chat_template``
233
+ rendering. Downstream consumers (the trace_examples adapter) should
234
+ read the structural flag, never grep the tag — see Wave 20 design note
235
+ on TOOL_ERROR_TAG string-coupling debt.
236
+ """
237
  parts: list[str] = []
238
+ had_tool_error = False
239
  for block in content:
240
  if not isinstance(block, dict):
241
  continue
 
255
  tc = "\n".join(sub)
256
  tu_id = block.get("tool_use_id", "<unknown>")
257
  is_err = block.get("is_error", False)
258
+ if is_err:
259
+ had_tool_error = True
260
  tag = "[TOOL_RESULT (ERROR)]" if is_err else "[TOOL_RESULT]"
261
  parts.append(f"{tag} (id={tu_id})\n{tc}")
262
  elif bt == "image":
263
  parts.append("[IMAGE OMITTED]")
264
+ return "\n\n".join(parts), had_tool_error
265
 
266
  def _serialize_assistant_content(
267
  self, content: list[Any], *, strip_thinking: bool,
composer_replication/ingestion/tests/test_trace_examples_adapter.py CHANGED
@@ -174,6 +174,126 @@ def test_tool_error_tag_matches_ingester_output():
174
  )
175
 
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # ----------------------------------------------------------------------
178
  # Empty input
179
  # ----------------------------------------------------------------------
 
174
  )
175
 
176
 
177
+ # ----------------------------------------------------------------------
178
+ # Structural error flag (Wave 20 — eliminate TOOL_ERROR_TAG coupling)
179
+ # ----------------------------------------------------------------------
180
+
181
+
182
+ def test_ingester_sets_structural_tool_error_flag():
183
+ """The ingester must set a STRUCTURAL `tool_error: True` boolean on
184
+ user messages whose source JSONL had `is_error: true`, independent of
185
+ the rendered string tag."""
186
+ ingester = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=True)
187
+ states = list(ingester.ingest(ERROR_FIXTURE))
188
+ flagged = [
189
+ m for s in states for m in s["messages"]
190
+ if m.get("role") == "user" and m.get("tool_error") is True
191
+ ]
192
+ assert flagged, (
193
+ "Expected ≥1 user message with structural tool_error=True flag; "
194
+ "the ingester is not surfacing is_error structurally."
195
+ )
196
+ # And every structurally-flagged message must also render the tag
197
+ # (the tag is kept for readability — both should co-occur on the fixture).
198
+ for m in flagged:
199
+ assert TOOL_ERROR_TAG in m["content"], (
200
+ "Structural flag set but string tag missing — the two views "
201
+ "of the same error have diverged within the ingester."
202
+ )
203
+
204
+
205
+ def test_clean_fixture_has_no_structural_flag():
206
+ """No user message on the clean fixture should carry tool_error=True."""
207
+ ingester = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=True)
208
+ states = list(ingester.ingest(OK_FIXTURE))
209
+ flagged = [
210
+ m for s in states for m in s["messages"]
211
+ if m.get("role") == "user" and m.get("tool_error")
212
+ ]
213
+ assert not flagged, f"Clean fixture should have 0 structural flags; got {len(flagged)}"
214
+
215
+
216
+ def test_structural_flag_survives_tag_drift():
217
+ """THE drift-resilience guarantee: if the rendered string tag drifts
218
+ (e.g. a future serialization change strips or renames it) but the
219
+ structural `tool_error: True` flag is present, the adapter MUST still
220
+ detect the error site. This is the entire point of the Wave 20 fix —
221
+ detection no longer depends on grepping a human-readable string."""
222
+ # Hand-build a state where the tag is ABSENT from content but the
223
+ # structural flag is set — simulating ingester serialization drift.
224
+ states = [{
225
+ "state_id": "drift-0",
226
+ "messages": [
227
+ {"role": "system", "content": "sys"},
228
+ {"role": "user", "content": "run the build"},
229
+ {"role": "assistant", "content": "[TOOL_USE] name=Bash input={}"},
230
+ # Tag DELIBERATELY absent from content; only the structural flag.
231
+ {"role": "user", "content": "build failed: missing target",
232
+ "tool_error": True},
233
+ {"role": "assistant", "content": "Let me fix the target."},
234
+ ],
235
+ }]
236
+ examples = claude_states_to_trace_examples(states)
237
+ assert len(examples) == 1
238
+ err_turns = [t for t in examples[0]["turns"] if t.get("tool_error")]
239
+ assert len(err_turns) == 1, (
240
+ "Structural flag present but adapter failed to detect the error "
241
+ "site without the string tag — the coupling fix is broken."
242
+ )
243
+ # The recovery turn is the assistant immediately after the flagged user turn.
244
+ assert err_turns[0]["content"] == "Let me fix the target."
245
+
246
+
247
+ def test_structural_false_suppresses_tag_match():
248
+ """Inverse drift case: a producer sets `tool_error: False` to assert
249
+ 'this is NOT an error' even though the rendered content happens to
250
+ contain the tag string. The structural flag must WIN over the string."""
251
+ states = [{
252
+ "state_id": "false-0",
253
+ "messages": [
254
+ {"role": "system", "content": "sys"},
255
+ {"role": "user", "content": "look at this log"},
256
+ {"role": "assistant", "content": "[TOOL_USE] name=Read input={}"},
257
+ # Content contains the tag verbatim (e.g. quoting a prior log)
258
+ # but the producer asserts it's not a live error site.
259
+ {"role": "user",
260
+ "content": f"the docs mention {TOOL_ERROR_TAG} as an example",
261
+ "tool_error": False},
262
+ {"role": "assistant", "content": "I see, that's just documentation."},
263
+ ],
264
+ }]
265
+ examples = claude_states_to_trace_examples(states)
266
+ err_turns = [t for t in examples[0]["turns"] if t.get("tool_error")]
267
+ assert not err_turns, (
268
+ "tool_error=False should suppress detection even when the string "
269
+ "tag is present in content; structural flag must take precedence."
270
+ )
271
+
272
+
273
+ def test_string_tag_fallback_when_no_structural_flag():
274
+ """Backward-compat: an OLD trace (no structural flag anywhere) with the
275
+ tag in content must STILL be detected via the string fallback path."""
276
+ states = [{
277
+ "state_id": "legacy-0",
278
+ "messages": [
279
+ {"role": "system", "content": "sys"},
280
+ {"role": "user", "content": "run it"},
281
+ {"role": "assistant", "content": "[TOOL_USE] name=Bash input={}"},
282
+ # No tool_error key at all — pure legacy serialization.
283
+ {"role": "user",
284
+ "content": f"{TOOL_ERROR_TAG} (id=x)\nno such file or directory"},
285
+ {"role": "assistant", "content": "Creating the file first."},
286
+ ],
287
+ }]
288
+ examples = claude_states_to_trace_examples(states)
289
+ err_turns = [t for t in examples[0]["turns"] if t.get("tool_error")]
290
+ assert len(err_turns) == 1, (
291
+ "Legacy trace without structural flag must fall back to the string "
292
+ "tag match; backward compatibility broken."
293
+ )
294
+ assert err_turns[0]["tool_error"] == "file_not_found"
295
+
296
+
297
  # ----------------------------------------------------------------------
298
  # Empty input
299
  # ----------------------------------------------------------------------
composer_replication/ingestion/trace_examples.py CHANGED
@@ -88,6 +88,30 @@ def default_classify_error(content: str) -> str:
88
  return "tool_error"
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # ---------------------------------------------------------------------------
92
  # Adapter
93
  # ---------------------------------------------------------------------------
@@ -163,7 +187,14 @@ def claude_states_to_trace_examples(
163
  str(c.get("text", c)) if isinstance(c, dict) else str(c)
164
  for c in prev_content
165
  )
166
- if TOOL_ERROR_TAG in prev_content:
 
 
 
 
 
 
 
167
  error_kind_found = error_kind_fn(prev_content)
168
  error_content_found = prev_content
169
  break
 
88
  return "tool_error"
89
 
90
 
91
+ def _user_turn_has_error(msg: Mapping[str, Any], flat_content: str) -> bool:
92
+ """Decide whether a user-role turn is a tool-error site.
93
+
94
+ Precedence (Wave 20 — eliminate TOOL_ERROR_TAG string-coupling):
95
+
96
+ 1. **Structural flag** — if the message dict carries an explicit
97
+ ``tool_error`` key, trust it as the source of truth. The ingester sets
98
+ ``tool_error: True`` whenever the source JSONL had ``is_error: true``.
99
+ A third-party producer can set ``tool_error: False`` to assert "no
100
+ error here" even if the rendered text happens to contain the tag.
101
+ 2. **String-tag fallback** — only when no structural flag is present
102
+ (older serialized traces, or producers that never learned the boolean
103
+ contract) do we fall back to matching ``TOOL_ERROR_TAG`` in the
104
+ rendered content. This keeps backward compatibility without making the
105
+ brittle string match the primary path.
106
+
107
+ Returns True iff the turn should trigger SDPO error-site handling.
108
+ """
109
+ structural = msg.get("tool_error")
110
+ if structural is not None:
111
+ return bool(structural)
112
+ return TOOL_ERROR_TAG in flat_content
113
+
114
+
115
  # ---------------------------------------------------------------------------
116
  # Adapter
117
  # ---------------------------------------------------------------------------
 
187
  str(c.get("text", c)) if isinstance(c, dict) else str(c)
188
  for c in prev_content
189
  )
190
+ # STRUCTURAL detection (Wave 20): the ingester sets a
191
+ # `tool_error: True` boolean on user messages whose source
192
+ # JSONL had `is_error: true`. This is the source of truth.
193
+ # We fall back to string-matching the TOOL_ERROR_TAG only
194
+ # for messages that lack the structural flag (older traces
195
+ # or third-party producers that didn't set it) — see
196
+ # `_user_turn_has_error`.
197
+ if _user_turn_has_error(prev, prev_content):
198
  error_kind_found = error_kind_fn(prev_content)
199
  error_content_found = prev_content
200
  break
composer_replication/trainer/data_collator.py CHANGED
@@ -318,8 +318,18 @@ class ComposerDataCollator:
318
 
319
  # Tokenize the full teacher conversation
320
  teacher_ids = self._tokenize_messages(teacher_messages)
321
- # Build the per-token loss mask by tokenizing each segment and concatenating
322
- sdpo_mask = self._build_segment_mask(teacher_loss_segments)
 
 
 
 
 
 
 
 
 
 
323
  # Truncate mask to teacher_ids length if tokenization round-tripped slightly differently
324
  sdpo_mask = sdpo_mask[: len(teacher_ids)]
325
  if len(sdpo_mask) < len(teacher_ids):
@@ -465,14 +475,17 @@ class ComposerDataCollator:
465
  # Tokenize the full student conversation via apply_chat_template
466
  # (mirrors teacher's path so chat-template markers are identical).
467
  student_ids = self._tokenize_messages(student_messages)
468
- # Build response mask via the same segment-tokenization helper used
469
- # for sdpo_mask, then reinterpret 1=in-response, 0=not-in-response.
470
- # We can't reuse _build_segment_mask (which uses ignore_index for
471
- # non-loss); inline a 0/1 variant.
472
- resp_mask: list[int] = []
473
- for is_resp, text in student_loss_segments:
474
- seg_ids = self._tokenize_text(text)
475
- resp_mask.extend([1 if is_resp else 0] * len(seg_ids))
 
 
 
476
  # Pad/truncate response_mask to student_ids length (same as teacher path).
477
  resp_mask = resp_mask[: len(student_ids)]
478
  if len(resp_mask) < len(student_ids):
@@ -486,6 +499,15 @@ class ComposerDataCollator:
486
  """For each (is_loss, text) segment, tokenize and emit per-token mask values.
487
 
488
  Loss-active tokens get 1; non-loss tokens get -100 (ignore_index).
 
 
 
 
 
 
 
 
 
489
  """
490
  out: list[int] = []
491
  for is_loss, text in segments:
@@ -494,6 +516,94 @@ class ComposerDataCollator:
494
  out.extend([mask_value] * len(seg_ids))
495
  return out
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  # ----------------------------------------------------------------------
498
  # Channel 3: trace-replay DPO inputs
499
  # ----------------------------------------------------------------------
 
318
 
319
  # Tokenize the full teacher conversation
320
  teacher_ids = self._tokenize_messages(teacher_messages)
321
+ # Build the per-token loss mask ALIGNED to the chat-template tokenization
322
+ # (Wave 20 fix). The old path tokenized each segment's raw text in
323
+ # isolation and concatenated; that ignored the scaffolding tokens
324
+ # (<|im_start|>{role}\n ... <|im_end|>\n, BOS, etc.) that
325
+ # apply_chat_template inserts, so mask positions drifted left of the
326
+ # real content tokens — the residual ~33% misalignment documented in
327
+ # the Wave 19 production audit. `_build_chat_aligned_mask` derives the
328
+ # mask from per-message apply_chat_template deltas instead, so loss
329
+ # bits land exactly on content tokens regardless of template markers.
330
+ sdpo_mask = self._build_chat_aligned_mask(
331
+ teacher_messages, teacher_loss_segments, teacher_ids
332
+ )
333
  # Truncate mask to teacher_ids length if tokenization round-tripped slightly differently
334
  sdpo_mask = sdpo_mask[: len(teacher_ids)]
335
  if len(sdpo_mask) < len(teacher_ids):
 
475
  # Tokenize the full student conversation via apply_chat_template
476
  # (mirrors teacher's path so chat-template markers are identical).
477
  student_ids = self._tokenize_messages(student_messages)
478
+ # Build response mask ALIGNED to the chat-template tokenization (Wave 20
479
+ # fix same drift bug as the teacher sdpo_mask path). We derive the
480
+ # mask from per-message apply_chat_template deltas so 1-bits land on
481
+ # the assistant content tokens exactly, not shifted by the template
482
+ # scaffolding. `_build_chat_aligned_mask` emits 1 for loss segments and
483
+ # ignore_index for the rest; we remap ignore_index -> 0 because the
484
+ # response_mask convention here is 1=in-response, 0=not.
485
+ raw_mask = self._build_chat_aligned_mask(
486
+ student_messages, student_loss_segments, student_ids
487
+ )
488
+ resp_mask = [1 if v == 1 else 0 for v in raw_mask]
489
  # Pad/truncate response_mask to student_ids length (same as teacher path).
490
  resp_mask = resp_mask[: len(student_ids)]
491
  if len(resp_mask) < len(student_ids):
 
499
  """For each (is_loss, text) segment, tokenize and emit per-token mask values.
500
 
501
  Loss-active tokens get 1; non-loss tokens get -100 (ignore_index).
502
+
503
+ NOTE (Wave 20): this naive per-segment concatenation IGNORES the
504
+ chat-template scaffolding that `apply_chat_template` inserts around
505
+ each message, so the resulting mask drifts out of alignment with a
506
+ sequence produced via `_tokenize_messages`. It is retained only for
507
+ the degenerate fallback inside `_build_chat_aligned_mask` and for
508
+ callers that build sequences via raw segment concatenation (no chat
509
+ template). The SDPO/response-mask paths now use
510
+ `_build_chat_aligned_mask` instead.
511
  """
512
  out: list[int] = []
513
  for is_loss, text in segments:
 
516
  out.extend([mask_value] * len(seg_ids))
517
  return out
518
 
519
+ @staticmethod
520
+ def _find_subseq(haystack: list[int], needle: list[int], start: int = 0) -> int:
521
+ """Return the index where ``needle`` first occurs in ``haystack`` at or
522
+ after ``start``, or -1 if absent. Linear scan (spans are short)."""
523
+ if not needle:
524
+ return start
525
+ n, m = len(haystack), len(needle)
526
+ for i in range(start, n - m + 1):
527
+ if haystack[i:i + m] == needle:
528
+ return i
529
+ return -1
530
+
531
+ def _build_chat_aligned_mask(
532
+ self,
533
+ messages: Sequence[dict],
534
+ segments: Sequence[tuple[bool, str]],
535
+ full_ids: list[int],
536
+ ) -> list[int]:
537
+ """Build a per-token loss mask aligned to a chat-template tokenization.
538
+
539
+ The caller builds ``messages`` and ``segments`` in lockstep — element
540
+ ``k`` of each describes the same logical chunk, where ``segments[k] =
541
+ (is_loss, content_text)`` and ``messages[k] = {role, content}``.
542
+
543
+ We need a mask over ``full_ids = apply_chat_template(messages)`` whose
544
+ 1-bits sit exactly on the content tokens of loss segments. The hard
545
+ part is that ``apply_chat_template`` inserts role/BOS/EOS scaffolding
546
+ between and around messages, so the naive ``_build_segment_mask``
547
+ (which tokenizes each content string in isolation and concatenates)
548
+ drifts: its k-th block of mask bits lands at the wrong offset because
549
+ all the preceding scaffolding tokens are unaccounted for.
550
+
551
+ Algorithm — per-message prefix deltas:
552
+
553
+ prev_len = len(apply_chat_template(messages[:k]))
554
+ cur_len = len(apply_chat_template(messages[:k+1]))
555
+ # message k occupies full_ids[prev_len : cur_len] (content + its
556
+ # own scaffolding). Locate the content token run inside that span
557
+ # by subsequence match against the content's standalone
558
+ # tokenization, mark THOSE positions with the segment value and
559
+ # leave the scaffolding as ignore_index.
560
+
561
+ Falls back gracefully:
562
+ * If the tokenizer has no usable chat template (stub / no template),
563
+ ``_tokenize_messages`` returns a plain concatenation and the prefix
564
+ deltas equal the raw content token counts — so the content
565
+ subsequence match is trivially the whole span and the result
566
+ matches ``_build_segment_mask`` exactly (stub tests stay green).
567
+ * If a content run can't be located inside its span (rare tokenizer
568
+ merges across the content/scaffolding boundary), we mark the whole
569
+ message span with the segment value when it is a loss segment, so
570
+ we never silently drop SDPO signal — we over-include by at most a
571
+ couple scaffolding tokens rather than misalign.
572
+ """
573
+ mask = [self.config.ignore_index] * len(full_ids)
574
+ prev_len = 0
575
+ search_from = 0
576
+ for k, msg in enumerate(messages):
577
+ prefix_ids = self._tokenize_messages(list(messages[: k + 1]))
578
+ cur_len = len(prefix_ids)
579
+ span_start, span_end = prev_len, cur_len
580
+ prev_len = cur_len
581
+ if span_end <= span_start:
582
+ continue
583
+ is_loss = segments[k][0] if k < len(segments) else False
584
+ content = segments[k][1] if k < len(segments) else msg.get("content", "")
585
+ if not is_loss:
586
+ search_from = span_end
587
+ continue
588
+ # Loss segment: mark only the content tokens within the span.
589
+ content_ids = self._tokenize_text(content)
590
+ # Search for the content run inside this message's span. Anchor the
591
+ # search at span_start so we don't match content from a later msg.
592
+ idx = self._find_subseq(full_ids[:span_end], content_ids, start=max(span_start, search_from))
593
+ if idx != -1 and idx >= span_start:
594
+ for p in range(idx, min(idx + len(content_ids), span_end)):
595
+ mask[p] = 1
596
+ search_from = idx + len(content_ids)
597
+ else:
598
+ # Fallback: couldn't locate the content run (tokenizer merged
599
+ # the content/scaffolding boundary). Mark the whole span as
600
+ # loss rather than drop the SDPO signal entirely. Over-includes
601
+ # at most the message's own scaffolding tokens.
602
+ for p in range(span_start, span_end):
603
+ mask[p] = 1
604
+ search_from = span_end
605
+ return mask
606
+
607
  # ----------------------------------------------------------------------
608
  # Channel 3: trace-replay DPO inputs
609
  # ----------------------------------------------------------------------
composer_replication/trainer/tests/__init__.py ADDED
File without changes
composer_replication/trainer/tests/test_chat_template_alignment.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wave 20 — chat-template alignment regression guard for the PACKAGE collator.
2
+
3
+ `composer_replication.trainer.data_collator.ComposerDataCollator` builds the
4
+ SDPO `sdpo_loss_mask` (and the aligned-student `response_mask`) so that in-loss
5
+ positions sit exactly on content tokens. The hard part is that
6
+ `apply_chat_template` inserts role/BOS/EOS scaffolding around each message; the
7
+ old `_build_segment_mask` tokenized each content string in isolation and
8
+ concatenated, so the mask drifted left of the real content tokens. The Wave 19
9
+ production audit measured this drift at ~67% aligned. Wave 20's
10
+ `_build_chat_aligned_mask` derives the mask from per-message
11
+ `apply_chat_template` prefix deltas instead, restoring ~100% alignment.
12
+
13
+ These tests use a REAL chat-template tokenizer (the stub used by
14
+ spikes/005 cannot expose the drift — its `apply_chat_template` adds no
15
+ scaffolding). They skip cleanly when transformers / the model cache is absent.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import pytest
20
+
21
+ from composer_replication.trainer.data_collator import (
22
+ CollatorConfig,
23
+ ComposerDataCollator,
24
+ )
25
+
26
+
27
+ def _load_real_chat_tokenizer():
28
+ """Return a real tokenizer with a chat template, or None to skip."""
29
+ try:
30
+ import os
31
+
32
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
33
+ os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
34
+ from transformers import AutoTokenizer
35
+ except Exception:
36
+ return None
37
+ for model in ("Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct"):
38
+ try:
39
+ t = AutoTokenizer.from_pretrained(model)
40
+ if getattr(t, "chat_template", None):
41
+ return t
42
+ except Exception:
43
+ continue
44
+ return None
45
+
46
+
47
+ _REAL_TOK = _load_real_chat_tokenizer()
48
+ _SKIP_REASON = "real chat-template tokenizer not available (offline / not cached)"
49
+
50
+
51
+ @pytest.fixture
52
+ def real_chat_tok():
53
+ if _REAL_TOK is None:
54
+ pytest.skip(_SKIP_REASON)
55
+ return _REAL_TOK
56
+
57
+
58
+ @pytest.fixture
59
+ def multiturn_error_trace():
60
+ """Multi-turn trace with an error site after several turns, so the
61
+ chat-template scaffolding drift compounds (what exposed the old 33%)."""
62
+ return {
63
+ "trace_id": "real-align-1",
64
+ "turns": [
65
+ {"role": "user", "content": "Read /etc/app/config.yaml and summarize it."},
66
+ {"role": "assistant", "content": '[TOOL_USE] name=Read input={"path":"/etc/app/config.yaml"}'},
67
+ {"role": "user", "content": "[TOOL_RESULT (ERROR)] (id=t1)\nError: no such file or directory"},
68
+ {
69
+ "role": "assistant",
70
+ "content": "The file does not exist there. Let me search for it instead.",
71
+ "tool_error": "file_not_found",
72
+ "error_meta": {"source_role": "user"},
73
+ },
74
+ {"role": "user", "content": "[TOOL_RESULT] (id=t2)\nFound /opt/app/config.yaml"},
75
+ {"role": "assistant", "content": "Found it at /opt/app/config.yaml. Reading now."},
76
+ ],
77
+ "final_reward": 0.0,
78
+ }
79
+
80
+
81
+ def _hint_gen(kind, _meta):
82
+ return f"The path was wrong (kind: {kind}). Search with Glob before reading."
83
+
84
+
85
+ def test_real_chat_template_sdpo_mask_fully_aligned(real_chat_tok, multiturn_error_trace):
86
+ """THE Wave 20 guarantee: with a REAL chat template, every in-loss
87
+ sdpo_loss_mask position must have student==teacher token id. Before the
88
+ fix this drifted to ~67% because the mask was built from per-segment
89
+ tokenization that ignored apply_chat_template scaffolding."""
90
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
91
+ collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg)
92
+ batch = collator([multiturn_error_trace])
93
+
94
+ assert "sdpo_loss_mask" in batch, "SDPO channel did not fire on the error trace"
95
+ s_in = batch["input_ids"]
96
+ t_in = batch["ctx_teacher_input_ids"]
97
+ m_in = batch["sdpo_loss_mask"]
98
+ assert s_in.shape == t_in.shape == m_in.shape
99
+
100
+ n_aligned = n_total = 0
101
+ for row in range(s_in.shape[0]):
102
+ in_loss = m_in[row] == 1
103
+ if int(in_loss.sum()) == 0:
104
+ continue
105
+ s_at = s_in[row][in_loss]
106
+ t_at = t_in[row][in_loss]
107
+ n_aligned += int((s_at == t_at).sum().item())
108
+ n_total += int(in_loss.sum().item())
109
+
110
+ assert n_total > 0, "No in-loss positions — SDPO mask is empty"
111
+ ratio = n_aligned / n_total
112
+ assert ratio >= 0.95, (
113
+ f"SDPO mask alignment is only {100 * ratio:.1f}% ({n_aligned}/{n_total}); "
114
+ f"the chat-template drift fix has regressed. Expected ~100%."
115
+ )
116
+
117
+
118
+ def test_real_chat_template_in_loss_tokens_are_content_not_scaffolding(
119
+ real_chat_tok, multiturn_error_trace
120
+ ):
121
+ """The in-loss teacher tokens must decode to the recovery turn's CONTENT,
122
+ not chat-template markers (<|im_start|>, role strings, etc.)."""
123
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
124
+ collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg)
125
+ batch = collator([multiturn_error_trace])
126
+
127
+ t_in = batch["ctx_teacher_input_ids"][0]
128
+ m_in = batch["sdpo_loss_mask"][0]
129
+ in_loss = m_in == 1
130
+ decoded = real_chat_tok.decode(t_in[in_loss].tolist())
131
+ assert "does not exist" in decoded, (
132
+ f"In-loss tokens don't contain the recovery content; got: {decoded!r}"
133
+ )
134
+ for marker in ("<|im_start|>", "<|im_end|>", "<|endoftext|>"):
135
+ assert marker not in decoded, (
136
+ f"Chat-template marker {marker!r} leaked into the in-loss span: {decoded!r}"
137
+ )
138
+
139
+
140
+ def test_real_chat_template_student_teacher_shapes_match(real_chat_tok, multiturn_error_trace):
141
+ """The SDPO gate requires student_logits.shape == teacher_logits.shape;
142
+ verify the aligned-student path produces matching sequence lengths."""
143
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
144
+ collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg)
145
+ batch = collator([multiturn_error_trace])
146
+ assert batch["input_ids"].shape == batch["ctx_teacher_input_ids"].shape
examples/sdpo_with_real_traces_production/run.py CHANGED
@@ -308,16 +308,22 @@ def main() -> int:
308
  ratio = n_aligned / n_total_in_loss
309
  log.info(" alignment audit: %d / %d in-loss positions match student==teacher (%.1f%%)",
310
  n_aligned, n_total_in_loss, 100 * ratio)
311
- if ratio < 1.0:
312
  log.warning(
313
  " NOTE: %d positions (%.1f%%) of the SDPO mask cover non-aligned "
314
- "tokens. This is a residual segment-vs-chat-template drift bug "
315
- "in the existing _build_segment_mask: the segment-tokenizer "
316
- "doesn't account for chat-template markers added by "
317
- "apply_chat_template. Tracked for Wave 20.",
 
318
  n_total_in_loss - n_aligned,
319
  100 * (1 - ratio),
320
  )
 
 
 
 
 
321
 
322
  log.info("=" * 64)
323
  log.info("Summary")
 
308
  ratio = n_aligned / n_total_in_loss
309
  log.info(" alignment audit: %d / %d in-loss positions match student==teacher (%.1f%%)",
310
  n_aligned, n_total_in_loss, 100 * ratio)
311
+ if ratio < 0.95:
312
  log.warning(
313
  " NOTE: %d positions (%.1f%%) of the SDPO mask cover non-aligned "
314
+ "tokens. As of Wave 20 the chat-template drift was fixed via "
315
+ "ComposerDataCollator._build_chat_aligned_mask (per-message "
316
+ "apply_chat_template prefix deltas). A ratio below ~100%% now "
317
+ "indicates a NEW regression — investigate the collator, not a "
318
+ "known-residual bug.",
319
  n_total_in_loss - n_aligned,
320
  100 * (1 - ratio),
321
  )
322
+ else:
323
+ log.info(
324
+ " ✓ Wave 20 chat-template alignment holding (%.1f%% — was ~67%% "
325
+ "before the _build_chat_aligned_mask fix).", 100 * ratio,
326
+ )
327
 
328
  log.info("=" * 64)
329
  log.info("Summary")
uv.lock ADDED
The diff for this file is too large to render. See raw diff