"""Adapter: ClaudeCodeIngester output → ComposerDataCollator input. The ingester (`composer_replication.ingestion.claude_code.ClaudeCodeIngester`) emits `TraceState` dicts with a `messages` field — a list of OpenAI-style chat dicts. The data collator (`composer_replication.trainer.data_collator .ComposerDataCollator`) expects `TraceExample` dicts with a `turns` field — a list of `TraceTurn` dicts where each turn carries its own role, content, and (critically) `tool_error` field for SDPO error-site detection. This module bridges the two. The adapter: 1. Consumes a `TraceState` from the ingester. 2. Converts its `messages` (chat dicts) → `turns` (TraceTurns). 3. Detects tool-error sites by looking for the `[TOOL_RESULT (ERROR)]` tag the ingester writes (per Claude Code's `is_error: true` flag in the source JSONL). 4. Marks the assistant turn IMMEDIATELY AFTER an error tool-result with `tool_error=""` so the data collator's `_build_hint_injected_trace` recognizes it as an SDPO error site. Usage: from composer_replication.ingestion import ClaudeCodeIngester from composer_replication.ingestion.trace_examples import ( claude_states_to_trace_examples, ) from composer_replication.trainer.data_collator import ( ComposerDataCollator, CollatorConfig, ) ingester = ClaudeCodeIngester() states = list(ingester.ingest(session_jsonl_path)) examples = claude_states_to_trace_examples(states) config = CollatorConfig( hint_generator=lambda kind, meta: "Hint: try a different path.", enable_replay_dpo=False, ) collator = ComposerDataCollator(tokenizer=tok, config=config) batch = collator(examples) # batch now has properly-aligned ctx_teacher_input_ids + sdpo_loss_mask This is the production-grade alignment path. Wave 18's `examples/sdpo_with_real_traces/` is a wiring smoke that bypasses this adapter; Wave 19's `examples/sdpo_with_real_traces_production/` uses this adapter for the real alignment. """ from __future__ import annotations import re from typing import Any, Iterable, Mapping # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- # The ingester writes this tag for tool_results where the source JSONL had # is_error: true. We detect error sites by string-matching this tag in the # user-turn content. Matches the `tag = "[TOOL_RESULT (ERROR)]"` literal # in `composer_replication.ingestion.claude_code._serialize_user_content`. TOOL_ERROR_TAG = "[TOOL_RESULT (ERROR)]" # Heuristic: classify the error_kind by simple keyword match on the error # content. The data collator's `hint_generator` receives this string as # its first argument so the hint can be tailored. These categories are a # minimal v0 set; users can extend by passing their own classifier # function via the `error_kind_fn` parameter. _ERROR_KIND_PATTERNS = [ # Order matters: command_not_found must come BEFORE file_not_found # since "command not found" would also match a generic "not found". ("command_not_found", re.compile(r"(?i)command not found")), ("file_not_found", re.compile(r"(?i)\b(file does not exist|no such file or directory|file not found)\b")), ("permission_denied", re.compile(r"(?i)permission denied")), ("syntax_error", re.compile(r"(?i)syntax\s*error")), ("connection_error", re.compile(r"(?i)\b(connection|network|timeout) (error|refused)\b")), ] def default_classify_error(content: str) -> str: """Classify a tool-error message into a short error_kind string. Returns one of the named categories above, or "tool_error" for anything unmatched. Users can override by passing their own `error_kind_fn` to `claude_states_to_trace_examples`. """ for kind, pattern in _ERROR_KIND_PATTERNS: if pattern.search(content): return kind return "tool_error" # --------------------------------------------------------------------------- # Adapter # --------------------------------------------------------------------------- def claude_states_to_trace_examples( states: Iterable[Mapping[str, Any]], *, error_kind_fn=default_classify_error, final_reward: float = 0.0, ) -> list[dict[str, Any]]: """Convert ClaudeCodeIngester TraceState dicts → TraceExample dicts. Each input state's `messages` list (OpenAI chat dicts) is rewritten as a `turns` list of TraceTurn dicts. Tool-error sites are detected by matching the `[TOOL_RESULT (ERROR)]` tag in user-role messages (the ingester writes this tag whenever the source JSONL had `is_error: true`). When found, the assistant turn IMMEDIATELY after the error tool-result gets its `tool_error` field populated, which is what `ComposerDataCollator._build_hint_injected_trace` checks via `_is_error_turn`. Args: states: iterable of TraceState dicts (from `ClaudeCodeIngester.ingest`). error_kind_fn: callable(error_content) -> str for classifying errors. Defaults to the keyword-match classifier above. final_reward: scalar reward for the final assistant turn (the collator threads this into the GRPO channel; defaults to 0 since Claude Code traces don't carry RLVR rewards natively). Returns: list[TraceExample] (TypedDict — `{trace_id, turns, final_reward, dpo_pairs}`). dpo_pairs is omitted (Claude Code traces don't carry chosen/rejected pairs; use `teacher_replay.extract_dpo_pairs` for that channel separately). """ examples: list[dict[str, Any]] = [] for state in states: msgs = state.get("messages", []) turns: list[dict[str, Any]] = [] for i, msg in enumerate(msgs): content = msg.get("content", "") if isinstance(content, list): # Defensive: some tokenizers / chat formats hand back lists. content = "\n".join( str(c.get("text", c)) if isinstance(c, dict) else str(c) for c in content ) role = msg.get("role", "") turn: dict[str, Any] = {"role": role, "content": content} # An assistant turn is an error site iff a recent preceding # user-role turn contained the TOOL_ERROR_TAG. Walk backward # through user turns until we hit either an error-tagged user # turn (mark this assistant as the error recovery turn) or a # different role / no error tag (no error site). # # This handles chains where an error tool_result is followed # by additional user turns (e.g., a follow-up tool_result on # a successful retry) before the assistant recovery turn. if role == "assistant" and i > 0: error_kind_found: str | None = None error_content_found: str | None = None for j in range(i - 1, -1, -1): prev = msgs[j] if prev.get("role") != "user": break prev_content = prev.get("content", "") if isinstance(prev_content, list): prev_content = "\n".join( str(c.get("text", c)) if isinstance(c, dict) else str(c) for c in prev_content ) if TOOL_ERROR_TAG in prev_content: error_kind_found = error_kind_fn(prev_content) error_content_found = prev_content break if error_kind_found: turn["tool_error"] = error_kind_found turn["error_meta"] = { "source_role": "user", "source_content_excerpt": (error_content_found or "")[:200], } turns.append(turn) if not turns: continue examples.append({ "trace_id": str(state.get("state_id", "")), "turns": turns, "final_reward": float(final_reward), }) return examples __all__ = [ "claude_states_to_trace_examples", "default_classify_error", "TOOL_ERROR_TAG", ]