File size: 8,399 Bytes
03bf323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""Adapter: ClaudeCodeIngester output → ComposerDataCollator input.

The ingester (`composer_replication.ingestion.claude_code.ClaudeCodeIngester`)
emits `TraceState` dicts with a `messages` field — a list of OpenAI-style
chat dicts. The data collator (`composer_replication.trainer.data_collator
.ComposerDataCollator`) expects `TraceExample` dicts with a `turns` field —
a list of `TraceTurn` dicts where each turn carries its own role, content,
and (critically) `tool_error` field for SDPO error-site detection.

This module bridges the two. The adapter:

  1. Consumes a `TraceState` from the ingester.
  2. Converts its `messages` (chat dicts) → `turns` (TraceTurns).
  3. Detects tool-error sites by looking for the `[TOOL_RESULT (ERROR)]`
     tag the ingester writes (per Claude Code's `is_error: true` flag in
     the source JSONL).
  4. Marks the assistant turn IMMEDIATELY AFTER an error tool-result with
     `tool_error="<error_kind>"` so the data collator's
     `_build_hint_injected_trace` recognizes it as an SDPO error site.

Usage:
    from composer_replication.ingestion import ClaudeCodeIngester
    from composer_replication.ingestion.trace_examples import (
        claude_states_to_trace_examples,
    )
    from composer_replication.trainer.data_collator import (
        ComposerDataCollator, CollatorConfig,
    )

    ingester = ClaudeCodeIngester()
    states = list(ingester.ingest(session_jsonl_path))
    examples = claude_states_to_trace_examples(states)

    config = CollatorConfig(
        hint_generator=lambda kind, meta: "Hint: try a different path.",
        enable_replay_dpo=False,
    )
    collator = ComposerDataCollator(tokenizer=tok, config=config)
    batch = collator(examples)
    # batch now has properly-aligned ctx_teacher_input_ids + sdpo_loss_mask

This is the production-grade alignment path. Wave 18's
`examples/sdpo_with_real_traces/` is a wiring smoke that bypasses this
adapter; Wave 19's `examples/sdpo_with_real_traces_production/` uses
this adapter for the real alignment.
"""
from __future__ import annotations

import re
from typing import Any, Iterable, Mapping

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

# The ingester writes this tag for tool_results where the source JSONL had
# is_error: true. We detect error sites by string-matching this tag in the
# user-turn content. Matches the `tag = "[TOOL_RESULT (ERROR)]"` literal
# in `composer_replication.ingestion.claude_code._serialize_user_content`.
TOOL_ERROR_TAG = "[TOOL_RESULT (ERROR)]"

# Heuristic: classify the error_kind by simple keyword match on the error
# content. The data collator's `hint_generator` receives this string as
# its first argument so the hint can be tailored. These categories are a
# minimal v0 set; users can extend by passing their own classifier
# function via the `error_kind_fn` parameter.
_ERROR_KIND_PATTERNS = [
    # Order matters: command_not_found must come BEFORE file_not_found
    # since "command not found" would also match a generic "not found".
    ("command_not_found", re.compile(r"(?i)command not found")),
    ("file_not_found", re.compile(r"(?i)\b(file does not exist|no such file or directory|file not found)\b")),
    ("permission_denied", re.compile(r"(?i)permission denied")),
    ("syntax_error", re.compile(r"(?i)syntax\s*error")),
    ("connection_error", re.compile(r"(?i)\b(connection|network|timeout) (error|refused)\b")),
]


def default_classify_error(content: str) -> str:
    """Classify a tool-error message into a short error_kind string.

    Returns one of the named categories above, or "tool_error" for
    anything unmatched. Users can override by passing their own
    `error_kind_fn` to `claude_states_to_trace_examples`.
    """
    for kind, pattern in _ERROR_KIND_PATTERNS:
        if pattern.search(content):
            return kind
    return "tool_error"


# ---------------------------------------------------------------------------
# Adapter
# ---------------------------------------------------------------------------


def claude_states_to_trace_examples(
    states: Iterable[Mapping[str, Any]],
    *,
    error_kind_fn=default_classify_error,
    final_reward: float = 0.0,
) -> list[dict[str, Any]]:
    """Convert ClaudeCodeIngester TraceState dicts → TraceExample dicts.

    Each input state's `messages` list (OpenAI chat dicts) is rewritten
    as a `turns` list of TraceTurn dicts. Tool-error sites are detected
    by matching the `[TOOL_RESULT (ERROR)]` tag in user-role messages
    (the ingester writes this tag whenever the source JSONL had
    `is_error: true`). When found, the assistant turn IMMEDIATELY after
    the error tool-result gets its `tool_error` field populated, which
    is what `ComposerDataCollator._build_hint_injected_trace` checks via
    `_is_error_turn`.

    Args:
        states: iterable of TraceState dicts (from `ClaudeCodeIngester.ingest`).
        error_kind_fn: callable(error_content) -> str for classifying
            errors. Defaults to the keyword-match classifier above.
        final_reward: scalar reward for the final assistant turn (the
            collator threads this into the GRPO channel; defaults to 0
            since Claude Code traces don't carry RLVR rewards natively).

    Returns:
        list[TraceExample] (TypedDict — `{trace_id, turns, final_reward,
        dpo_pairs}`). dpo_pairs is omitted (Claude Code traces don't
        carry chosen/rejected pairs; use `teacher_replay.extract_dpo_pairs`
        for that channel separately).
    """
    examples: list[dict[str, Any]] = []
    for state in states:
        msgs = state.get("messages", [])
        turns: list[dict[str, Any]] = []

        for i, msg in enumerate(msgs):
            content = msg.get("content", "")
            if isinstance(content, list):
                # Defensive: some tokenizers / chat formats hand back lists.
                content = "\n".join(
                    str(c.get("text", c)) if isinstance(c, dict) else str(c)
                    for c in content
                )

            role = msg.get("role", "")
            turn: dict[str, Any] = {"role": role, "content": content}

            # An assistant turn is an error site iff a recent preceding
            # user-role turn contained the TOOL_ERROR_TAG. Walk backward
            # through user turns until we hit either an error-tagged user
            # turn (mark this assistant as the error recovery turn) or a
            # different role / no error tag (no error site).
            #
            # This handles chains where an error tool_result is followed
            # by additional user turns (e.g., a follow-up tool_result on
            # a successful retry) before the assistant recovery turn.
            if role == "assistant" and i > 0:
                error_kind_found: str | None = None
                error_content_found: str | None = None
                for j in range(i - 1, -1, -1):
                    prev = msgs[j]
                    if prev.get("role") != "user":
                        break
                    prev_content = prev.get("content", "")
                    if isinstance(prev_content, list):
                        prev_content = "\n".join(
                            str(c.get("text", c)) if isinstance(c, dict) else str(c)
                            for c in prev_content
                        )
                    if TOOL_ERROR_TAG in prev_content:
                        error_kind_found = error_kind_fn(prev_content)
                        error_content_found = prev_content
                        break
                if error_kind_found:
                    turn["tool_error"] = error_kind_found
                    turn["error_meta"] = {
                        "source_role": "user",
                        "source_content_excerpt": (error_content_found or "")[:200],
                    }

            turns.append(turn)

        if not turns:
            continue

        examples.append({
            "trace_id": str(state.get("state_id", "")),
            "turns": turns,
            "final_reward": float(final_reward),
        })

    return examples


__all__ = [
    "claude_states_to_trace_examples",
    "default_classify_error",
    "TOOL_ERROR_TAG",
]