Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 8,399 Bytes
03bf323 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Adapter: ClaudeCodeIngester output → ComposerDataCollator input.
The ingester (`composer_replication.ingestion.claude_code.ClaudeCodeIngester`)
emits `TraceState` dicts with a `messages` field — a list of OpenAI-style
chat dicts. The data collator (`composer_replication.trainer.data_collator
.ComposerDataCollator`) expects `TraceExample` dicts with a `turns` field —
a list of `TraceTurn` dicts where each turn carries its own role, content,
and (critically) `tool_error` field for SDPO error-site detection.
This module bridges the two. The adapter:
1. Consumes a `TraceState` from the ingester.
2. Converts its `messages` (chat dicts) → `turns` (TraceTurns).
3. Detects tool-error sites by looking for the `[TOOL_RESULT (ERROR)]`
tag the ingester writes (per Claude Code's `is_error: true` flag in
the source JSONL).
4. Marks the assistant turn IMMEDIATELY AFTER an error tool-result with
`tool_error="<error_kind>"` so the data collator's
`_build_hint_injected_trace` recognizes it as an SDPO error site.
Usage:
from composer_replication.ingestion import ClaudeCodeIngester
from composer_replication.ingestion.trace_examples import (
claude_states_to_trace_examples,
)
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig,
)
ingester = ClaudeCodeIngester()
states = list(ingester.ingest(session_jsonl_path))
examples = claude_states_to_trace_examples(states)
config = CollatorConfig(
hint_generator=lambda kind, meta: "Hint: try a different path.",
enable_replay_dpo=False,
)
collator = ComposerDataCollator(tokenizer=tok, config=config)
batch = collator(examples)
# batch now has properly-aligned ctx_teacher_input_ids + sdpo_loss_mask
This is the production-grade alignment path. Wave 18's
`examples/sdpo_with_real_traces/` is a wiring smoke that bypasses this
adapter; Wave 19's `examples/sdpo_with_real_traces_production/` uses
this adapter for the real alignment.
"""
from __future__ import annotations
import re
from typing import Any, Iterable, Mapping
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# The ingester writes this tag for tool_results where the source JSONL had
# is_error: true. We detect error sites by string-matching this tag in the
# user-turn content. Matches the `tag = "[TOOL_RESULT (ERROR)]"` literal
# in `composer_replication.ingestion.claude_code._serialize_user_content`.
TOOL_ERROR_TAG = "[TOOL_RESULT (ERROR)]"
# Heuristic: classify the error_kind by simple keyword match on the error
# content. The data collator's `hint_generator` receives this string as
# its first argument so the hint can be tailored. These categories are a
# minimal v0 set; users can extend by passing their own classifier
# function via the `error_kind_fn` parameter.
_ERROR_KIND_PATTERNS = [
# Order matters: command_not_found must come BEFORE file_not_found
# since "command not found" would also match a generic "not found".
("command_not_found", re.compile(r"(?i)command not found")),
("file_not_found", re.compile(r"(?i)\b(file does not exist|no such file or directory|file not found)\b")),
("permission_denied", re.compile(r"(?i)permission denied")),
("syntax_error", re.compile(r"(?i)syntax\s*error")),
("connection_error", re.compile(r"(?i)\b(connection|network|timeout) (error|refused)\b")),
]
def default_classify_error(content: str) -> str:
"""Classify a tool-error message into a short error_kind string.
Returns one of the named categories above, or "tool_error" for
anything unmatched. Users can override by passing their own
`error_kind_fn` to `claude_states_to_trace_examples`.
"""
for kind, pattern in _ERROR_KIND_PATTERNS:
if pattern.search(content):
return kind
return "tool_error"
# ---------------------------------------------------------------------------
# Adapter
# ---------------------------------------------------------------------------
def claude_states_to_trace_examples(
states: Iterable[Mapping[str, Any]],
*,
error_kind_fn=default_classify_error,
final_reward: float = 0.0,
) -> list[dict[str, Any]]:
"""Convert ClaudeCodeIngester TraceState dicts → TraceExample dicts.
Each input state's `messages` list (OpenAI chat dicts) is rewritten
as a `turns` list of TraceTurn dicts. Tool-error sites are detected
by matching the `[TOOL_RESULT (ERROR)]` tag in user-role messages
(the ingester writes this tag whenever the source JSONL had
`is_error: true`). When found, the assistant turn IMMEDIATELY after
the error tool-result gets its `tool_error` field populated, which
is what `ComposerDataCollator._build_hint_injected_trace` checks via
`_is_error_turn`.
Args:
states: iterable of TraceState dicts (from `ClaudeCodeIngester.ingest`).
error_kind_fn: callable(error_content) -> str for classifying
errors. Defaults to the keyword-match classifier above.
final_reward: scalar reward for the final assistant turn (the
collator threads this into the GRPO channel; defaults to 0
since Claude Code traces don't carry RLVR rewards natively).
Returns:
list[TraceExample] (TypedDict — `{trace_id, turns, final_reward,
dpo_pairs}`). dpo_pairs is omitted (Claude Code traces don't
carry chosen/rejected pairs; use `teacher_replay.extract_dpo_pairs`
for that channel separately).
"""
examples: list[dict[str, Any]] = []
for state in states:
msgs = state.get("messages", [])
turns: list[dict[str, Any]] = []
for i, msg in enumerate(msgs):
content = msg.get("content", "")
if isinstance(content, list):
# Defensive: some tokenizers / chat formats hand back lists.
content = "\n".join(
str(c.get("text", c)) if isinstance(c, dict) else str(c)
for c in content
)
role = msg.get("role", "")
turn: dict[str, Any] = {"role": role, "content": content}
# An assistant turn is an error site iff a recent preceding
# user-role turn contained the TOOL_ERROR_TAG. Walk backward
# through user turns until we hit either an error-tagged user
# turn (mark this assistant as the error recovery turn) or a
# different role / no error tag (no error site).
#
# This handles chains where an error tool_result is followed
# by additional user turns (e.g., a follow-up tool_result on
# a successful retry) before the assistant recovery turn.
if role == "assistant" and i > 0:
error_kind_found: str | None = None
error_content_found: str | None = None
for j in range(i - 1, -1, -1):
prev = msgs[j]
if prev.get("role") != "user":
break
prev_content = prev.get("content", "")
if isinstance(prev_content, list):
prev_content = "\n".join(
str(c.get("text", c)) if isinstance(c, dict) else str(c)
for c in prev_content
)
if TOOL_ERROR_TAG in prev_content:
error_kind_found = error_kind_fn(prev_content)
error_content_found = prev_content
break
if error_kind_found:
turn["tool_error"] = error_kind_found
turn["error_meta"] = {
"source_role": "user",
"source_content_excerpt": (error_content_found or "")[:200],
}
turns.append(turn)
if not turns:
continue
examples.append({
"trace_id": str(state.get("state_id", "")),
"turns": turns,
"final_reward": float(final_reward),
})
return examples
__all__ = [
"claude_states_to_trace_examples",
"default_classify_error",
"TOOL_ERROR_TAG",
]
|