Codeseys's picture
Wave 19: production-grade SDPO via ComposerDataCollator + adapter + collator fixes
03bf323
"""Adapter: ClaudeCodeIngester output → ComposerDataCollator input.
The ingester (`composer_replication.ingestion.claude_code.ClaudeCodeIngester`)
emits `TraceState` dicts with a `messages` field — a list of OpenAI-style
chat dicts. The data collator (`composer_replication.trainer.data_collator
.ComposerDataCollator`) expects `TraceExample` dicts with a `turns` field —
a list of `TraceTurn` dicts where each turn carries its own role, content,
and (critically) `tool_error` field for SDPO error-site detection.
This module bridges the two. The adapter:
1. Consumes a `TraceState` from the ingester.
2. Converts its `messages` (chat dicts) → `turns` (TraceTurns).
3. Detects tool-error sites by looking for the `[TOOL_RESULT (ERROR)]`
tag the ingester writes (per Claude Code's `is_error: true` flag in
the source JSONL).
4. Marks the assistant turn IMMEDIATELY AFTER an error tool-result with
`tool_error="<error_kind>"` so the data collator's
`_build_hint_injected_trace` recognizes it as an SDPO error site.
Usage:
from composer_replication.ingestion import ClaudeCodeIngester
from composer_replication.ingestion.trace_examples import (
claude_states_to_trace_examples,
)
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig,
)
ingester = ClaudeCodeIngester()
states = list(ingester.ingest(session_jsonl_path))
examples = claude_states_to_trace_examples(states)
config = CollatorConfig(
hint_generator=lambda kind, meta: "Hint: try a different path.",
enable_replay_dpo=False,
)
collator = ComposerDataCollator(tokenizer=tok, config=config)
batch = collator(examples)
# batch now has properly-aligned ctx_teacher_input_ids + sdpo_loss_mask
This is the production-grade alignment path. Wave 18's
`examples/sdpo_with_real_traces/` is a wiring smoke that bypasses this
adapter; Wave 19's `examples/sdpo_with_real_traces_production/` uses
this adapter for the real alignment.
"""
from __future__ import annotations
import re
from typing import Any, Iterable, Mapping
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# The ingester writes this tag for tool_results where the source JSONL had
# is_error: true. We detect error sites by string-matching this tag in the
# user-turn content. Matches the `tag = "[TOOL_RESULT (ERROR)]"` literal
# in `composer_replication.ingestion.claude_code._serialize_user_content`.
TOOL_ERROR_TAG = "[TOOL_RESULT (ERROR)]"
# Heuristic: classify the error_kind by simple keyword match on the error
# content. The data collator's `hint_generator` receives this string as
# its first argument so the hint can be tailored. These categories are a
# minimal v0 set; users can extend by passing their own classifier
# function via the `error_kind_fn` parameter.
_ERROR_KIND_PATTERNS = [
# Order matters: command_not_found must come BEFORE file_not_found
# since "command not found" would also match a generic "not found".
("command_not_found", re.compile(r"(?i)command not found")),
("file_not_found", re.compile(r"(?i)\b(file does not exist|no such file or directory|file not found)\b")),
("permission_denied", re.compile(r"(?i)permission denied")),
("syntax_error", re.compile(r"(?i)syntax\s*error")),
("connection_error", re.compile(r"(?i)\b(connection|network|timeout) (error|refused)\b")),
]
def default_classify_error(content: str) -> str:
"""Classify a tool-error message into a short error_kind string.
Returns one of the named categories above, or "tool_error" for
anything unmatched. Users can override by passing their own
`error_kind_fn` to `claude_states_to_trace_examples`.
"""
for kind, pattern in _ERROR_KIND_PATTERNS:
if pattern.search(content):
return kind
return "tool_error"
# ---------------------------------------------------------------------------
# Adapter
# ---------------------------------------------------------------------------
def claude_states_to_trace_examples(
states: Iterable[Mapping[str, Any]],
*,
error_kind_fn=default_classify_error,
final_reward: float = 0.0,
) -> list[dict[str, Any]]:
"""Convert ClaudeCodeIngester TraceState dicts → TraceExample dicts.
Each input state's `messages` list (OpenAI chat dicts) is rewritten
as a `turns` list of TraceTurn dicts. Tool-error sites are detected
by matching the `[TOOL_RESULT (ERROR)]` tag in user-role messages
(the ingester writes this tag whenever the source JSONL had
`is_error: true`). When found, the assistant turn IMMEDIATELY after
the error tool-result gets its `tool_error` field populated, which
is what `ComposerDataCollator._build_hint_injected_trace` checks via
`_is_error_turn`.
Args:
states: iterable of TraceState dicts (from `ClaudeCodeIngester.ingest`).
error_kind_fn: callable(error_content) -> str for classifying
errors. Defaults to the keyword-match classifier above.
final_reward: scalar reward for the final assistant turn (the
collator threads this into the GRPO channel; defaults to 0
since Claude Code traces don't carry RLVR rewards natively).
Returns:
list[TraceExample] (TypedDict — `{trace_id, turns, final_reward,
dpo_pairs}`). dpo_pairs is omitted (Claude Code traces don't
carry chosen/rejected pairs; use `teacher_replay.extract_dpo_pairs`
for that channel separately).
"""
examples: list[dict[str, Any]] = []
for state in states:
msgs = state.get("messages", [])
turns: list[dict[str, Any]] = []
for i, msg in enumerate(msgs):
content = msg.get("content", "")
if isinstance(content, list):
# Defensive: some tokenizers / chat formats hand back lists.
content = "\n".join(
str(c.get("text", c)) if isinstance(c, dict) else str(c)
for c in content
)
role = msg.get("role", "")
turn: dict[str, Any] = {"role": role, "content": content}
# An assistant turn is an error site iff a recent preceding
# user-role turn contained the TOOL_ERROR_TAG. Walk backward
# through user turns until we hit either an error-tagged user
# turn (mark this assistant as the error recovery turn) or a
# different role / no error tag (no error site).
#
# This handles chains where an error tool_result is followed
# by additional user turns (e.g., a follow-up tool_result on
# a successful retry) before the assistant recovery turn.
if role == "assistant" and i > 0:
error_kind_found: str | None = None
error_content_found: str | None = None
for j in range(i - 1, -1, -1):
prev = msgs[j]
if prev.get("role") != "user":
break
prev_content = prev.get("content", "")
if isinstance(prev_content, list):
prev_content = "\n".join(
str(c.get("text", c)) if isinstance(c, dict) else str(c)
for c in prev_content
)
if TOOL_ERROR_TAG in prev_content:
error_kind_found = error_kind_fn(prev_content)
error_content_found = prev_content
break
if error_kind_found:
turn["tool_error"] = error_kind_found
turn["error_meta"] = {
"source_role": "user",
"source_content_excerpt": (error_content_found or "")[:200],
}
turns.append(turn)
if not turns:
continue
examples.append({
"trace_id": str(state.get("state_id", "")),
"turns": turns,
"final_reward": float(final_reward),
})
return examples
__all__ = [
"claude_states_to_trace_examples",
"default_classify_error",
"TOOL_ERROR_TAG",
]