Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Wave 10 — packaging: composer_replication is now pip-installable
Browse filesPromotes the framework from "skeleton scattered across spike directories"
to a proper installable Python package (closes vision-validation gap V4).
What's new:
pyproject.toml at repo root
- Hatchling build backend
- Python ≥3.10
- Core deps: torch≥2.0, transformers≥4.46
- Optional extras: replay (httpx), diloco (torchft-nightly),
train (trl, peft, accelerate, datasets), dev (everything + pytest, ruff)
- Project URLs point to the HF repo + Discussions
composer_replication/ package (8 modules)
- __init__.py re-exports the framework's public API:
compose_loss, LossComponents, build_batch (Spike 006)
generalized_jsd_loss (verified port of OPSD)
ClaudeCodeIngester, IngestionStats, SYSTEM_PROMPT (Spike 007)
TraceState, DPOPair, TeacherSpec, replay_trace,
extract_dpo_pairs, DEFAULT_TEACHERS (Spike 001/005)
ComposerReplicationTrainer (Spike 005, TRL subclass)
make_diloco_outer_loop (Spike 008, optional)
- Submodules (loss, batch, opsd, teacher_replay, hint_generator,
ingestion.claude_code, trainer.composer_trainer + data_collator,
diloco) are 1:1 copies of the spike modules with sibling-relative
sys.path hacks replaced by package-absolute imports.
- DiLoCo import is guarded — package works without torchft installed,
_DILOCO_AVAILABLE flag exposes the state.
- Spike directories KEEP their own copies as verification harnesses;
the package and the spikes stay in sync because the package's imports
resolve cleanly without sys.path mutation, while the spikes still use
their original sys.path.insert() pattern for self-containment.
examples/qwen_05b_quickstart/
- run.py: end-to-end CPU smoke using the installed package — loads
Qwen2.5-0.5B-Instruct, runs 5 backward steps through the 3-channel
loss, prints the loss curve. ~3-5 min wall-clock, ~$0.
- README.md: step-by-step instructions + expected output.
- run.log: actual successful run output (Initial 0.7390 → Final 0.0031,
99.6% reduction, all grads finite).
Verification
- pip install -e . succeeds clean.
- All four import paths resolve under the installed package:
cr.compose_loss, cr.ClaudeCodeIngester, cr.ComposerReplicationTrainer,
cr.make_diloco_outer_loop.
- Quickstart end-to-end PASS on real Qwen2.5-0.5B with the same loss
trajectory as Spike 006.
- Spike 005 (38/38), 007 (15/15), 008 (5/5) all still pass with the
installed package — no regression.
Refs: BACKLOG.md "Wave 10 — Packaging"; docs/VISION_VALIDATION.md gap V4.
- README.md +19 -3
- composer_replication/README.md +36 -0
- composer_replication/__init__.py +89 -0
- composer_replication/batch.py +128 -0
- composer_replication/diloco/__init__.py +124 -0
- composer_replication/hint_generator.py +107 -0
- composer_replication/ingestion/__init__.py +20 -0
- composer_replication/ingestion/claude_code.py +295 -0
- composer_replication/loss.py +211 -0
- composer_replication/opsd.py +132 -0
- composer_replication/teacher_replay.py +280 -0
- composer_replication/trainer/__init__.py +10 -0
- composer_replication/trainer/composer_trainer.py +236 -0
- composer_replication/trainer/data_collator.py +440 -0
- examples/qwen_05b_quickstart/README.md +70 -0
- examples/qwen_05b_quickstart/run.py +83 -0
- pyproject.toml +92 -0
|
@@ -27,15 +27,31 @@ pretty_name: "Composer 2.5 Replication Framework — Research Synthesis"
|
|
| 27 |
|
| 28 |
# Composer 2.5 Replication Framework
|
| 29 |
|
| 30 |
-
> **Repo type:** `model` (methodology). **Status:** Research synthesis + v0.
|
| 31 |
> **Author:** [Codeseys](https://huggingface.co/Codeseys)
|
| 32 |
> **Goal:** Replicate Cursor's [Composer 2.5](https://cursor.com/blog/composer-2-5) (a post-trained Kimi K2.5 specialised for agentic coding) on **any** HuggingFace base model, using a synthesis of decentralized RL post-training techniques.
|
| 33 |
|
| 34 |
This repository is the **"paper of the project"** — it is the methodology / research / framework specification for an open replication of Cursor's Composer 2.5 system, plus a **novel multi-teacher trace-replay distillation channel** that stacks on top of the Composer recipe.
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
- 🟢 Spike 001 (kill-switch teacher cost) — **VALIDATED**: 150 real OpenRouter calls, $0.98/trace, p95 latency 20.5s. The novel research direction is economically viable.
|
| 38 |
-
- 🟢 Spike 005 (integrated 3-channel trainer skeleton) — **SKELETON-VALIDATED
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
- 📋 Spikes 002a/002b/003/004 — planned, awaiting GPU budget commitment.
|
| 40 |
|
| 41 |
📝 **Publication materials drafted:** [`publications/`](publications/) contains a complete pre-experimental release set — longform methodology paper, blog post (HF Blog format), repo Discussion announcement, X/LinkedIn threads, plus `CITATION.cff` and `CITATION.bib` at the repo root. Use [`publications/RELEASE_CHECKLIST.md`](publications/RELEASE_CHECKLIST.md) to coordinate the publication wave. Nothing posted publicly yet — this is a pre-experimental release, not a post-experimental one.
|
|
|
|
| 27 |
|
| 28 |
# Composer 2.5 Replication Framework
|
| 29 |
|
| 30 |
+
> **Repo type:** `model` (methodology). **Status:** Research synthesis + v0.1 framework with verified gap-closer spikes (2026-05-26).
|
| 31 |
> **Author:** [Codeseys](https://huggingface.co/Codeseys)
|
| 32 |
> **Goal:** Replicate Cursor's [Composer 2.5](https://cursor.com/blog/composer-2-5) (a post-trained Kimi K2.5 specialised for agentic coding) on **any** HuggingFace base model, using a synthesis of decentralized RL post-training techniques.
|
| 33 |
|
| 34 |
This repository is the **"paper of the project"** — it is the methodology / research / framework specification for an open replication of Cursor's Composer 2.5 system, plus a **novel multi-teacher trace-replay distillation channel** that stacks on top of the Composer recipe.
|
| 35 |
|
| 36 |
+
## Install
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
pip install -e .
|
| 40 |
+
python examples/qwen_05b_quickstart/run.py
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
The quickstart loads Qwen2.5-0.5B-Instruct and runs 5 backward steps through
|
| 44 |
+
the 3-channel loss on CPU in ~3-5 minutes. See
|
| 45 |
+
[`examples/qwen_05b_quickstart/README.md`](examples/qwen_05b_quickstart/README.md)
|
| 46 |
+
for what the output should look like.
|
| 47 |
+
|
| 48 |
+
**v0.1 spike progress (2026-05-26):**
|
| 49 |
- 🟢 Spike 001 (kill-switch teacher cost) — **VALIDATED**: 150 real OpenRouter calls, $0.98/trace, p95 latency 20.5s. The novel research direction is economically viable.
|
| 50 |
+
- 🟢 Spike 005 (integrated 3-channel trainer skeleton) — **SKELETON-VALIDATED**: 38/38 unit tests passing; the integration architecture claim ("all three channels run simultaneously, ablate cleanly, train without divergence") is empirically verified.
|
| 51 |
+
- 🟢 Spike 006 (real HF model smoke) — **PASSED**: Qwen2.5-0.5B-Instruct via `AutoModelForCausalLM`, 5 backward steps on CPU, loss 0.7390 → 0.0031 (99.6% reduction), all gradients finite. Closes vision-validation gap V8.
|
| 52 |
+
- 🟢 Spike 007 (real trace ingestion) — **PASSED**: `ClaudeCodeIngester.ingest()` converts Claude Code session JSONL → `TraceState` records. 15/15 tests including a real-session smoke. Closes V5.
|
| 53 |
+
- 🟢 Spike 008 (DiLoCo outer-loop smoke) — **PASSED**: `make_diloco_outer_loop()` wraps `torchft.local_sgd.DiLoCo` (BSD-3, Meta-maintained). 5/5 tests including pseudo-gradient sign-convention verification. Closes V2.
|
| 54 |
+
- 🟢 Wave 10 (packaging) — **DONE**: `pip install -e .` works; `composer_replication` package re-exports the verified APIs from the spike directories.
|
| 55 |
- 📋 Spikes 002a/002b/003/004 — planned, awaiting GPU budget commitment.
|
| 56 |
|
| 57 |
📝 **Publication materials drafted:** [`publications/`](publications/) contains a complete pre-experimental release set — longform methodology paper, blog post (HF Blog format), repo Discussion announcement, X/LinkedIn threads, plus `CITATION.cff` and `CITATION.bib` at the repo root. Use [`publications/RELEASE_CHECKLIST.md`](publications/RELEASE_CHECKLIST.md) to coordinate the publication wave. Nothing posted publicly yet — this is a pre-experimental release, not a post-experimental one.
|
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# composer_replication
|
| 2 |
+
|
| 3 |
+
The Composer 2.5 Replication Framework, packaged for `pip install`.
|
| 4 |
+
|
| 5 |
+
This package re-exports the verified APIs that live in the
|
| 6 |
+
[`spikes/`](../spikes/) directory of the parent repository, so that downstream
|
| 7 |
+
code can `import composer_replication` instead of poking at `sys.path`.
|
| 8 |
+
|
| 9 |
+
## Package map
|
| 10 |
+
|
| 11 |
+
| module | source spike | purpose |
|
| 12 |
+
|---|---|---|
|
| 13 |
+
| `composer_replication.loss` | spike 006 | Free `compose_loss(model, batch, ...)` 3-channel loss composer + `LossComponents` dataclass |
|
| 14 |
+
| `composer_replication.batch` | spike 006 | `build_batch(tokenizer)` — real chat-template batch from any HF tokenizer |
|
| 15 |
+
| `composer_replication.opsd` | spike 005 | `generalized_jsd_loss` (verified port of `siyan-zhao/OPSD`) |
|
| 16 |
+
| `composer_replication.teacher_replay` | spike 001/005 | `replay_trace`, `extract_dpo_pairs`, `TraceState`, `TeacherSpec` (multi-teacher OpenRouter replay) |
|
| 17 |
+
| `composer_replication.hint_generator` | spike 005 | Hint-text construction at error sites for SDPO channel |
|
| 18 |
+
| `composer_replication.trainer` | spike 005 | `ComposerReplicationTrainer` (TRL `GRPOTrainer` subclass with the 3 channels) |
|
| 19 |
+
| `composer_replication.ingestion` | spike 007 | `ClaudeCodeIngester` (Claude Code session JSONL → `TraceState`) |
|
| 20 |
+
| `composer_replication.diloco` | spike 008 | `make_diloco_outer_loop` (wraps `torchft.local_sgd.DiLoCo`) |
|
| 21 |
+
|
| 22 |
+
## Why a package on top of spikes?
|
| 23 |
+
|
| 24 |
+
The spikes are research artifacts: each one has its own `README.md`, tests,
|
| 25 |
+
verdict, and a `sys.path` hack to find sibling modules. They live forever as
|
| 26 |
+
verification harnesses.
|
| 27 |
+
|
| 28 |
+
Most users want to `pip install -e . && python my_training_script.py`. This
|
| 29 |
+
package is the pip-installable face of the framework. The two surfaces stay
|
| 30 |
+
in sync because the package modules are 1:1 copies of the spike modules with
|
| 31 |
+
only the import paths changed (sibling-relative → package-absolute).
|
| 32 |
+
|
| 33 |
+
## Quickstart
|
| 34 |
+
|
| 35 |
+
See [`examples/qwen_05b_quickstart/`](../examples/qwen_05b_quickstart/) at
|
| 36 |
+
the repo root.
|
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication — Composer 2.5 Replication Framework.
|
| 2 |
+
|
| 3 |
+
A research-grade, open replication of Cursor Composer 2.5's training recipe:
|
| 4 |
+
take any HuggingFace model, further-RL-train it using a 3-channel loss combining
|
| 5 |
+
|
| 6 |
+
1. RLVR / GRPO (channel 1, via TRL)
|
| 7 |
+
2. SDPO hint-distillation (channel 2, OPSD-based)
|
| 8 |
+
3. Multi-teacher trace-replay DPO (channel 3, this framework's contribution)
|
| 9 |
+
|
| 10 |
+
with optional DiLoCo / Streaming DiLoCo outer-loop sync for distributed runs.
|
| 11 |
+
|
| 12 |
+
See https://huggingface.co/Codeseys/composer-replication-framework for the
|
| 13 |
+
full project README, design docs, ADRs, and verification spikes.
|
| 14 |
+
|
| 15 |
+
Quickstart:
|
| 16 |
+
>>> from composer_replication import compose_loss, build_batch
|
| 17 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 18 |
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
| 19 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
| 20 |
+
>>> batch = build_batch(tokenizer)
|
| 21 |
+
>>> components = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
|
| 22 |
+
>>> components.total.backward()
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
# Loss composition (Spike 006)
|
| 27 |
+
from composer_replication.loss import LossComponents, compose_loss
|
| 28 |
+
from composer_replication.batch import build_batch
|
| 29 |
+
|
| 30 |
+
# Trace ingestion (Spike 007)
|
| 31 |
+
from composer_replication.ingestion.claude_code import (
|
| 32 |
+
SYSTEM_PROMPT,
|
| 33 |
+
ClaudeCodeIngester,
|
| 34 |
+
IngestionStats,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# OPSD / SDPO loss (verified extension from siyan-zhao/OPSD, MIT)
|
| 38 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 39 |
+
|
| 40 |
+
# Teacher replay (Spike 001 → trainer)
|
| 41 |
+
from composer_replication.teacher_replay import (
|
| 42 |
+
DEFAULT_TEACHERS,
|
| 43 |
+
DPOPair,
|
| 44 |
+
TeacherCallResult,
|
| 45 |
+
TeacherSpec,
|
| 46 |
+
TraceState,
|
| 47 |
+
extract_dpo_pairs,
|
| 48 |
+
replay_trace,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Trainer (Spike 005)
|
| 52 |
+
from composer_replication.trainer import ComposerReplicationTrainer
|
| 53 |
+
|
| 54 |
+
# DiLoCo (Spike 008) — optional, requires torchft
|
| 55 |
+
try:
|
| 56 |
+
from composer_replication.diloco import make_diloco_outer_loop
|
| 57 |
+
_DILOCO_AVAILABLE = True
|
| 58 |
+
except ImportError:
|
| 59 |
+
_DILOCO_AVAILABLE = False
|
| 60 |
+
make_diloco_outer_loop = None # type: ignore[assignment]
|
| 61 |
+
|
| 62 |
+
__version__ = "0.1.0"
|
| 63 |
+
|
| 64 |
+
__all__ = [
|
| 65 |
+
# Core loss
|
| 66 |
+
"compose_loss",
|
| 67 |
+
"LossComponents",
|
| 68 |
+
"build_batch",
|
| 69 |
+
"generalized_jsd_loss",
|
| 70 |
+
# Trace ingestion
|
| 71 |
+
"ClaudeCodeIngester",
|
| 72 |
+
"IngestionStats",
|
| 73 |
+
"SYSTEM_PROMPT",
|
| 74 |
+
"TraceState",
|
| 75 |
+
# Teacher replay
|
| 76 |
+
"DEFAULT_TEACHERS",
|
| 77 |
+
"DPOPair",
|
| 78 |
+
"TeacherCallResult",
|
| 79 |
+
"TeacherSpec",
|
| 80 |
+
"extract_dpo_pairs",
|
| 81 |
+
"replay_trace",
|
| 82 |
+
# Trainer
|
| 83 |
+
"ComposerReplicationTrainer",
|
| 84 |
+
# DiLoCo (optional)
|
| 85 |
+
"make_diloco_outer_loop",
|
| 86 |
+
# Meta
|
| 87 |
+
"_DILOCO_AVAILABLE",
|
| 88 |
+
"__version__",
|
| 89 |
+
]
|
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer.
|
| 2 |
+
|
| 3 |
+
Used by Spike 006's smoke to generate inputs for `compose_loss` from a real
|
| 4 |
+
chat-template-formatted conversation, NOT random ints.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_batch(
|
| 14 |
+
tokenizer: Any,
|
| 15 |
+
*,
|
| 16 |
+
device: torch.device | str = "cpu",
|
| 17 |
+
seed: int = 42,
|
| 18 |
+
) -> dict[str, torch.Tensor]:
|
| 19 |
+
"""Construct a full 3-channel input batch from a real tokenizer.
|
| 20 |
+
|
| 21 |
+
Returns a dict with all keys `compose_loss` may consume:
|
| 22 |
+
input_ids, response_mask
|
| 23 |
+
ctx_teacher_input_ids, sdpo_loss_mask
|
| 24 |
+
dpo_chosen_input_ids, dpo_chosen_response_mask
|
| 25 |
+
dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 26 |
+
dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs
|
| 27 |
+
|
| 28 |
+
The DPO ref logprobs are dummy tensors (not from a real reference policy
|
| 29 |
+
forward); the smoke is verifying the loss composition wires together,
|
| 30 |
+
not the reference-policy precompute pipeline.
|
| 31 |
+
"""
|
| 32 |
+
torch.manual_seed(seed)
|
| 33 |
+
|
| 34 |
+
# ------------------------------------------------------------------
|
| 35 |
+
# Conversation 1: student rollout
|
| 36 |
+
# ------------------------------------------------------------------
|
| 37 |
+
student_msgs = [
|
| 38 |
+
{"role": "system", "content": "You are a careful coding assistant."},
|
| 39 |
+
{"role": "user", "content": "Write a Python function to compute the factorial of n."},
|
| 40 |
+
{"role": "assistant", "content": "def factorial(n):\n if n <= 1: return 1\n return n * factorial(n - 1)"},
|
| 41 |
+
]
|
| 42 |
+
student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False)
|
| 43 |
+
student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False)
|
| 44 |
+
input_ids = student_enc["input_ids"].to(device)
|
| 45 |
+
|
| 46 |
+
# response_mask: rough heuristic — last 30% of tokens are "the response"
|
| 47 |
+
# (good enough for a smoke; production uses chat-template offsets)
|
| 48 |
+
T = input_ids.shape[1]
|
| 49 |
+
response_mask = torch.zeros_like(input_ids)
|
| 50 |
+
response_mask[:, int(T * 0.7):] = 1
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------
|
| 53 |
+
# Conversation 2: hint-conditioned teacher context (SDPO)
|
| 54 |
+
# ------------------------------------------------------------------
|
| 55 |
+
teacher_msgs = [
|
| 56 |
+
{"role": "system", "content": "You are a careful coding assistant."},
|
| 57 |
+
{"role": "user", "content": "Write a Python function to compute the factorial of n."},
|
| 58 |
+
{"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."},
|
| 59 |
+
{"role": "assistant", "content": "def factorial(n):\n result = 1\n for i in range(2, n + 1):\n result *= i\n return result"},
|
| 60 |
+
]
|
| 61 |
+
teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False)
|
| 62 |
+
teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False)
|
| 63 |
+
ctx_teacher_input_ids = teacher_enc["input_ids"].to(device)
|
| 64 |
+
|
| 65 |
+
# SDPO loss mask: 1 on the post-hint assistant tokens (the "error site")
|
| 66 |
+
T_t = ctx_teacher_input_ids.shape[1]
|
| 67 |
+
sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids)
|
| 68 |
+
sdpo_loss_mask[:, int(T_t * 0.7):] = 1
|
| 69 |
+
|
| 70 |
+
# ------------------------------------------------------------------
|
| 71 |
+
# Conversation 3 + 4: DPO chosen / rejected pairs
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
dpo_chosen_msgs = [
|
| 74 |
+
{"role": "system", "content": "You are a careful coding assistant."},
|
| 75 |
+
{"role": "user", "content": "What's the time complexity of binary search?"},
|
| 76 |
+
{"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."},
|
| 77 |
+
]
|
| 78 |
+
dpo_rejected_msgs = [
|
| 79 |
+
{"role": "system", "content": "You are a careful coding assistant."},
|
| 80 |
+
{"role": "user", "content": "What's the time complexity of binary search?"},
|
| 81 |
+
{"role": "assistant", "content": "It's O(n) I think, you have to look at every element."},
|
| 82 |
+
]
|
| 83 |
+
chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False)
|
| 84 |
+
rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False)
|
| 85 |
+
|
| 86 |
+
# Pad both sequences to the same length so we can stack them
|
| 87 |
+
chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False)
|
| 88 |
+
rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False)
|
| 89 |
+
|
| 90 |
+
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
| 91 |
+
|
| 92 |
+
chosen_ids = chosen_enc["input_ids"]
|
| 93 |
+
rejected_ids = rejected_enc["input_ids"]
|
| 94 |
+
L = max(chosen_ids.shape[1], rejected_ids.shape[1])
|
| 95 |
+
|
| 96 |
+
def _pad(ids: torch.Tensor, length: int) -> torch.Tensor:
|
| 97 |
+
cur = ids.shape[1]
|
| 98 |
+
if cur >= length:
|
| 99 |
+
return ids[:, :length]
|
| 100 |
+
return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1)
|
| 101 |
+
|
| 102 |
+
dpo_chosen_input_ids = _pad(chosen_ids, L).to(device)
|
| 103 |
+
dpo_rejected_input_ids = _pad(rejected_ids, L).to(device)
|
| 104 |
+
|
| 105 |
+
chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids)
|
| 106 |
+
chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1
|
| 107 |
+
rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids)
|
| 108 |
+
rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1
|
| 109 |
+
|
| 110 |
+
# Dummy reference-policy logprobs (in production: precomputed by data collator)
|
| 111 |
+
dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device)
|
| 112 |
+
dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device)
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"input_ids": input_ids,
|
| 116 |
+
"response_mask": response_mask,
|
| 117 |
+
"ctx_teacher_input_ids": ctx_teacher_input_ids,
|
| 118 |
+
"sdpo_loss_mask": sdpo_loss_mask,
|
| 119 |
+
"dpo_chosen_input_ids": dpo_chosen_input_ids,
|
| 120 |
+
"dpo_chosen_response_mask": chosen_resp_mask,
|
| 121 |
+
"dpo_rejected_input_ids": dpo_rejected_input_ids,
|
| 122 |
+
"dpo_rejected_response_mask": rejected_resp_mask,
|
| 123 |
+
"dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs,
|
| 124 |
+
"dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
__all__ = ["build_batch"]
|
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_diloco.py — DiLoCo outer-loop wrapper for Composer Replication Framework.
|
| 2 |
+
|
| 3 |
+
Wraps `torchft.local_sgd.DiLoCo` with the framework's conventions:
|
| 4 |
+
- Sign convention is documented LOUDLY here once and tested via Spike 008.
|
| 5 |
+
- The wrapper exposes the same constructor shape as torchft's DiLoCo so a
|
| 6 |
+
future swap-in of the upstream class is a one-line change.
|
| 7 |
+
- Vanilla DiLoCo (Douillard et al. 2023) = `fragment_sync_delay=0`, single
|
| 8 |
+
fragment. Streaming DiLoCo (Liu et al. 2025) = non-zero delay, multiple
|
| 9 |
+
fragments. Spike 008 uses vanilla; Streaming is configured by the same API.
|
| 10 |
+
|
| 11 |
+
Reference: `docs/adrs/ADR-003-diloco-impl.md`.
|
| 12 |
+
|
| 13 |
+
Sign convention (READ THIS BEFORE TOUCHING):
|
| 14 |
+
torchft's `_save_grads()` (line 324 of torchft/local_sgd.py) computes
|
| 15 |
+
grad = θ_initial - θ_local
|
| 16 |
+
and stores it as `param.grad` for the outer optimizer to consume.
|
| 17 |
+
The outer optimizer then runs `param.data -= lr * grad`, equivalently
|
| 18 |
+
θ_new = θ_local + lr * (θ_initial - θ_local) if outer optimizer is plain SGD
|
| 19 |
+
which slurps the local-trained-θ TOWARD the initial-θ instead of away
|
| 20 |
+
from it. That looks wrong, but it's correct for SGD-with-Nesterov-momentum
|
| 21 |
+
on outer loop: the outer optimizer accumulates the negative-grad-direction
|
| 22 |
+
history, so the "wrong-sign" pseudogradient combined with SGD's "subtract
|
| 23 |
+
grad" semantics gives net "step in the local-Δ direction" once momentum
|
| 24 |
+
builds up. This is consistent with the DiLoCo paper's pseudo-code.
|
| 25 |
+
|
| 26 |
+
Bottom line: do NOT negate. torchft's pseudogradient sign + SGD outer
|
| 27 |
+
optimizer is the correct combination. Spike 008's
|
| 28 |
+
`test_diloco_pseudogradient_sign_convention` test catches a sign flip.
|
| 29 |
+
"""
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
from typing import Any
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
# Import lazily — torchft is an optional dep at framework level.
|
| 37 |
+
_TORCHFT_AVAILABLE = False
|
| 38 |
+
DiLoCo: Any = None
|
| 39 |
+
Manager: Any = None
|
| 40 |
+
_DummyWork: Any = None
|
| 41 |
+
try:
|
| 42 |
+
from torchft.local_sgd import DiLoCo as _DiLoCo # type: ignore[import]
|
| 43 |
+
from torchft.manager import Manager as _Manager # type: ignore[import]
|
| 44 |
+
from torchft.work import _DummyWork as __DummyWork # type: ignore[import]
|
| 45 |
+
|
| 46 |
+
_TORCHFT_AVAILABLE = True
|
| 47 |
+
DiLoCo = _DiLoCo
|
| 48 |
+
Manager = _Manager
|
| 49 |
+
_DummyWork = __DummyWork
|
| 50 |
+
except ImportError: # pragma: no cover — only hits in lighter-weight CI envs
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def make_diloco_outer_loop(
|
| 55 |
+
manager: Any,
|
| 56 |
+
model_fragments: list[torch.nn.Module],
|
| 57 |
+
inner_optimizer: torch.optim.Optimizer,
|
| 58 |
+
*,
|
| 59 |
+
outer_lr: float = 0.7,
|
| 60 |
+
outer_momentum: float = 0.9,
|
| 61 |
+
nesterov: bool = True,
|
| 62 |
+
sync_every: int = 100,
|
| 63 |
+
fragment_sync_delay: int = 0,
|
| 64 |
+
fragment_update_alpha: float = 0.0,
|
| 65 |
+
) -> Any:
|
| 66 |
+
"""Construct a DiLoCo wrapper around `model_fragments` with default DiLoCo hyperparams.
|
| 67 |
+
|
| 68 |
+
Default hyperparams (DiLoCo paper §3.2):
|
| 69 |
+
outer_lr = 0.7, outer_momentum = 0.9, Nesterov
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
manager: torchft.Manager (or test mock with `.allreduce`, `.should_commit`,
|
| 73 |
+
`.current_step`, `.start_quorum`)
|
| 74 |
+
model_fragments: list of nn.Modules. For vanilla DiLoCo, pass [whole_model].
|
| 75 |
+
For Streaming DiLoCo with N fragments, pass [frag_0, frag_1, ..., frag_N-1].
|
| 76 |
+
inner_optimizer: any torch.optim.Optimizer. Steps every batch.
|
| 77 |
+
outer_lr / outer_momentum / nesterov: outer SGD hyperparams.
|
| 78 |
+
Override defaults only if you know why.
|
| 79 |
+
sync_every: number of inner steps per outer round.
|
| 80 |
+
fragment_sync_delay: 0 = vanilla DiLoCo (sync at outer round).
|
| 81 |
+
>0 = Streaming DiLoCo with overlapped sync. Requires CUDA streams.
|
| 82 |
+
fragment_update_alpha: 0 = full replacement of fragment params on sync.
|
| 83 |
+
>0 = exponential mixing weight. Streaming DiLoCo only.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
A torchft.local_sgd.DiLoCo instance configured for the framework's
|
| 87 |
+
conventions. Use as a context manager:
|
| 88 |
+
with make_diloco_outer_loop(...) as outer:
|
| 89 |
+
for step in range(N):
|
| 90 |
+
inner_optimizer.zero_grad()
|
| 91 |
+
loss = compute_loss(...)
|
| 92 |
+
loss.backward()
|
| 93 |
+
inner_optimizer.step() # outer sync fires automatically
|
| 94 |
+
"""
|
| 95 |
+
if not _TORCHFT_AVAILABLE:
|
| 96 |
+
raise RuntimeError(
|
| 97 |
+
"torchft is not installed. `pip install torchft-nightly` to use DiLoCo."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
outer_optimizer = torch.optim.SGD(
|
| 101 |
+
[p for frag in model_fragments for p in frag.parameters()],
|
| 102 |
+
lr=outer_lr,
|
| 103 |
+
momentum=outer_momentum,
|
| 104 |
+
nesterov=nesterov,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return DiLoCo(
|
| 108 |
+
manager=manager,
|
| 109 |
+
model_fragments=model_fragments,
|
| 110 |
+
inner_optimizer=inner_optimizer,
|
| 111 |
+
outer_optimizer=outer_optimizer,
|
| 112 |
+
sync_every=sync_every,
|
| 113 |
+
fragment_sync_delay=fragment_sync_delay,
|
| 114 |
+
fragment_update_alpha=fragment_update_alpha,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
__all__ = [
|
| 119 |
+
"make_diloco_outer_loop",
|
| 120 |
+
"DiLoCo",
|
| 121 |
+
"Manager",
|
| 122 |
+
"_DummyWork",
|
| 123 |
+
"_TORCHFT_AVAILABLE",
|
| 124 |
+
]
|
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""hint_generator.py — Template-based hint generator (v0.1 starter).
|
| 2 |
+
|
| 3 |
+
Composer 2.5 inserts text hints at error-turn sites:
|
| 4 |
+
"Reminder: Available tools are: …" (when a tool-call refs a non-existent tool)
|
| 5 |
+
"Reminder: tool arguments must be valid JSON" (on JSONDecodeError)
|
| 6 |
+
... etc.
|
| 7 |
+
|
| 8 |
+
This module provides a registry of hint templates keyed by error_kind. The
|
| 9 |
+
data collator (in trl_path/data_collator.py) calls dispatch(error_kind, ctx)
|
| 10 |
+
to get the hint text to splice into ctx_teacher.
|
| 11 |
+
|
| 12 |
+
v0.2 will replace these templates with an LLM-driven hint generator (likely
|
| 13 |
+
Sonnet 4.6 or Opus 4.7 via OpenRouter) for cases where templates are too rigid
|
| 14 |
+
(style violations, wasteful explanations).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
+
from typing import TypedDict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class HintContext(TypedDict, total=False):
|
| 24 |
+
"""Per-error context the hint generator can use."""
|
| 25 |
+
error_kind: str # e.g. "tool_not_found", "json_decode", "type_error"
|
| 26 |
+
error_message: str # raw error from the env
|
| 27 |
+
available_tools: list[str] # for tool_not_found
|
| 28 |
+
tool_name: str # the failing tool, if known
|
| 29 |
+
tool_schema: dict # the schema, if known
|
| 30 |
+
intent: str # student's apparent intent, if extractable
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Hint templates
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
def hint_tool_not_found(ctx: HintContext) -> str:
|
| 38 |
+
tools = ctx.get("available_tools", [])
|
| 39 |
+
if tools:
|
| 40 |
+
tool_list = ", ".join(f"`{t}`" for t in tools)
|
| 41 |
+
return f"Reminder: Available tools are: {tool_list}. Please use one of these."
|
| 42 |
+
return "Reminder: the tool you tried to call does not exist. Use only available tools."
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def hint_json_decode(ctx: HintContext) -> str:
|
| 46 |
+
return (
|
| 47 |
+
"Reminder: tool arguments must be valid JSON. Common mistakes: "
|
| 48 |
+
"single quotes (use double), trailing commas, unescaped newlines in strings."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def hint_type_error(ctx: HintContext) -> str:
|
| 53 |
+
name = ctx.get("tool_name")
|
| 54 |
+
schema = ctx.get("tool_schema")
|
| 55 |
+
if name and schema:
|
| 56 |
+
return (
|
| 57 |
+
f"Reminder: `{name}` expects arguments matching this schema:\n"
|
| 58 |
+
f" {schema}\n"
|
| 59 |
+
"Re-issue the call with arguments matching the schema."
|
| 60 |
+
)
|
| 61 |
+
return "Reminder: tool arguments do not match the expected types. Check the schema."
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def hint_runtime_error(ctx: HintContext) -> str:
|
| 65 |
+
msg = ctx.get("error_message", "an exception")
|
| 66 |
+
return (
|
| 67 |
+
f"Reminder: the previous tool call raised {msg}. "
|
| 68 |
+
"Reconsider the inputs or read the relevant code first to understand state."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def hint_repeated_failure(ctx: HintContext) -> str:
|
| 73 |
+
"""Triggered when the same kind of error happens 3+ times in a row."""
|
| 74 |
+
return (
|
| 75 |
+
"Reminder: this approach has failed multiple times. "
|
| 76 |
+
"Step back and consider an alternative approach: read more files, "
|
| 77 |
+
"search for similar patterns elsewhere, or break the task down differently."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Registry
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
HINT_TEMPLATES: dict[str, Callable[[HintContext], str]] = {
|
| 86 |
+
"tool_not_found": hint_tool_not_found,
|
| 87 |
+
"json_decode": hint_json_decode,
|
| 88 |
+
"type_error": hint_type_error,
|
| 89 |
+
"runtime_error": hint_runtime_error,
|
| 90 |
+
"repeated_failure": hint_repeated_failure,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None:
|
| 95 |
+
"""Generate a hint for the given error_kind. Returns None if unknown."""
|
| 96 |
+
fn = HINT_TEMPLATES.get(error_kind)
|
| 97 |
+
if fn is None:
|
| 98 |
+
return None
|
| 99 |
+
return fn(ctx or {})
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None:
|
| 103 |
+
"""Add a custom hint template."""
|
| 104 |
+
HINT_TEMPLATES[error_kind] = fn
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
__all__ = ["dispatch", "register", "HintContext", "HINT_TEMPLATES"]
|
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication.ingestion — trace-source adapters.
|
| 2 |
+
|
| 3 |
+
v0.1: Claude Code session JSONL.
|
| 4 |
+
v0.2 candidates: OpenHands trajectories, SWE-smith-trajectories.
|
| 5 |
+
|
| 6 |
+
Per docs/adrs/ADR-002-trace-source.md.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from composer_replication.ingestion.claude_code import (
|
| 11 |
+
SYSTEM_PROMPT,
|
| 12 |
+
ClaudeCodeIngester,
|
| 13 |
+
IngestionStats,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"ClaudeCodeIngester",
|
| 18 |
+
"IngestionStats",
|
| 19 |
+
"SYSTEM_PROMPT",
|
| 20 |
+
]
|
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""claude_code_ingester.py — Claude Code session JSONL → TraceState iterator.
|
| 2 |
+
|
| 3 |
+
Maps the user's local `~/.claude/projects/<encoded>/<sessionId>.jsonl` files to
|
| 4 |
+
the existing `TraceState` schema (state_id + messages + student_action).
|
| 5 |
+
|
| 6 |
+
Design (per ADR-002):
|
| 7 |
+
- One TraceState per assistant TURN (not per tool_use block). Multiple tool_use
|
| 8 |
+
blocks in one assistant message belong to a single reasoning step.
|
| 9 |
+
- `student_action` = JSON-serialized list of (text + tool_use) blocks of the
|
| 10 |
+
assistant message. Teacher gets the message history before this turn and is
|
| 11 |
+
asked "what should the assistant do here?". Comparison vs the literal student
|
| 12 |
+
action gives our DPO signal.
|
| 13 |
+
- `messages` = OpenAI-style history of all records BEFORE this assistant turn.
|
| 14 |
+
System + user messages preserved; previous assistant turns flattened to text.
|
| 15 |
+
- `thinking` blocks STRIPPED from messages passed to teachers (teachers don't
|
| 16 |
+
have access to Claude's reasoning trace) but KEPT in student_action so the
|
| 17 |
+
reproduction loop sees what the student actually emitted.
|
| 18 |
+
- A synthetic system prompt is injected at messages[0] for trace IDs without one
|
| 19 |
+
(most Claude Code sessions don't have one written into the JSONL).
|
| 20 |
+
- Subagent traces (filenames starting with `agent-` OR records with
|
| 21 |
+
`isSidechain: True`) are SKIPPED in v0.1.
|
| 22 |
+
|
| 23 |
+
This is the v0.1 ingester. Non-goals:
|
| 24 |
+
- Reference-policy logprob precompute (lives in the data collator).
|
| 25 |
+
- Error-site detection (separate concern; uses tool_result is_error flag).
|
| 26 |
+
- DPO-pair extraction (lives in teacher_replay.extract_dpo_pairs).
|
| 27 |
+
"""
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import json
|
| 31 |
+
import logging
|
| 32 |
+
import re
|
| 33 |
+
import sys
|
| 34 |
+
from collections.abc import Iterator
|
| 35 |
+
from dataclasses import dataclass
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import Any, TypedDict
|
| 38 |
+
|
| 39 |
+
from composer_replication.teacher_replay import TraceState
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
SUPPORTED_VERSIONS = re.compile(r"^2\.\d+\.\d+$")
|
| 45 |
+
SYSTEM_PROMPT = (
|
| 46 |
+
"You are a senior software engineer working as a coding agent in a terminal "
|
| 47 |
+
"environment. You can call tools (Bash, Read, Write, Edit, Grep, etc.) and "
|
| 48 |
+
"see their outputs. Reason carefully before each action. When a tool fails, "
|
| 49 |
+
"diagnose the cause and adjust."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class IngestionStats:
|
| 55 |
+
n_records_total: int = 0
|
| 56 |
+
n_records_skipped: int = 0
|
| 57 |
+
n_states_emitted: int = 0
|
| 58 |
+
n_assistant_turns: int = 0
|
| 59 |
+
n_tool_use_blocks: int = 0
|
| 60 |
+
n_text_blocks: int = 0
|
| 61 |
+
skipped_subagent: int = 0
|
| 62 |
+
skipped_summary: int = 0
|
| 63 |
+
skipped_truncated_lines: int = 0
|
| 64 |
+
version_warnings: list[str] | None = None
|
| 65 |
+
|
| 66 |
+
def __post_init__(self) -> None:
|
| 67 |
+
if self.version_warnings is None:
|
| 68 |
+
self.version_warnings = []
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ClaudeCodeIngester:
|
| 72 |
+
"""Convert one or more Claude Code session JSONL files to TraceState records.
|
| 73 |
+
|
| 74 |
+
Usage:
|
| 75 |
+
ingester = ClaudeCodeIngester()
|
| 76 |
+
for state in ingester.ingest(Path("session.jsonl")):
|
| 77 |
+
...
|
| 78 |
+
stats = ingester.last_stats
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
*,
|
| 84 |
+
system_prompt: str = SYSTEM_PROMPT,
|
| 85 |
+
skip_sidechain: bool = True,
|
| 86 |
+
strip_thinking: bool = True,
|
| 87 |
+
max_history_tokens: int | None = None,
|
| 88 |
+
) -> None:
|
| 89 |
+
self.system_prompt = system_prompt
|
| 90 |
+
self.skip_sidechain = skip_sidechain
|
| 91 |
+
self.strip_thinking = strip_thinking
|
| 92 |
+
self.max_history_tokens = max_history_tokens
|
| 93 |
+
self.last_stats = IngestionStats()
|
| 94 |
+
|
| 95 |
+
def ingest(self, path: Path) -> Iterator[TraceState]:
|
| 96 |
+
"""Yield one TraceState per assistant turn in the given session JSONL."""
|
| 97 |
+
self.last_stats = IngestionStats()
|
| 98 |
+
stats = self.last_stats
|
| 99 |
+
|
| 100 |
+
# Skip subagent files by filename convention
|
| 101 |
+
if self.skip_sidechain and path.name.startswith("agent-"):
|
| 102 |
+
logger.info("Skipping subagent file: %s", path)
|
| 103 |
+
stats.skipped_subagent = 1
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
records = list(self._iter_records(path))
|
| 107 |
+
# Build a quick lookup of records that ARE assistant turns; everything
|
| 108 |
+
# else feeds the message history we hand to teachers.
|
| 109 |
+
history: list[dict[str, Any]] = [
|
| 110 |
+
{"role": "system", "content": self.system_prompt}
|
| 111 |
+
]
|
| 112 |
+
state_idx = 0
|
| 113 |
+
for rec in records:
|
| 114 |
+
stats.n_records_total += 1
|
| 115 |
+
|
| 116 |
+
rec_type = rec.get("type")
|
| 117 |
+
if rec_type == "summary":
|
| 118 |
+
stats.skipped_summary += 1
|
| 119 |
+
continue
|
| 120 |
+
if rec_type in {"attachment", "queue-operation", "file-history-snapshot",
|
| 121 |
+
"last-prompt", "system"}:
|
| 122 |
+
stats.n_records_skipped += 1
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
if self.skip_sidechain and rec.get("isSidechain") is True:
|
| 126 |
+
stats.skipped_subagent += 1
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
if rec_type == "user":
|
| 130 |
+
msg = rec.get("message", {})
|
| 131 |
+
content = msg.get("content")
|
| 132 |
+
if isinstance(content, str):
|
| 133 |
+
history.append({"role": "user", "content": content})
|
| 134 |
+
elif isinstance(content, list):
|
| 135 |
+
# Either text blocks (a real human prompt) or tool_result
|
| 136 |
+
# blocks (an observation). Both go into history as user
|
| 137 |
+
# messages, but we serialize them differently.
|
| 138 |
+
flat = self._flatten_user_content(content)
|
| 139 |
+
if flat:
|
| 140 |
+
history.append({"role": "user", "content": flat})
|
| 141 |
+
|
| 142 |
+
elif rec_type == "assistant":
|
| 143 |
+
msg = rec.get("message", {})
|
| 144 |
+
content = msg.get("content")
|
| 145 |
+
if not isinstance(content, list):
|
| 146 |
+
stats.n_records_skipped += 1
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Build student_action from this assistant message's content
|
| 150 |
+
# (KEEPING thinking blocks in student_action — that's the
|
| 151 |
+
# actual student emission we'd be RL-training).
|
| 152 |
+
student_action = self._serialize_assistant_content(
|
| 153 |
+
content, strip_thinking=False,
|
| 154 |
+
)
|
| 155 |
+
if not student_action:
|
| 156 |
+
# Empty assistant turn — skip
|
| 157 |
+
stats.n_records_skipped += 1
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
# Track block counts
|
| 161 |
+
for block in content:
|
| 162 |
+
if isinstance(block, dict):
|
| 163 |
+
bt = block.get("type")
|
| 164 |
+
if bt == "tool_use":
|
| 165 |
+
stats.n_tool_use_blocks += 1
|
| 166 |
+
elif bt == "text":
|
| 167 |
+
stats.n_text_blocks += 1
|
| 168 |
+
|
| 169 |
+
# Build the messages handed to teachers — strip thinking
|
| 170 |
+
# blocks if configured.
|
| 171 |
+
teacher_history = self._maybe_strip_thinking(history)
|
| 172 |
+
|
| 173 |
+
state = TraceState(
|
| 174 |
+
state_id=f"{path.stem}::{state_idx:04d}",
|
| 175 |
+
messages=list(teacher_history), # snapshot
|
| 176 |
+
student_action=student_action,
|
| 177 |
+
)
|
| 178 |
+
yield state
|
| 179 |
+
stats.n_states_emitted += 1
|
| 180 |
+
state_idx += 1
|
| 181 |
+
stats.n_assistant_turns += 1
|
| 182 |
+
|
| 183 |
+
# Append a flattened version of this assistant turn to history
|
| 184 |
+
# for the NEXT teacher call (history grows with each turn).
|
| 185 |
+
history.append({
|
| 186 |
+
"role": "assistant",
|
| 187 |
+
"content": self._serialize_assistant_content(
|
| 188 |
+
content, strip_thinking=self.strip_thinking,
|
| 189 |
+
),
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
# Validate version field of last seen record (best-effort)
|
| 193 |
+
if records:
|
| 194 |
+
v = records[-1].get("version")
|
| 195 |
+
if v and not SUPPORTED_VERSIONS.match(str(v)):
|
| 196 |
+
stats.version_warnings.append(
|
| 197 |
+
f"Unrecognized version {v!r} in {path.name} — ingester "
|
| 198 |
+
"tested against 2.x.x. Check schema compatibility."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# ------------------------------------------------------------------
|
| 202 |
+
# Helpers
|
| 203 |
+
# ------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
def _iter_records(self, path: Path) -> Iterator[dict[str, Any]]:
|
| 206 |
+
with path.open("r", encoding="utf-8") as f:
|
| 207 |
+
for line in f:
|
| 208 |
+
line = line.strip()
|
| 209 |
+
if not line:
|
| 210 |
+
continue
|
| 211 |
+
try:
|
| 212 |
+
yield json.loads(line)
|
| 213 |
+
except json.JSONDecodeError as e:
|
| 214 |
+
self.last_stats.skipped_truncated_lines += 1
|
| 215 |
+
logger.debug("Truncated/malformed line in %s: %s", path, e)
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
def _flatten_user_content(self, content: list[Any]) -> str:
|
| 219 |
+
"""Convert a user record's content list to a single string."""
|
| 220 |
+
parts: list[str] = []
|
| 221 |
+
for block in content:
|
| 222 |
+
if not isinstance(block, dict):
|
| 223 |
+
continue
|
| 224 |
+
bt = block.get("type")
|
| 225 |
+
if bt == "text":
|
| 226 |
+
txt = block.get("text", "")
|
| 227 |
+
if txt:
|
| 228 |
+
parts.append(txt)
|
| 229 |
+
elif bt == "tool_result":
|
| 230 |
+
tc = block.get("content", "")
|
| 231 |
+
if isinstance(tc, list):
|
| 232 |
+
# Sometimes content is itself a list of blocks
|
| 233 |
+
sub = []
|
| 234 |
+
for sb in tc:
|
| 235 |
+
if isinstance(sb, dict) and sb.get("type") == "text":
|
| 236 |
+
sub.append(sb.get("text", ""))
|
| 237 |
+
tc = "\n".join(sub)
|
| 238 |
+
tu_id = block.get("tool_use_id", "<unknown>")
|
| 239 |
+
is_err = block.get("is_error", False)
|
| 240 |
+
tag = "[TOOL_RESULT (ERROR)]" if is_err else "[TOOL_RESULT]"
|
| 241 |
+
parts.append(f"{tag} (id={tu_id})\n{tc}")
|
| 242 |
+
elif bt == "image":
|
| 243 |
+
parts.append("[IMAGE OMITTED]")
|
| 244 |
+
return "\n\n".join(parts)
|
| 245 |
+
|
| 246 |
+
def _serialize_assistant_content(
|
| 247 |
+
self, content: list[Any], *, strip_thinking: bool,
|
| 248 |
+
) -> str:
|
| 249 |
+
"""Serialize an assistant message's content list to a string.
|
| 250 |
+
|
| 251 |
+
Preserves:
|
| 252 |
+
text blocks → as-is
|
| 253 |
+
thinking blocks → "[THINKING] ..." (or stripped)
|
| 254 |
+
tool_use blocks → "[TOOL_USE] name=... input={json}"
|
| 255 |
+
"""
|
| 256 |
+
parts: list[str] = []
|
| 257 |
+
for block in content:
|
| 258 |
+
if not isinstance(block, dict):
|
| 259 |
+
continue
|
| 260 |
+
bt = block.get("type")
|
| 261 |
+
if bt == "text":
|
| 262 |
+
parts.append(block.get("text", ""))
|
| 263 |
+
elif bt == "thinking":
|
| 264 |
+
if not strip_thinking:
|
| 265 |
+
parts.append(f"[THINKING] {block.get('thinking', '')}")
|
| 266 |
+
elif bt == "tool_use":
|
| 267 |
+
name = block.get("name", "")
|
| 268 |
+
inp = block.get("input", {})
|
| 269 |
+
try:
|
| 270 |
+
inp_str = json.dumps(inp, separators=(",", ":"))
|
| 271 |
+
except (TypeError, ValueError):
|
| 272 |
+
inp_str = str(inp)
|
| 273 |
+
parts.append(f"[TOOL_USE] name={name} input={inp_str}")
|
| 274 |
+
return "\n\n".join(p for p in parts if p)
|
| 275 |
+
|
| 276 |
+
def _maybe_strip_thinking(self, history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 277 |
+
if not self.strip_thinking:
|
| 278 |
+
return history
|
| 279 |
+
out = []
|
| 280 |
+
for msg in history:
|
| 281 |
+
if msg["role"] != "assistant":
|
| 282 |
+
out.append(msg)
|
| 283 |
+
continue
|
| 284 |
+
# Strip [THINKING] lines from assistant content
|
| 285 |
+
content = msg["content"]
|
| 286 |
+
if isinstance(content, str):
|
| 287 |
+
lines = content.split("\n\n")
|
| 288 |
+
kept = [l for l in lines if not l.strip().startswith("[THINKING]")]
|
| 289 |
+
out.append({"role": "assistant", "content": "\n\n".join(kept)})
|
| 290 |
+
else:
|
| 291 |
+
out.append(msg)
|
| 292 |
+
return out
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
__all__ = ["ClaudeCodeIngester", "IngestionStats", "SYSTEM_PROMPT"]
|
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""compose_loss.py — free 3-channel loss composer for verification smokes.
|
| 2 |
+
|
| 3 |
+
This is a verification-harness mirror of `ComposerReplicationTrainer._compute_loss`
|
| 4 |
+
that does NOT depend on TRL's GRPOTrainer parent. The GRPO channel is replaced
|
| 5 |
+
with standard LM next-token-prediction cross-entropy, which is the limit GRPO
|
| 6 |
+
converges to under deterministic rewards.
|
| 7 |
+
|
| 8 |
+
Use it for:
|
| 9 |
+
- CPU smokes on real HF models (Spike 006)
|
| 10 |
+
- Unit tests of loss composition without spinning up TRL
|
| 11 |
+
- Anywhere we want to verify gradient flow through the 3-channel sum
|
| 12 |
+
without paying TRL's full machinery cost
|
| 13 |
+
|
| 14 |
+
Do NOT use it as the production training loss. Production = ComposerReplicationTrainer
|
| 15 |
+
(a real GRPOTrainer subclass) which uses TRL's reward + advantage estimation.
|
| 16 |
+
|
| 17 |
+
Total loss:
|
| 18 |
+
total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo
|
| 19 |
+
|
| 20 |
+
Channels:
|
| 21 |
+
- lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub)
|
| 22 |
+
- sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits
|
| 23 |
+
- trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs
|
| 24 |
+
"""
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import sys
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class LossComponents:
|
| 39 |
+
"""Per-channel breakdown of the total loss for logging + ablation."""
|
| 40 |
+
lm_ce: torch.Tensor
|
| 41 |
+
sdpo_jsd: torch.Tensor
|
| 42 |
+
trace_replay_dpo: torch.Tensor
|
| 43 |
+
total: torch.Tensor
|
| 44 |
+
|
| 45 |
+
def detached(self) -> dict[str, float]:
|
| 46 |
+
return {
|
| 47 |
+
"lm_ce": float(self.lm_ce.detach()),
|
| 48 |
+
"sdpo_jsd": float(self.sdpo_jsd.detach()),
|
| 49 |
+
"trace_replay_dpo": float(self.trace_replay_dpo.detach()),
|
| 50 |
+
"total": float(self.total.detach()),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def compose_loss(
|
| 55 |
+
model: torch.nn.Module,
|
| 56 |
+
inputs: dict[str, torch.Tensor],
|
| 57 |
+
*,
|
| 58 |
+
alpha_sdpo: float = 0.1,
|
| 59 |
+
beta_replay: float = 0.05,
|
| 60 |
+
sdpo_jsd_beta: float = 0.5,
|
| 61 |
+
sdpo_temperature: float = 1.0,
|
| 62 |
+
sdpo_token_clip: float | None = None,
|
| 63 |
+
replay_dpo_beta: float = 0.1,
|
| 64 |
+
lm_ce_label_smoothing: float = 0.0,
|
| 65 |
+
) -> LossComponents:
|
| 66 |
+
"""Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo.
|
| 67 |
+
|
| 68 |
+
Required keys in `inputs`:
|
| 69 |
+
- input_ids: (B, T_s) student rollout
|
| 70 |
+
- response_mask: (B, T_s) 1 on assistant-response tokens, 0 elsewhere
|
| 71 |
+
|
| 72 |
+
Optional keys (channel auto-disables if missing OR if its weight = 0):
|
| 73 |
+
SDPO:
|
| 74 |
+
- ctx_teacher_input_ids: (B, T_t) hint-conditioned context
|
| 75 |
+
- sdpo_loss_mask: (B, T_t) 1 at error-turn tokens
|
| 76 |
+
DPO:
|
| 77 |
+
- dpo_chosen_input_ids, dpo_chosen_response_mask
|
| 78 |
+
- dpo_rejected_input_ids, dpo_rejected_response_mask
|
| 79 |
+
- dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed)
|
| 80 |
+
"""
|
| 81 |
+
device = _device_of(model)
|
| 82 |
+
|
| 83 |
+
# ------------------------------------------------------------------
|
| 84 |
+
# Channel 1 (GRPO stub): LM cross-entropy on response tokens
|
| 85 |
+
# ------------------------------------------------------------------
|
| 86 |
+
lm_ce = _lm_response_ce(
|
| 87 |
+
model,
|
| 88 |
+
inputs["input_ids"],
|
| 89 |
+
inputs["response_mask"],
|
| 90 |
+
label_smoothing=lm_ce_label_smoothing,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# ------------------------------------------------------------------
|
| 94 |
+
# Channel 2 (SDPO): generalized JSD on hint-conditioned forward
|
| 95 |
+
# ------------------------------------------------------------------
|
| 96 |
+
sdpo_jsd = _zero(device)
|
| 97 |
+
if (
|
| 98 |
+
alpha_sdpo > 0.0
|
| 99 |
+
and "ctx_teacher_input_ids" in inputs
|
| 100 |
+
and inputs["ctx_teacher_input_ids"].numel() > 0
|
| 101 |
+
):
|
| 102 |
+
student_logits = model(input_ids=inputs["input_ids"]).logits
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
|
| 105 |
+
|
| 106 |
+
if student_logits.shape == teacher_logits.shape:
|
| 107 |
+
sdpo_jsd = generalized_jsd_loss(
|
| 108 |
+
student_logits=student_logits,
|
| 109 |
+
teacher_logits=teacher_logits,
|
| 110 |
+
labels=inputs.get("sdpo_loss_mask"),
|
| 111 |
+
beta=sdpo_jsd_beta,
|
| 112 |
+
temperature=sdpo_temperature,
|
| 113 |
+
token_clip=sdpo_token_clip,
|
| 114 |
+
reduction="batchmean",
|
| 115 |
+
)
|
| 116 |
+
# else: silently zero — the data collator is responsible for shape
|
| 117 |
+
# alignment in production. For the smoke we accept misalignment and
|
| 118 |
+
# exercise the fallback path.
|
| 119 |
+
|
| 120 |
+
# ------------------------------------------------------------------
|
| 121 |
+
# Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement
|
| 122 |
+
# pairs.
|
| 123 |
+
# ------------------------------------------------------------------
|
| 124 |
+
trace_replay_dpo = _zero(device)
|
| 125 |
+
if (
|
| 126 |
+
beta_replay > 0.0
|
| 127 |
+
and "dpo_chosen_input_ids" in inputs
|
| 128 |
+
and inputs["dpo_chosen_input_ids"].numel() > 0
|
| 129 |
+
):
|
| 130 |
+
chosen_lp = _sequence_logprobs(
|
| 131 |
+
model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"]
|
| 132 |
+
)
|
| 133 |
+
rejected_lp = _sequence_logprobs(
|
| 134 |
+
model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"]
|
| 135 |
+
)
|
| 136 |
+
ref_chosen = inputs["dpo_chosen_ref_logprobs"]
|
| 137 |
+
ref_rejected = inputs["dpo_rejected_ref_logprobs"]
|
| 138 |
+
dpo_logits = replay_dpo_beta * (
|
| 139 |
+
(chosen_lp - ref_chosen) - (rejected_lp - ref_rejected)
|
| 140 |
+
)
|
| 141 |
+
trace_replay_dpo = -F.logsigmoid(dpo_logits).mean()
|
| 142 |
+
|
| 143 |
+
total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo
|
| 144 |
+
|
| 145 |
+
return LossComponents(
|
| 146 |
+
lm_ce=lm_ce,
|
| 147 |
+
sdpo_jsd=sdpo_jsd,
|
| 148 |
+
trace_replay_dpo=trace_replay_dpo,
|
| 149 |
+
total=total,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ----------------------------------------------------------------------
|
| 154 |
+
# Helpers
|
| 155 |
+
# ----------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
def _zero(device: torch.device) -> torch.Tensor:
|
| 158 |
+
"""Differentiable zero — safe to add into a sum without breaking backward."""
|
| 159 |
+
return torch.zeros(1, device=device, requires_grad=True).squeeze()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _device_of(model: torch.nn.Module) -> torch.device:
|
| 163 |
+
return next(model.parameters()).device
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _lm_response_ce(
|
| 167 |
+
model: torch.nn.Module,
|
| 168 |
+
input_ids: torch.Tensor,
|
| 169 |
+
response_mask: torch.Tensor,
|
| 170 |
+
*,
|
| 171 |
+
label_smoothing: float = 0.0,
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
"""Standard next-token-prediction cross-entropy on response tokens only.
|
| 174 |
+
|
| 175 |
+
Mirrors what GRPO converges to under deterministic rewards (the policy
|
| 176 |
+
gradient devolves to behavior cloning of high-reward rollouts).
|
| 177 |
+
"""
|
| 178 |
+
outputs = model(input_ids=input_ids)
|
| 179 |
+
# Shift: logits[t] predicts input_ids[t+1]
|
| 180 |
+
logits = outputs.logits[:, :-1, :]
|
| 181 |
+
targets = input_ids[:, 1:]
|
| 182 |
+
mask = response_mask[:, 1:].float()
|
| 183 |
+
|
| 184 |
+
loss_per_token = F.cross_entropy(
|
| 185 |
+
logits.reshape(-1, logits.size(-1)),
|
| 186 |
+
targets.reshape(-1),
|
| 187 |
+
reduction="none",
|
| 188 |
+
label_smoothing=label_smoothing,
|
| 189 |
+
).view_as(targets)
|
| 190 |
+
|
| 191 |
+
masked = loss_per_token * mask
|
| 192 |
+
n_tokens = mask.sum().clamp_min(1.0)
|
| 193 |
+
return masked.sum() / n_tokens
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _sequence_logprobs(
|
| 197 |
+
model: torch.nn.Module,
|
| 198 |
+
input_ids: torch.Tensor,
|
| 199 |
+
response_mask: torch.Tensor,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
"""Sum of next-token logprobs over response tokens (standard DPO accounting)."""
|
| 202 |
+
outputs = model(input_ids=input_ids)
|
| 203 |
+
logits = outputs.logits[:, :-1, :]
|
| 204 |
+
targets = input_ids[:, 1:]
|
| 205 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 206 |
+
token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
|
| 207 |
+
masked = token_lp * response_mask[:, 1:].float()
|
| 208 |
+
return masked.sum(dim=-1)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
__all__ = ["compose_loss", "LossComponents"]
|
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""opsd_loss.py — Self-distillation loss, lifted from siyan-zhao/OPSD.
|
| 2 |
+
|
| 3 |
+
Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT).
|
| 4 |
+
Verified self-contained via DeepWiki audit on 2026-05-25.
|
| 5 |
+
|
| 6 |
+
Mathematical reference:
|
| 7 |
+
- OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation
|
| 8 |
+
for LLMs", arXiv:2601.18734.
|
| 9 |
+
- SDPO paper: Hübotter et al., "Reinforcement Learning via Self-Distillation",
|
| 10 |
+
arXiv:2601.20802 (formalizes the same loss as Composer 2.5's "Targeted RL with
|
| 11 |
+
Textual Feedback").
|
| 12 |
+
|
| 13 |
+
The loss computes JSD/KL divergence between a teacher distribution (model
|
| 14 |
+
conditioned on privileged information / a hint) and a student distribution
|
| 15 |
+
(model on the original context). Both come from the SAME model — the teacher
|
| 16 |
+
is just "the model with hint inserted into context."
|
| 17 |
+
|
| 18 |
+
Composer 2.5 uses this with the privileged information being a "hint" inserted
|
| 19 |
+
at the error-turn site. We use the same loss; the data collator constructs
|
| 20 |
+
ctx_teacher = ctx_student + hint_at_error_turn for us.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generalized_jsd_loss(
|
| 30 |
+
student_logits: torch.Tensor,
|
| 31 |
+
teacher_logits: torch.Tensor,
|
| 32 |
+
labels: torch.Tensor | None = None,
|
| 33 |
+
beta: float = 0.5,
|
| 34 |
+
temperature: float = 1.0,
|
| 35 |
+
reduction: str = "batchmean",
|
| 36 |
+
logits_are_probs: bool = False,
|
| 37 |
+
top_k: int | None = None,
|
| 38 |
+
token_clip: float | None = None,
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
"""Generalized Jensen-Shannon Divergence loss between student and teacher.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
student_logits: (B, T, V) — student model logits at each token position.
|
| 44 |
+
teacher_logits: (B, T, V) — teacher (= same model with hint context) logits.
|
| 45 |
+
labels: (B, T) — token-level mask. Positions with label == -100 are ignored
|
| 46 |
+
(standard HF padding/ignored convention). For Composer-style hint-distill,
|
| 47 |
+
mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere.
|
| 48 |
+
beta: in [0, 1]. 0 = forward KL (student → teacher); 1 = reverse KL
|
| 49 |
+
(teacher → student); 0.5 = symmetric JSD (default, recommended).
|
| 50 |
+
temperature: softens distributions; T > 1 encourages distribution-matching
|
| 51 |
+
on broader tail probabilities. SDPO paper uses 1.0.
|
| 52 |
+
reduction: "batchmean" (sum / batch_size, like torch.nn.KLDivLoss) or "sum".
|
| 53 |
+
logits_are_probs: if True, inputs are already probabilities (skip softmax).
|
| 54 |
+
top_k: restrict KL to top-k tokens of the teacher distribution.
|
| 55 |
+
Saves compute on large vocabularies (Qwen3 vocab = 152K).
|
| 56 |
+
token_clip: clip per-token JSD to this max. Stabilizes training.
|
| 57 |
+
SDPO paper does NOT clip; OPSD code defaults to None (no clip).
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Scalar loss tensor.
|
| 61 |
+
"""
|
| 62 |
+
# Temperature scaling
|
| 63 |
+
if not logits_are_probs:
|
| 64 |
+
student_logits = student_logits / temperature
|
| 65 |
+
teacher_logits = teacher_logits / temperature
|
| 66 |
+
|
| 67 |
+
# Top-k restriction (optional, for vocab-size compute savings)
|
| 68 |
+
if top_k is not None:
|
| 69 |
+
# Restrict to top-k tokens of teacher; renormalize both there.
|
| 70 |
+
teacher_topk_vals, teacher_topk_idx = teacher_logits.topk(top_k, dim=-1)
|
| 71 |
+
student_topk_vals = student_logits.gather(-1, teacher_topk_idx)
|
| 72 |
+
student_log_probs = F.log_softmax(student_topk_vals, dim=-1)
|
| 73 |
+
teacher_log_probs = F.log_softmax(teacher_topk_vals, dim=-1)
|
| 74 |
+
else:
|
| 75 |
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
| 76 |
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
| 77 |
+
|
| 78 |
+
# KL / JSD computation
|
| 79 |
+
if beta == 0.0:
|
| 80 |
+
# Forward KL: KL(student || teacher)
|
| 81 |
+
per_token_div = F.kl_div(
|
| 82 |
+
student_log_probs, teacher_log_probs,
|
| 83 |
+
reduction="none", log_target=True,
|
| 84 |
+
).sum(dim=-1)
|
| 85 |
+
elif beta == 1.0:
|
| 86 |
+
# Reverse KL: KL(teacher || student)
|
| 87 |
+
per_token_div = F.kl_div(
|
| 88 |
+
teacher_log_probs, student_log_probs,
|
| 89 |
+
reduction="none", log_target=True,
|
| 90 |
+
).sum(dim=-1)
|
| 91 |
+
else:
|
| 92 |
+
# JSD (symmetric, beta = 0.5 default):
|
| 93 |
+
# M = 0.5 * (P + Q); JSD = 0.5 * (KL(P||M) + KL(Q||M))
|
| 94 |
+
# Implementation via log-space mixture:
|
| 95 |
+
# log_m = logaddexp(log p, log q) - log 2
|
| 96 |
+
log_mixture = torch.logaddexp(student_log_probs, teacher_log_probs) - torch.log(
|
| 97 |
+
torch.tensor(2.0, device=student_logits.device)
|
| 98 |
+
)
|
| 99 |
+
kl_student_mixture = F.kl_div(
|
| 100 |
+
log_mixture, student_log_probs, reduction="none", log_target=True
|
| 101 |
+
).sum(dim=-1)
|
| 102 |
+
kl_teacher_mixture = F.kl_div(
|
| 103 |
+
log_mixture, teacher_log_probs, reduction="none", log_target=True
|
| 104 |
+
).sum(dim=-1)
|
| 105 |
+
per_token_div = beta * kl_student_mixture + (1.0 - beta) * kl_teacher_mixture
|
| 106 |
+
|
| 107 |
+
# Optional per-token clip (stability)
|
| 108 |
+
if token_clip is not None:
|
| 109 |
+
per_token_div = per_token_div.clamp(max=token_clip)
|
| 110 |
+
|
| 111 |
+
# Mask out ignored positions (labels == -100, the HF convention)
|
| 112 |
+
if labels is not None:
|
| 113 |
+
loss_mask = (labels != -100).float()
|
| 114 |
+
per_token_div = per_token_div * loss_mask
|
| 115 |
+
n_valid = loss_mask.sum().clamp(min=1.0)
|
| 116 |
+
else:
|
| 117 |
+
n_valid = torch.tensor(per_token_div.numel(), device=per_token_div.device, dtype=per_token_div.dtype)
|
| 118 |
+
|
| 119 |
+
if reduction == "batchmean":
|
| 120 |
+
# batchmean = sum over (B*T_valid) / B
|
| 121 |
+
return per_token_div.sum() / per_token_div.shape[0]
|
| 122 |
+
elif reduction == "sum":
|
| 123 |
+
return per_token_div.sum()
|
| 124 |
+
elif reduction == "mean":
|
| 125 |
+
return per_token_div.sum() / n_valid
|
| 126 |
+
elif reduction == "none":
|
| 127 |
+
return per_token_div
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(f"Unknown reduction: {reduction}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
__all__ = ["generalized_jsd_loss"]
|
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""teacher_replay.py — N-teacher OpenRouter parallel client + DPO-pair extractor.
|
| 2 |
+
|
| 3 |
+
This is channel 3 of the integrated trainer: at each step of a frozen agentic
|
| 4 |
+
trace, query N pre-trained external teachers (frontier models from different
|
| 5 |
+
labs) and convert teacher disagreement into preference pairs for DPO loss.
|
| 6 |
+
|
| 7 |
+
Generalized from spike-001's `replay.py`. Verified economic floor (✅ spike 001):
|
| 8 |
+
$0.98 mean per-trace cost ungated, $0.30/trace projected with VOI gating.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
from teacher_replay import replay_trace, extract_dpo_pairs
|
| 12 |
+
|
| 13 |
+
# 1. Replay each step of a frozen trace with N teachers.
|
| 14 |
+
teacher_actions = await replay_trace(
|
| 15 |
+
states=trace_states,
|
| 16 |
+
teachers=DEFAULT_TEACHERS,
|
| 17 |
+
max_total_usd=10.0,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# 2. Extract DPO pairs from teacher disagreement.
|
| 21 |
+
pairs = extract_dpo_pairs(
|
| 22 |
+
states=trace_states,
|
| 23 |
+
student_actions=trace_student_actions,
|
| 24 |
+
teacher_actions=teacher_actions,
|
| 25 |
+
agreement_threshold=2, # at least 2/3 teachers must agree
|
| 26 |
+
)
|
| 27 |
+
# → [{"chosen": …, "rejected": …, "state": …}, …]
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import asyncio
|
| 33 |
+
import json
|
| 34 |
+
import os
|
| 35 |
+
import time
|
| 36 |
+
from collections import Counter
|
| 37 |
+
from collections.abc import Sequence
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from typing import TypedDict
|
| 40 |
+
|
| 41 |
+
# httpx is lazy-imported inside replay_trace() so that DPO-pair extraction
|
| 42 |
+
# (the deterministic local logic) is testable without httpx installed.
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Config
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
DEFAULT_TEACHERS: list["TeacherSpec"] = [
|
| 50 |
+
{"slug": "anthropic/claude-opus-4.7", "input_per_mtok": 15.0, "output_per_mtok": 75.0},
|
| 51 |
+
{"slug": "openai/gpt-5", "input_per_mtok": 1.25, "output_per_mtok": 10.0},
|
| 52 |
+
{"slug": "deepseek/deepseek-v4-pro", "input_per_mtok": 1.10, "output_per_mtok": 4.40},
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _load_api_key() -> str:
|
| 59 |
+
"""Load OPENROUTER_API_KEY from env or ~/.hermes/.env (same as spike 001)."""
|
| 60 |
+
if "OPENROUTER_API_KEY" in os.environ:
|
| 61 |
+
return os.environ["OPENROUTER_API_KEY"]
|
| 62 |
+
hermes_env = Path.home() / ".hermes" / ".env"
|
| 63 |
+
if hermes_env.exists():
|
| 64 |
+
for line in hermes_env.read_text().splitlines():
|
| 65 |
+
line = line.strip()
|
| 66 |
+
if line.startswith("OPENROUTER_API_KEY="):
|
| 67 |
+
return line.split("=", 1)[1].strip().strip('"').strip("'")
|
| 68 |
+
raise RuntimeError("OPENROUTER_API_KEY not found in env or ~/.hermes/.env")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Types
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
class TeacherSpec(TypedDict):
|
| 76 |
+
slug: str
|
| 77 |
+
input_per_mtok: float
|
| 78 |
+
output_per_mtok: float
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TraceState(TypedDict):
|
| 82 |
+
"""One step of a frozen agentic trace."""
|
| 83 |
+
state_id: str # unique within the trace
|
| 84 |
+
messages: list[dict] # the conversation up to and including this step's user prompt
|
| 85 |
+
student_action: str # what the student actually did at this step (for DPO comparison)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TeacherCallResult(TypedDict):
|
| 89 |
+
state_id: str
|
| 90 |
+
teacher_slug: str
|
| 91 |
+
response_text: str | None
|
| 92 |
+
latency_s: float
|
| 93 |
+
prompt_tokens: int
|
| 94 |
+
completion_tokens: int
|
| 95 |
+
cost_usd: float
|
| 96 |
+
error: str | None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DPOPair(TypedDict):
|
| 100 |
+
state_id: str
|
| 101 |
+
state_messages: list[dict]
|
| 102 |
+
chosen: str # teacher-consensus action
|
| 103 |
+
rejected: str # student action
|
| 104 |
+
n_teachers_agreeing: int
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
# Teacher replay
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
async def _call_teacher(
|
| 112 |
+
client, # httpx.AsyncClient — lazy-typed so module imports without httpx
|
| 113 |
+
state: TraceState,
|
| 114 |
+
teacher: TeacherSpec,
|
| 115 |
+
api_key: str,
|
| 116 |
+
max_tokens: int = 200,
|
| 117 |
+
) -> TeacherCallResult:
|
| 118 |
+
payload = {
|
| 119 |
+
"model": teacher["slug"],
|
| 120 |
+
"messages": state["messages"],
|
| 121 |
+
"max_tokens": max_tokens,
|
| 122 |
+
"temperature": 0.2,
|
| 123 |
+
}
|
| 124 |
+
headers = {
|
| 125 |
+
"Authorization": f"Bearer {api_key}",
|
| 126 |
+
"Content-Type": "application/json",
|
| 127 |
+
"HTTP-Referer": "https://huggingface.co/Codeseys/composer-replication-framework",
|
| 128 |
+
"X-Title": "composer-replication-framework spike-005-skeleton",
|
| 129 |
+
}
|
| 130 |
+
t0 = time.perf_counter()
|
| 131 |
+
err = None
|
| 132 |
+
response_text = None
|
| 133 |
+
prompt_tokens = 0
|
| 134 |
+
completion_tokens = 0
|
| 135 |
+
try:
|
| 136 |
+
r = await client.post(OPENROUTER_URL, json=payload, headers=headers, timeout=120.0)
|
| 137 |
+
r.raise_for_status()
|
| 138 |
+
data = r.json()
|
| 139 |
+
response_text = data["choices"][0]["message"]["content"]
|
| 140 |
+
usage = data.get("usage", {})
|
| 141 |
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
| 142 |
+
completion_tokens = usage.get("completion_tokens", 0)
|
| 143 |
+
except Exception as e: # noqa: BLE001 — capture all for verdict logging
|
| 144 |
+
err = repr(e)[:300]
|
| 145 |
+
t1 = time.perf_counter()
|
| 146 |
+
cost_usd = (
|
| 147 |
+
(prompt_tokens / 1_000_000) * teacher["input_per_mtok"]
|
| 148 |
+
+ (completion_tokens / 1_000_000) * teacher["output_per_mtok"]
|
| 149 |
+
)
|
| 150 |
+
return {
|
| 151 |
+
"state_id": state["state_id"],
|
| 152 |
+
"teacher_slug": teacher["slug"],
|
| 153 |
+
"response_text": response_text,
|
| 154 |
+
"latency_s": round(t1 - t0, 3),
|
| 155 |
+
"prompt_tokens": prompt_tokens,
|
| 156 |
+
"completion_tokens": completion_tokens,
|
| 157 |
+
"cost_usd": round(cost_usd, 6),
|
| 158 |
+
"error": err,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
async def replay_trace(
|
| 163 |
+
states: Sequence[TraceState],
|
| 164 |
+
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
|
| 165 |
+
max_total_usd: float = 5.0,
|
| 166 |
+
api_key: str | None = None,
|
| 167 |
+
) -> list[TeacherCallResult]:
|
| 168 |
+
"""Query all (state, teacher) pairs in parallel within each state.
|
| 169 |
+
|
| 170 |
+
Hard-caps spend at max_total_usd. Returns per-call results; aggregate
|
| 171 |
+
by state_id downstream to extract DPO pairs.
|
| 172 |
+
"""
|
| 173 |
+
import httpx # lazy import — only required for live-API replay
|
| 174 |
+
|
| 175 |
+
api_key = api_key or _load_api_key()
|
| 176 |
+
results: list[TeacherCallResult] = []
|
| 177 |
+
cumulative_cost = 0.0
|
| 178 |
+
async with httpx.AsyncClient() as client:
|
| 179 |
+
for state in states:
|
| 180 |
+
tasks = [_call_teacher(client, state, t, api_key) for t in teachers]
|
| 181 |
+
state_results = await asyncio.gather(*tasks)
|
| 182 |
+
results.extend(state_results)
|
| 183 |
+
cumulative_cost += sum(
|
| 184 |
+
r["cost_usd"] for r in state_results if r["error"] is None
|
| 185 |
+
)
|
| 186 |
+
if cumulative_cost > max_total_usd:
|
| 187 |
+
break
|
| 188 |
+
return results
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
# DPO pair extraction
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
|
| 195 |
+
def _normalize_action(text: str | None) -> str:
|
| 196 |
+
"""Normalize an action string for cluster-by-equality.
|
| 197 |
+
|
| 198 |
+
For real agentic traces, this should parse the tool call (name + args) and
|
| 199 |
+
return a canonical form. For the skeleton we just normalize whitespace.
|
| 200 |
+
"""
|
| 201 |
+
if text is None:
|
| 202 |
+
return ""
|
| 203 |
+
return " ".join(text.split()).strip().lower()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def extract_dpo_pairs(
|
| 207 |
+
states: Sequence[TraceState],
|
| 208 |
+
teacher_actions: Sequence[TeacherCallResult],
|
| 209 |
+
agreement_threshold: int = 2,
|
| 210 |
+
) -> list[DPOPair]:
|
| 211 |
+
"""Convert teacher-disagreement-with-student into preference pairs.
|
| 212 |
+
|
| 213 |
+
Logic:
|
| 214 |
+
- Group teacher_actions by state_id.
|
| 215 |
+
- For each state, normalize all teacher responses + student response.
|
| 216 |
+
- If `agreement_threshold` or more teachers agree on action X,
|
| 217 |
+
and student_action != X:
|
| 218 |
+
emit (chosen=X, rejected=student_action) pair
|
| 219 |
+
- Otherwise no pair (no signal).
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
states: sequence of TraceState (must include state["student_action"]).
|
| 223 |
+
teacher_actions: flat list of TeacherCallResult from replay_trace().
|
| 224 |
+
agreement_threshold: min number of teachers that must agree for a pair.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
List of DPOPair dicts ready for DPO training.
|
| 228 |
+
"""
|
| 229 |
+
by_state: dict[str, list[TeacherCallResult]] = {}
|
| 230 |
+
for tr in teacher_actions:
|
| 231 |
+
if tr["error"] is None and tr["response_text"] is not None:
|
| 232 |
+
by_state.setdefault(tr["state_id"], []).append(tr)
|
| 233 |
+
|
| 234 |
+
state_lookup = {s["state_id"]: s for s in states}
|
| 235 |
+
pairs: list[DPOPair] = []
|
| 236 |
+
|
| 237 |
+
for state_id, calls in by_state.items():
|
| 238 |
+
if state_id not in state_lookup:
|
| 239 |
+
continue
|
| 240 |
+
state = state_lookup[state_id]
|
| 241 |
+
student_norm = _normalize_action(state["student_action"])
|
| 242 |
+
|
| 243 |
+
teacher_norm = [_normalize_action(c["response_text"]) for c in calls]
|
| 244 |
+
counts = Counter(teacher_norm)
|
| 245 |
+
|
| 246 |
+
for action, n in counts.items():
|
| 247 |
+
if n >= agreement_threshold and action != student_norm and action:
|
| 248 |
+
# Find the original (un-normalized) teacher response for the chosen action.
|
| 249 |
+
chosen_text = next(
|
| 250 |
+
c["response_text"] for c, norm in zip(calls, teacher_norm)
|
| 251 |
+
if norm == action and c["response_text"]
|
| 252 |
+
)
|
| 253 |
+
pairs.append({
|
| 254 |
+
"state_id": state_id,
|
| 255 |
+
"state_messages": state["messages"],
|
| 256 |
+
"chosen": chosen_text,
|
| 257 |
+
"rejected": state["student_action"],
|
| 258 |
+
"n_teachers_agreeing": n,
|
| 259 |
+
})
|
| 260 |
+
break # one pair per state — the most-agreed-upon teacher action
|
| 261 |
+
|
| 262 |
+
return pairs
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None:
|
| 266 |
+
p = Path(path)
|
| 267 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
p.write_text("\n".join(json.dumps(d) for d in pairs) + "\n")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
__all__ = [
|
| 272 |
+
"DEFAULT_TEACHERS",
|
| 273 |
+
"TeacherSpec",
|
| 274 |
+
"TraceState",
|
| 275 |
+
"TeacherCallResult",
|
| 276 |
+
"DPOPair",
|
| 277 |
+
"replay_trace",
|
| 278 |
+
"extract_dpo_pairs",
|
| 279 |
+
"save_pairs",
|
| 280 |
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication.trainer — TRL GRPOTrainer subclass + data collator.
|
| 2 |
+
|
| 3 |
+
Per docs/INTEGRATION_ARCHITECTURE.md § "Recipe A".
|
| 4 |
+
Per docs/adrs/ADR-003 (also wraps DiLoCo when training distributed).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
|
| 9 |
+
|
| 10 |
+
__all__ = ["ComposerReplicationTrainer"]
|
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels.
|
| 2 |
+
|
| 3 |
+
Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A".
|
| 4 |
+
Verified extension point: GRPOTrainer._compute_loss(model, inputs)
|
| 5 |
+
(DeepWiki audit of huggingface/trl, 2026-05-25).
|
| 6 |
+
|
| 7 |
+
Total loss:
|
| 8 |
+
total_loss = grpo_loss
|
| 9 |
+
+ alpha_sdpo * sdpo_kl_at_error_turns
|
| 10 |
+
+ beta_replay * trace_replay_dpo_loss
|
| 11 |
+
|
| 12 |
+
Where:
|
| 13 |
+
- grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches).
|
| 14 |
+
- sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and
|
| 15 |
+
teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only.
|
| 16 |
+
- trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from
|
| 17 |
+
N external teacher disagreement with the student.
|
| 18 |
+
|
| 19 |
+
The data collator (data_collator.py) is responsible for:
|
| 20 |
+
- Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint.
|
| 21 |
+
- Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere).
|
| 22 |
+
- Loading DPO pairs from the trace-replay output (see teacher_replay.py).
|
| 23 |
+
- Precomputing reference-policy logprobs for DPO.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import logging
|
| 29 |
+
from typing import Any
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
# These imports work when TRL is installed — they're not skeleton imports.
|
| 35 |
+
# The example_run.py guards against missing TRL with an import-time check.
|
| 36 |
+
try:
|
| 37 |
+
from trl import GRPOTrainer # type: ignore
|
| 38 |
+
except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL
|
| 39 |
+
GRPOTrainer = object # type: ignore — fallback so module imports without TRL
|
| 40 |
+
|
| 41 |
+
from composer_replication.opsd import generalized_jsd_loss
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
| 47 |
+
"""TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.
|
| 48 |
+
|
| 49 |
+
Args (in addition to GRPOTrainer's):
|
| 50 |
+
alpha_sdpo: weight on SDPO hint-distill loss. Set to 0 to disable
|
| 51 |
+
channel 2 (e.g. for the v0.1 ablation baseline).
|
| 52 |
+
beta_replay: weight on trace-replay DPO loss. Set to 0 to disable
|
| 53 |
+
channel 3 (e.g. for the Composer-recipe-only ablation arm).
|
| 54 |
+
sdpo_jsd_beta: beta param of generalized_jsd_loss (0=fwd KL, 0.5=JSD, 1=rev KL).
|
| 55 |
+
sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
|
| 56 |
+
sdpo_token_clip: per-token JSD clip for stability; None = no clip.
|
| 57 |
+
replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
*args: Any,
|
| 63 |
+
alpha_sdpo: float = 0.1,
|
| 64 |
+
beta_replay: float = 0.05,
|
| 65 |
+
sdpo_jsd_beta: float = 0.5,
|
| 66 |
+
sdpo_temperature: float = 1.0,
|
| 67 |
+
sdpo_token_clip: float | None = None,
|
| 68 |
+
replay_dpo_beta: float = 0.1,
|
| 69 |
+
**kwargs: Any,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(*args, **kwargs)
|
| 72 |
+
self.alpha_sdpo = alpha_sdpo
|
| 73 |
+
self.beta_replay = beta_replay
|
| 74 |
+
self.sdpo_jsd_beta = sdpo_jsd_beta
|
| 75 |
+
self.sdpo_temperature = sdpo_temperature
|
| 76 |
+
self.sdpo_token_clip = sdpo_token_clip
|
| 77 |
+
self.replay_dpo_beta = replay_dpo_beta
|
| 78 |
+
|
| 79 |
+
# ----------------------------------------------------------------------
|
| 80 |
+
# Loss override (the integration core)
|
| 81 |
+
# ----------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def _compute_loss(
|
| 84 |
+
self,
|
| 85 |
+
model: torch.nn.Module,
|
| 86 |
+
inputs: dict[str, torch.Tensor],
|
| 87 |
+
) -> torch.Tensor:
|
| 88 |
+
"""Override: total_loss = grpo + α*sdpo + β*replay."""
|
| 89 |
+
# Channel 1: standard GRPO loss
|
| 90 |
+
grpo_loss = super()._compute_loss(model, inputs)
|
| 91 |
+
|
| 92 |
+
# Channel 2: SDPO hint-distill at error sites
|
| 93 |
+
sdpo_kl = self._compute_sdpo_loss(model, inputs)
|
| 94 |
+
|
| 95 |
+
# Channel 3: trace-replay DPO from teacher disagreement
|
| 96 |
+
replay_dpo = self._compute_trace_replay_loss(model, inputs)
|
| 97 |
+
|
| 98 |
+
# Compose
|
| 99 |
+
total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo
|
| 100 |
+
|
| 101 |
+
# Log per-channel components (so we can ablate post-hoc)
|
| 102 |
+
if hasattr(self, "state") and getattr(self, "args", None) is not None:
|
| 103 |
+
log_steps = getattr(self.args, "logging_steps", 50)
|
| 104 |
+
if self.state.global_step % log_steps == 0:
|
| 105 |
+
self.log({ # type: ignore[attr-defined]
|
| 106 |
+
"loss/grpo": float(grpo_loss.detach()),
|
| 107 |
+
"loss/sdpo_kl": float(sdpo_kl.detach()),
|
| 108 |
+
"loss/trace_replay_dpo": float(replay_dpo.detach()),
|
| 109 |
+
"loss/total": float(total.detach()),
|
| 110 |
+
"loss/alpha_sdpo": self.alpha_sdpo,
|
| 111 |
+
"loss/beta_replay": self.beta_replay,
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
return total
|
| 115 |
+
|
| 116 |
+
# ----------------------------------------------------------------------
|
| 117 |
+
# Channel 2: SDPO hint-distill
|
| 118 |
+
# ----------------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
def _compute_sdpo_loss(
|
| 121 |
+
self,
|
| 122 |
+
model: torch.nn.Module,
|
| 123 |
+
inputs: dict[str, torch.Tensor],
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
"""Compute generalized_jsd_loss between student and hint-conditioned teacher.
|
| 126 |
+
|
| 127 |
+
Both come from the SAME model — teacher just has hint inserted into context.
|
| 128 |
+
Skipped (returns 0) if the batch has no error sites (data collator emits
|
| 129 |
+
empty ctx_teacher_input_ids).
|
| 130 |
+
"""
|
| 131 |
+
if (
|
| 132 |
+
self.alpha_sdpo == 0.0
|
| 133 |
+
or "ctx_teacher_input_ids" not in inputs
|
| 134 |
+
or inputs["ctx_teacher_input_ids"].numel() == 0
|
| 135 |
+
):
|
| 136 |
+
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
|
| 137 |
+
|
| 138 |
+
# Student forward (with grad, on the original-context input)
|
| 139 |
+
student_logits = model(input_ids=inputs["input_ids"]).logits
|
| 140 |
+
|
| 141 |
+
# Teacher forward (no grad — same model, hint-conditioned context)
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
|
| 144 |
+
|
| 145 |
+
# NOTE: in real implementation, ctx_teacher and ctx_student must be the
|
| 146 |
+
# SAME LENGTH at the post-hint section so logits align position-by-position.
|
| 147 |
+
# The data collator pads/aligns. The skeleton trusts that's done correctly.
|
| 148 |
+
if student_logits.shape != teacher_logits.shape:
|
| 149 |
+
logger.warning(
|
| 150 |
+
"SDPO logit shape mismatch: student=%s vs teacher=%s. "
|
| 151 |
+
"Skipping SDPO loss for this step. Check the data collator's "
|
| 152 |
+
"alignment — the post-hint section must have identical token-counts.",
|
| 153 |
+
student_logits.shape, teacher_logits.shape,
|
| 154 |
+
)
|
| 155 |
+
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
|
| 156 |
+
|
| 157 |
+
return generalized_jsd_loss(
|
| 158 |
+
student_logits=student_logits,
|
| 159 |
+
teacher_logits=teacher_logits,
|
| 160 |
+
labels=inputs.get("sdpo_loss_mask"), # error-turn token mask
|
| 161 |
+
beta=self.sdpo_jsd_beta,
|
| 162 |
+
temperature=self.sdpo_temperature,
|
| 163 |
+
token_clip=self.sdpo_token_clip,
|
| 164 |
+
reduction="batchmean",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# ----------------------------------------------------------------------
|
| 168 |
+
# Channel 3: trace-replay DPO
|
| 169 |
+
# ----------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
def _compute_trace_replay_loss(
|
| 172 |
+
self,
|
| 173 |
+
model: torch.nn.Module,
|
| 174 |
+
inputs: dict[str, torch.Tensor],
|
| 175 |
+
) -> torch.Tensor:
|
| 176 |
+
"""Standard DPO loss using (chosen, rejected) pairs from teacher disagreement.
|
| 177 |
+
|
| 178 |
+
DPO loss formula (Rafailov et al. 2023):
|
| 179 |
+
L = -log σ(β · (logπ(chosen) - logπ_ref(chosen)
|
| 180 |
+
- logπ(rejected) + logπ_ref(rejected)))
|
| 181 |
+
|
| 182 |
+
Where logπ_ref are precomputed by the data collator using the
|
| 183 |
+
reference (init student) policy.
|
| 184 |
+
"""
|
| 185 |
+
if (
|
| 186 |
+
self.beta_replay == 0.0
|
| 187 |
+
or "dpo_chosen_input_ids" not in inputs
|
| 188 |
+
or inputs["dpo_chosen_input_ids"].numel() == 0
|
| 189 |
+
):
|
| 190 |
+
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
|
| 191 |
+
|
| 192 |
+
# Forward passes for chosen and rejected, gather logprobs at response tokens
|
| 193 |
+
chosen_logprobs = self._sequence_logprobs(
|
| 194 |
+
model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"]
|
| 195 |
+
)
|
| 196 |
+
rejected_logprobs = self._sequence_logprobs(
|
| 197 |
+
model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"]
|
| 201 |
+
ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"]
|
| 202 |
+
|
| 203 |
+
logits = self.replay_dpo_beta * (
|
| 204 |
+
(chosen_logprobs - ref_chosen_logprobs)
|
| 205 |
+
- (rejected_logprobs - ref_rejected_logprobs)
|
| 206 |
+
)
|
| 207 |
+
return -F.logsigmoid(logits).mean()
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def _sequence_logprobs(
|
| 211 |
+
model: torch.nn.Module,
|
| 212 |
+
input_ids: torch.Tensor,
|
| 213 |
+
response_mask: torch.Tensor,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
"""Sum logprob of response tokens given the prompt prefix.
|
| 216 |
+
|
| 217 |
+
Standard DPO accounting: we only score the response tokens (where
|
| 218 |
+
response_mask == 1), not the prompt tokens.
|
| 219 |
+
"""
|
| 220 |
+
outputs = model(input_ids=input_ids)
|
| 221 |
+
# Shift for next-token prediction: logits[t] predicts input_ids[t+1]
|
| 222 |
+
logits = outputs.logits[:, :-1, :]
|
| 223 |
+
targets = input_ids[:, 1:]
|
| 224 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 225 |
+
token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
|
| 226 |
+
# Mask out prompt + padding; sum response-token logprobs
|
| 227 |
+
masked = token_logprobs * response_mask[:, 1:].float()
|
| 228 |
+
return masked.sum(dim=-1)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _device_of(model: torch.nn.Module) -> torch.device:
|
| 232 |
+
"""Return the device of any parameter of the model — robust to FSDP/DDP wrappers."""
|
| 233 |
+
return next(model.parameters()).device
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
__all__ = ["ComposerReplicationTrainer"]
|
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch.
|
| 2 |
+
|
| 3 |
+
Pipeline:
|
| 4 |
+
1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003).
|
| 5 |
+
2. Tokenize each turn of the trace.
|
| 6 |
+
3. Detect error sites (turns where a tool call failed) using a configurable predicate.
|
| 7 |
+
4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary.
|
| 8 |
+
5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position.
|
| 9 |
+
6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere.
|
| 10 |
+
7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step.
|
| 11 |
+
|
| 12 |
+
The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its
|
| 13 |
+
`inputs` argument. See `trl_path/composer_trainer.py` for the consumer side.
|
| 14 |
+
|
| 15 |
+
Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss
|
| 16 |
+
requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's
|
| 17 |
+
why we pad/align here rather than inside the loss function. The post-hint section of
|
| 18 |
+
ctx_teacher must have token-by-token alignment with the same section of ctx_student.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from collections.abc import Callable, Sequence
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
+
from typing import Any, TypedDict
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Types
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
class TraceTurn(TypedDict, total=False):
|
| 35 |
+
"""One turn of an agentic trace."""
|
| 36 |
+
role: str # "user" | "assistant" | "tool"
|
| 37 |
+
content: str # text or tool result
|
| 38 |
+
tool_call: dict | None # parsed tool call, if assistant-issued
|
| 39 |
+
tool_error: str | None # error_kind from the env, e.g. "tool_not_found"
|
| 40 |
+
error_meta: dict # extra info for hint generator (available_tools, etc.)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TraceExample(TypedDict, total=False):
|
| 44 |
+
"""One training example: a (trace, optional DPO pairs) tuple."""
|
| 45 |
+
trace_id: str
|
| 46 |
+
turns: list[TraceTurn]
|
| 47 |
+
final_reward: float # RLVR scalar (test-pass etc.) at trajectory end
|
| 48 |
+
dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Tokenizer protocol — duck-typed against HF AutoTokenizer
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
class TokenizerLike:
|
| 56 |
+
"""Minimal protocol the collator needs from a tokenizer.
|
| 57 |
+
|
| 58 |
+
Compatible with HuggingFace `AutoTokenizer` instances (the typical case),
|
| 59 |
+
but also satisfiable by simpler stubs for unit-testing.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
pad_token_id: int
|
| 63 |
+
|
| 64 |
+
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover
|
| 65 |
+
...
|
| 66 |
+
|
| 67 |
+
def apply_chat_template( # pragma: no cover
|
| 68 |
+
self, messages: list[dict], **kwargs: Any
|
| 69 |
+
) -> str | list[int]:
|
| 70 |
+
...
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
# Configuration
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class CollatorConfig:
|
| 79 |
+
"""Tunables for ComposerDataCollator."""
|
| 80 |
+
max_seq_len: int = 4096
|
| 81 |
+
max_dpo_seq_len: int = 2048
|
| 82 |
+
pad_token_id: int = 0
|
| 83 |
+
ignore_index: int = -100 # standard HF "ignore in loss" sentinel
|
| 84 |
+
|
| 85 |
+
# SDPO behavior
|
| 86 |
+
enable_sdpo: bool = True
|
| 87 |
+
hint_generator: Callable[[str, dict], str | None] | None = None
|
| 88 |
+
"""Callable error_kind, error_meta -> hint_text (or None to skip)."""
|
| 89 |
+
|
| 90 |
+
# Trace-replay DPO behavior
|
| 91 |
+
enable_replay_dpo: bool = True
|
| 92 |
+
|
| 93 |
+
# Reward shaping
|
| 94 |
+
rlvr_reward_key: str = "final_reward"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Helpers
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
def _is_error_turn(turn: TraceTurn) -> bool:
|
| 102 |
+
"""Predicate: is this turn an error site that should trigger SDPO?"""
|
| 103 |
+
return turn.get("tool_error") is not None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]:
|
| 107 |
+
"""Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template."""
|
| 108 |
+
return [
|
| 109 |
+
{"role": t["role"], "content": t["content"]}
|
| 110 |
+
for t in turns if t.get("content")
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]:
|
| 115 |
+
"""Right-pad with pad_id, or right-truncate to target_len."""
|
| 116 |
+
if len(seq) >= target_len:
|
| 117 |
+
return seq[:target_len]
|
| 118 |
+
return seq + [pad_id] * (target_len - len(seq))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
# The collator
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class ComposerDataCollator:
|
| 127 |
+
"""Build trainer-ready batches from raw traces + optional DPO pairs.
|
| 128 |
+
|
| 129 |
+
Usage:
|
| 130 |
+
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
|
| 131 |
+
batch = collator([trace_example_0, trace_example_1, ...])
|
| 132 |
+
# batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer
|
| 133 |
+
|
| 134 |
+
The dict contains:
|
| 135 |
+
# Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer)
|
| 136 |
+
- input_ids: (B, T_max)
|
| 137 |
+
- attention_mask: (B, T_max)
|
| 138 |
+
- response_mask: (B, T_max)
|
| 139 |
+
- rewards: (B,)
|
| 140 |
+
|
| 141 |
+
# Channel 2 (SDPO hint-distill) — present when any example has error turns
|
| 142 |
+
- ctx_teacher_input_ids: (B, T_max)
|
| 143 |
+
- sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens
|
| 144 |
+
|
| 145 |
+
# Channel 3 (trace-replay DPO) — present when any example has dpo_pairs
|
| 146 |
+
- dpo_chosen_input_ids: (B', T_dpo)
|
| 147 |
+
- dpo_chosen_response_mask: (B', T_dpo)
|
| 148 |
+
- dpo_rejected_input_ids: (B', T_dpo)
|
| 149 |
+
- dpo_rejected_response_mask: (B', T_dpo)
|
| 150 |
+
# ref_logprobs are NOT computed here — the trainer's reference-policy
|
| 151 |
+
# forward pass at training time produces them.
|
| 152 |
+
"""
|
| 153 |
+
tokenizer: TokenizerLike
|
| 154 |
+
config: CollatorConfig = field(default_factory=CollatorConfig)
|
| 155 |
+
|
| 156 |
+
def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
|
| 157 |
+
out: dict[str, torch.Tensor] = {}
|
| 158 |
+
|
| 159 |
+
# --- Channel 1: GRPO core fields ---
|
| 160 |
+
out.update(self._build_grpo_fields(batch))
|
| 161 |
+
|
| 162 |
+
# --- Channel 2: SDPO hint-distill fields ---
|
| 163 |
+
if self.config.enable_sdpo:
|
| 164 |
+
sdpo = self._build_sdpo_fields(batch)
|
| 165 |
+
if sdpo is not None:
|
| 166 |
+
out.update(sdpo)
|
| 167 |
+
|
| 168 |
+
# --- Channel 3: trace-replay DPO fields ---
|
| 169 |
+
if self.config.enable_replay_dpo:
|
| 170 |
+
dpo = self._build_dpo_fields(batch)
|
| 171 |
+
if dpo is not None:
|
| 172 |
+
out.update(dpo)
|
| 173 |
+
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
# ----------------------------------------------------------------------
|
| 177 |
+
# Channel 1: standard GRPO inputs
|
| 178 |
+
# ----------------------------------------------------------------------
|
| 179 |
+
|
| 180 |
+
def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
|
| 181 |
+
input_ids_list: list[list[int]] = []
|
| 182 |
+
response_masks_list: list[list[int]] = []
|
| 183 |
+
rewards: list[float] = []
|
| 184 |
+
|
| 185 |
+
for ex in batch:
|
| 186 |
+
ids, resp_mask = self._tokenize_trace(ex["turns"])
|
| 187 |
+
input_ids_list.append(ids)
|
| 188 |
+
response_masks_list.append(resp_mask)
|
| 189 |
+
rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0)))
|
| 190 |
+
|
| 191 |
+
max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list))
|
| 192 |
+
|
| 193 |
+
input_ids = torch.tensor(
|
| 194 |
+
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list],
|
| 195 |
+
dtype=torch.long,
|
| 196 |
+
)
|
| 197 |
+
response_mask = torch.tensor(
|
| 198 |
+
[_pad_or_truncate(m, max_len, 0) for m in response_masks_list],
|
| 199 |
+
dtype=torch.long,
|
| 200 |
+
)
|
| 201 |
+
attention_mask = (input_ids != self.config.pad_token_id).long()
|
| 202 |
+
|
| 203 |
+
return {
|
| 204 |
+
"input_ids": input_ids,
|
| 205 |
+
"attention_mask": attention_mask,
|
| 206 |
+
"response_mask": response_mask,
|
| 207 |
+
"rewards": torch.tensor(rewards, dtype=torch.float),
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# ----------------------------------------------------------------------
|
| 211 |
+
# Channel 2: SDPO hint-distill inputs
|
| 212 |
+
# ----------------------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
def _build_sdpo_fields(
|
| 215 |
+
self, batch: Sequence[TraceExample]
|
| 216 |
+
) -> dict[str, torch.Tensor] | None:
|
| 217 |
+
"""Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length."""
|
| 218 |
+
if self.config.hint_generator is None:
|
| 219 |
+
return None # nothing to do without a hint generator
|
| 220 |
+
|
| 221 |
+
ctx_teacher_list: list[list[int]] = []
|
| 222 |
+
sdpo_mask_list: list[list[int]] = []
|
| 223 |
+
any_error_sites = False
|
| 224 |
+
|
| 225 |
+
for ex in batch:
|
| 226 |
+
ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"])
|
| 227 |
+
ctx_teacher_list.append(ctx_teacher_ids)
|
| 228 |
+
sdpo_mask_list.append(sdpo_mask)
|
| 229 |
+
any_error_sites = any_error_sites or has_errors
|
| 230 |
+
|
| 231 |
+
if not any_error_sites:
|
| 232 |
+
return None # batch has no error sites — SDPO is a no-op for this step
|
| 233 |
+
|
| 234 |
+
max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list))
|
| 235 |
+
|
| 236 |
+
ctx_teacher = torch.tensor(
|
| 237 |
+
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list],
|
| 238 |
+
dtype=torch.long,
|
| 239 |
+
)
|
| 240 |
+
sdpo_mask = torch.tensor(
|
| 241 |
+
[_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list],
|
| 242 |
+
dtype=torch.long,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
"ctx_teacher_input_ids": ctx_teacher,
|
| 247 |
+
"sdpo_loss_mask": sdpo_mask,
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
def _build_hint_injected_trace(
|
| 251 |
+
self, turns: Sequence[TraceTurn]
|
| 252 |
+
) -> tuple[list[int], list[int], bool]:
|
| 253 |
+
"""Walk the trace; at each error-turn boundary, inject a hint and mark
|
| 254 |
+
the post-hint tokens as in-loss.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
(ctx_teacher_ids, sdpo_loss_mask, any_error_sites)
|
| 258 |
+
"""
|
| 259 |
+
if self.config.hint_generator is None:
|
| 260 |
+
# Caller responsibility — short-circuited by the dispatch.
|
| 261 |
+
empty: list[int] = []
|
| 262 |
+
return empty, empty, False
|
| 263 |
+
|
| 264 |
+
teacher_messages: list[dict] = []
|
| 265 |
+
teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text)
|
| 266 |
+
any_errors = False
|
| 267 |
+
|
| 268 |
+
for turn in turns:
|
| 269 |
+
if _is_error_turn(turn):
|
| 270 |
+
hint_text = self.config.hint_generator(
|
| 271 |
+
turn.get("tool_error", "unknown"),
|
| 272 |
+
turn.get("error_meta", {}),
|
| 273 |
+
)
|
| 274 |
+
if hint_text:
|
| 275 |
+
any_errors = True
|
| 276 |
+
# Inject hint as a system-style addendum BEFORE the assistant's response
|
| 277 |
+
teacher_messages.append({"role": "system", "content": hint_text})
|
| 278 |
+
teacher_loss_segments.append((False, hint_text))
|
| 279 |
+
if turn.get("content"):
|
| 280 |
+
teacher_messages.append({
|
| 281 |
+
"role": turn.get("role", "assistant"),
|
| 282 |
+
"content": turn["content"],
|
| 283 |
+
})
|
| 284 |
+
teacher_loss_segments.append((True, turn["content"])) # post-hint tokens = loss
|
| 285 |
+
continue
|
| 286 |
+
# Non-error turn (or hint generator returned None) — passthrough
|
| 287 |
+
if turn.get("content"):
|
| 288 |
+
teacher_messages.append({
|
| 289 |
+
"role": turn.get("role", "assistant"),
|
| 290 |
+
"content": turn["content"],
|
| 291 |
+
})
|
| 292 |
+
teacher_loss_segments.append((False, turn["content"]))
|
| 293 |
+
|
| 294 |
+
# Tokenize the full teacher conversation
|
| 295 |
+
teacher_ids = self._tokenize_messages(teacher_messages)
|
| 296 |
+
# Build the per-token loss mask by tokenizing each segment and concatenating
|
| 297 |
+
sdpo_mask = self._build_segment_mask(teacher_loss_segments)
|
| 298 |
+
# Truncate mask to teacher_ids length if tokenization round-tripped slightly differently
|
| 299 |
+
sdpo_mask = sdpo_mask[: len(teacher_ids)]
|
| 300 |
+
if len(sdpo_mask) < len(teacher_ids):
|
| 301 |
+
sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask))
|
| 302 |
+
|
| 303 |
+
return teacher_ids, sdpo_mask, any_errors
|
| 304 |
+
|
| 305 |
+
def _build_segment_mask(
|
| 306 |
+
self, segments: Sequence[tuple[bool, str]]
|
| 307 |
+
) -> list[int]:
|
| 308 |
+
"""For each (is_loss, text) segment, tokenize and emit per-token mask values.
|
| 309 |
+
|
| 310 |
+
Loss-active tokens get 1; non-loss tokens get -100 (ignore_index).
|
| 311 |
+
"""
|
| 312 |
+
out: list[int] = []
|
| 313 |
+
for is_loss, text in segments:
|
| 314 |
+
seg_ids = self._tokenize_text(text)
|
| 315 |
+
mask_value = 1 if is_loss else self.config.ignore_index
|
| 316 |
+
out.extend([mask_value] * len(seg_ids))
|
| 317 |
+
return out
|
| 318 |
+
|
| 319 |
+
# ----------------------------------------------------------------------
|
| 320 |
+
# Channel 3: trace-replay DPO inputs
|
| 321 |
+
# ----------------------------------------------------------------------
|
| 322 |
+
|
| 323 |
+
def _build_dpo_fields(
|
| 324 |
+
self, batch: Sequence[TraceExample]
|
| 325 |
+
) -> dict[str, torch.Tensor] | None:
|
| 326 |
+
"""Tokenize chosen/rejected pairs from teacher disagreement.
|
| 327 |
+
|
| 328 |
+
DPO accounting requires:
|
| 329 |
+
- chosen_input_ids = prompt + chosen_response
|
| 330 |
+
- rejected_input_ids = prompt + rejected_response
|
| 331 |
+
- response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss)
|
| 332 |
+
"""
|
| 333 |
+
all_chosen: list[list[int]] = []
|
| 334 |
+
all_rejected: list[list[int]] = []
|
| 335 |
+
all_chosen_resp_mask: list[list[int]] = []
|
| 336 |
+
all_rejected_resp_mask: list[list[int]] = []
|
| 337 |
+
|
| 338 |
+
for ex in batch:
|
| 339 |
+
for pair in ex.get("dpo_pairs") or []:
|
| 340 |
+
prompt_msgs = pair.get("state_messages", [])
|
| 341 |
+
prompt_ids = self._tokenize_messages(prompt_msgs)
|
| 342 |
+
chosen_ids = self._tokenize_text(pair["chosen"])
|
| 343 |
+
rejected_ids = self._tokenize_text(pair["rejected"])
|
| 344 |
+
|
| 345 |
+
chosen_full = prompt_ids + chosen_ids
|
| 346 |
+
rejected_full = prompt_ids + rejected_ids
|
| 347 |
+
|
| 348 |
+
# response_mask is 0 over prompt, 1 over response
|
| 349 |
+
chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids)
|
| 350 |
+
rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids)
|
| 351 |
+
|
| 352 |
+
all_chosen.append(chosen_full)
|
| 353 |
+
all_rejected.append(rejected_full)
|
| 354 |
+
all_chosen_resp_mask.append(chosen_mask)
|
| 355 |
+
all_rejected_resp_mask.append(rejected_mask)
|
| 356 |
+
|
| 357 |
+
if not all_chosen:
|
| 358 |
+
return None # no DPO pairs in this batch
|
| 359 |
+
|
| 360 |
+
cap = self.config.max_dpo_seq_len
|
| 361 |
+
max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected)))
|
| 362 |
+
|
| 363 |
+
return {
|
| 364 |
+
"dpo_chosen_input_ids": torch.tensor(
|
| 365 |
+
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen],
|
| 366 |
+
dtype=torch.long,
|
| 367 |
+
),
|
| 368 |
+
"dpo_chosen_response_mask": torch.tensor(
|
| 369 |
+
[_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask],
|
| 370 |
+
dtype=torch.long,
|
| 371 |
+
),
|
| 372 |
+
"dpo_rejected_input_ids": torch.tensor(
|
| 373 |
+
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected],
|
| 374 |
+
dtype=torch.long,
|
| 375 |
+
),
|
| 376 |
+
"dpo_rejected_response_mask": torch.tensor(
|
| 377 |
+
[_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask],
|
| 378 |
+
dtype=torch.long,
|
| 379 |
+
),
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
# ----------------------------------------------------------------------
|
| 383 |
+
# Tokenization helpers
|
| 384 |
+
# ----------------------------------------------------------------------
|
| 385 |
+
|
| 386 |
+
def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]:
|
| 387 |
+
"""Tokenize an entire trace; return (ids, response_mask).
|
| 388 |
+
|
| 389 |
+
response_mask = 1 over assistant turns (those are the loss-bearing tokens
|
| 390 |
+
for GRPO), 0 over user/tool turns (prompt context).
|
| 391 |
+
"""
|
| 392 |
+
all_ids: list[int] = []
|
| 393 |
+
resp_mask: list[int] = []
|
| 394 |
+
for turn in turns:
|
| 395 |
+
if not turn.get("content"):
|
| 396 |
+
continue
|
| 397 |
+
ids = self._tokenize_text(turn["content"])
|
| 398 |
+
mask_value = 1 if turn.get("role") == "assistant" else 0
|
| 399 |
+
all_ids.extend(ids)
|
| 400 |
+
resp_mask.extend([mask_value] * len(ids))
|
| 401 |
+
return all_ids, resp_mask
|
| 402 |
+
|
| 403 |
+
def _tokenize_text(self, text: str) -> list[int]:
|
| 404 |
+
"""Tokenize plain text via the tokenizer's __call__."""
|
| 405 |
+
result = self.tokenizer(text, add_special_tokens=False)
|
| 406 |
+
ids = result["input_ids"]
|
| 407 |
+
if hasattr(ids, "tolist"):
|
| 408 |
+
ids = ids.tolist()
|
| 409 |
+
# HF tokenizers often return list[list[int]] when batch-shaped; flatten if so
|
| 410 |
+
if ids and isinstance(ids[0], list):
|
| 411 |
+
ids = ids[0]
|
| 412 |
+
return list(ids)
|
| 413 |
+
|
| 414 |
+
def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]:
|
| 415 |
+
"""Tokenize a chat-formatted list of messages.
|
| 416 |
+
|
| 417 |
+
Tries apply_chat_template first; falls back to concatenated content if not available.
|
| 418 |
+
"""
|
| 419 |
+
if not messages:
|
| 420 |
+
return []
|
| 421 |
+
try:
|
| 422 |
+
ids = self.tokenizer.apply_chat_template(
|
| 423 |
+
list(messages), tokenize=True, add_generation_prompt=False
|
| 424 |
+
)
|
| 425 |
+
if hasattr(ids, "tolist"):
|
| 426 |
+
ids = ids.tolist()
|
| 427 |
+
return list(ids)
|
| 428 |
+
except (AttributeError, NotImplementedError, TypeError):
|
| 429 |
+
# Stub tokenizer or no chat template defined — fall back to concatenated content
|
| 430 |
+
text = "\n".join(m.get("content", "") for m in messages)
|
| 431 |
+
return self._tokenize_text(text)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
__all__ = [
|
| 435 |
+
"ComposerDataCollator",
|
| 436 |
+
"CollatorConfig",
|
| 437 |
+
"TraceTurn",
|
| 438 |
+
"TraceExample",
|
| 439 |
+
"TokenizerLike",
|
| 440 |
+
]
|
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quickstart: Qwen2.5-0.5B-Instruct on CPU
|
| 2 |
+
|
| 3 |
+
Run the Composer Replication Framework's 3-channel loss composition end-to-end
|
| 4 |
+
on a small open model in under 5 minutes on CPU.
|
| 5 |
+
|
| 6 |
+
## Setup
|
| 7 |
+
|
| 8 |
+
```bash
|
| 9 |
+
cd /path/to/composer-replication-framework
|
| 10 |
+
pip install -e .
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
(`-e` for editable install — picks up local code changes without re-installing.)
|
| 14 |
+
|
| 15 |
+
## Run
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python examples/qwen_05b_quickstart/run.py
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Expected output
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
[quickstart] loading Qwen/Qwen2.5-0.5B-Instruct (CPU, fp32) ...
|
| 25 |
+
[quickstart] loaded — 0.494B params
|
| 26 |
+
[quickstart] building real chat-template batch ...
|
| 27 |
+
[quickstart] running 5 backward steps ...
|
| 28 |
+
step 0: total=0.7390 lm_ce=0.7385 sdpo=0.0000 dpo=0.0114 finite=True
|
| 29 |
+
step 1: total=0.2090 lm_ce=0.2086 sdpo=0.0000 dpo=0.0084 finite=True
|
| 30 |
+
step 2: total=0.0501 lm_ce=0.0496 sdpo=0.0000 dpo=0.0093 finite=True
|
| 31 |
+
step 3: total=0.0094 lm_ce=0.0089 sdpo=0.0000 dpo=0.0094 finite=True
|
| 32 |
+
step 4: total=0.0031 lm_ce=0.0029 sdpo=0.0000 dpo=0.0044 finite=True
|
| 33 |
+
|
| 34 |
+
========================================================
|
| 35 |
+
Initial loss: 0.7390
|
| 36 |
+
Final loss: 0.0031
|
| 37 |
+
Reduction: 99.6%
|
| 38 |
+
Verdict: PASS
|
| 39 |
+
========================================================
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## What this demonstrates
|
| 43 |
+
|
| 44 |
+
- `build_batch(tokenizer)` produces a real chat-template-formatted batch
|
| 45 |
+
with all keys the 3-channel loss composer needs.
|
| 46 |
+
- `compose_loss(model, batch, alpha_sdpo, beta_replay)` returns
|
| 47 |
+
`LossComponents` with per-channel breakdown.
|
| 48 |
+
- Backward pass through `components.total` flows into all three channels:
|
| 49 |
+
- `lm_ce`: the GRPO stub (cross-entropy on response tokens, the limit
|
| 50 |
+
GRPO converges to under deterministic rewards).
|
| 51 |
+
- `sdpo_jsd`: hint-distillation between student logits and
|
| 52 |
+
hint-conditioned-teacher logits.
|
| 53 |
+
- `trace_replay_dpo`: DPO loss over (chosen, rejected) pairs from
|
| 54 |
+
multi-teacher disagreement.
|
| 55 |
+
|
| 56 |
+
## What this does NOT demonstrate
|
| 57 |
+
|
| 58 |
+
- Real GRPO rollouts + reward calculation (use `ComposerReplicationTrainer`
|
| 59 |
+
for that — a TRL `GRPOTrainer` subclass that wraps the same 3-channel
|
| 60 |
+
loss).
|
| 61 |
+
- Real teacher calls (those go through `composer_replication.replay_trace`
|
| 62 |
+
+ OpenRouter; ~$0.98 per 50-step trace at last measurement).
|
| 63 |
+
- DiLoCo outer loop (separate; needs `torchft-nightly` and is a
|
| 64 |
+
`make_diloco_outer_loop()` away once installed).
|
| 65 |
+
|
| 66 |
+
## Cost
|
| 67 |
+
|
| 68 |
+
- $0
|
| 69 |
+
- ~3-5 minutes wall-clock on CPU
|
| 70 |
+
- ~1 GB disk for Qwen2.5-0.5B weights (downloaded once into `~/.cache/huggingface`)
|
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Composer Replication Framework — quickstart smoke.
|
| 2 |
+
|
| 3 |
+
Runs the same 5-step CPU smoke as Spike 006, but using the installed package
|
| 4 |
+
API instead of importing from the spike directory.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
cd composer-replication-framework
|
| 8 |
+
pip install -e .
|
| 9 |
+
python examples/qwen_05b_quickstart/run.py
|
| 10 |
+
|
| 11 |
+
Expected: loss decreases from ~0.7 to <0.01 over 5 backward steps; all
|
| 12 |
+
gradients finite; ~3-5 min wall-clock on CPU; ~1 GB disk for Qwen2.5-0.5B
|
| 13 |
+
weights (downloaded once into HF cache).
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
# After `pip install -e .` from repo root, this import resolves cleanly.
|
| 22 |
+
from composer_replication import build_batch, compose_loss
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main() -> int:
|
| 29 |
+
print(f"[quickstart] loading {MODEL_REPO} (CPU, fp32) ...")
|
| 30 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 31 |
+
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
|
| 33 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32)
|
| 34 |
+
model = model.to("cpu")
|
| 35 |
+
model.train()
|
| 36 |
+
n_params_b = sum(p.numel() for p in model.parameters()) / 1e9
|
| 37 |
+
print(f"[quickstart] loaded — {n_params_b:.3f}B params")
|
| 38 |
+
|
| 39 |
+
print("[quickstart] building real chat-template batch ...")
|
| 40 |
+
batch = build_batch(tokenizer, device="cpu")
|
| 41 |
+
|
| 42 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
| 43 |
+
|
| 44 |
+
print("[quickstart] running 5 backward steps ...")
|
| 45 |
+
losses: list[float] = []
|
| 46 |
+
for step in range(5):
|
| 47 |
+
optimizer.zero_grad()
|
| 48 |
+
components = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
|
| 49 |
+
components.total.backward()
|
| 50 |
+
|
| 51 |
+
# Verify finite grads
|
| 52 |
+
finite = all(
|
| 53 |
+
(p.grad is None or torch.isfinite(p.grad).all().item())
|
| 54 |
+
for p in model.parameters()
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
optimizer.step()
|
| 58 |
+
|
| 59 |
+
c = components.detached()
|
| 60 |
+
losses.append(c["total"])
|
| 61 |
+
print(
|
| 62 |
+
f" step {step}: total={c['total']:.4f} "
|
| 63 |
+
f"lm_ce={c['lm_ce']:.4f} "
|
| 64 |
+
f"sdpo={c['sdpo_jsd']:.4f} "
|
| 65 |
+
f"dpo={c['trace_replay_dpo']:.4f} "
|
| 66 |
+
f"finite={finite}"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
initial, final = losses[0], losses[-1]
|
| 70 |
+
decreased = final < initial
|
| 71 |
+
print()
|
| 72 |
+
print("=" * 56)
|
| 73 |
+
print(f" Initial loss: {initial:.4f}")
|
| 74 |
+
print(f" Final loss: {final:.4f}")
|
| 75 |
+
print(f" Reduction: {(1 - final / initial) * 100:.1f}%")
|
| 76 |
+
print(f" Verdict: {'PASS' if decreased else 'FAIL'}")
|
| 77 |
+
print("=" * 56)
|
| 78 |
+
|
| 79 |
+
return 0 if decreased else 1
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
sys.exit(main())
|
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling>=1.21"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "composer-replication"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Open replication framework for Cursor Composer 2.5: GRPO + SDPO + multi-teacher trace-replay DPO with optional DiLoCo outer loop."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = { file = "LICENSE" }
|
| 11 |
+
authors = [
|
| 12 |
+
{ name = "Codeseys", email = "bbaladithyab@gmail.com" }
|
| 13 |
+
]
|
| 14 |
+
keywords = [
|
| 15 |
+
"rl-training",
|
| 16 |
+
"rlvr",
|
| 17 |
+
"grpo",
|
| 18 |
+
"sdpo",
|
| 19 |
+
"dpo",
|
| 20 |
+
"diloco",
|
| 21 |
+
"agentic",
|
| 22 |
+
"coding-agents",
|
| 23 |
+
"composer-2-5",
|
| 24 |
+
"cursor",
|
| 25 |
+
"trl",
|
| 26 |
+
"verl",
|
| 27 |
+
"openenv",
|
| 28 |
+
"torchft",
|
| 29 |
+
]
|
| 30 |
+
classifiers = [
|
| 31 |
+
"Development Status :: 3 - Alpha",
|
| 32 |
+
"Intended Audience :: Science/Research",
|
| 33 |
+
"License :: OSI Approved :: MIT License",
|
| 34 |
+
"Programming Language :: Python :: 3.10",
|
| 35 |
+
"Programming Language :: Python :: 3.11",
|
| 36 |
+
"Programming Language :: Python :: 3.12",
|
| 37 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 38 |
+
]
|
| 39 |
+
requires-python = ">=3.10"
|
| 40 |
+
dependencies = [
|
| 41 |
+
"torch>=2.0",
|
| 42 |
+
"transformers>=4.46",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
[project.optional-dependencies]
|
| 46 |
+
# Real teacher-replay over OpenRouter
|
| 47 |
+
replay = [
|
| 48 |
+
"httpx>=0.27",
|
| 49 |
+
]
|
| 50 |
+
# DiLoCo outer-loop optimizer
|
| 51 |
+
diloco = [
|
| 52 |
+
"torchft-nightly",
|
| 53 |
+
]
|
| 54 |
+
# Production training (TRL GRPOTrainer subclass)
|
| 55 |
+
train = [
|
| 56 |
+
"trl>=0.12",
|
| 57 |
+
"peft>=0.13",
|
| 58 |
+
"accelerate>=1.0",
|
| 59 |
+
"datasets>=3.0",
|
| 60 |
+
]
|
| 61 |
+
# Everything for development
|
| 62 |
+
dev = [
|
| 63 |
+
"pytest>=8.0",
|
| 64 |
+
"ruff>=0.6",
|
| 65 |
+
"composer-replication[replay,diloco,train]",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
[project.urls]
|
| 69 |
+
Homepage = "https://huggingface.co/Codeseys/composer-replication-framework"
|
| 70 |
+
Documentation = "https://huggingface.co/Codeseys/composer-replication-framework/blob/main/docs/INTEGRATION_ARCHITECTURE.md"
|
| 71 |
+
Repository = "https://huggingface.co/Codeseys/composer-replication-framework"
|
| 72 |
+
Issues = "https://huggingface.co/Codeseys/composer-replication-framework/discussions"
|
| 73 |
+
|
| 74 |
+
[tool.hatch.build.targets.wheel]
|
| 75 |
+
packages = ["composer_replication"]
|
| 76 |
+
|
| 77 |
+
[tool.hatch.build.targets.sdist]
|
| 78 |
+
include = [
|
| 79 |
+
"/composer_replication",
|
| 80 |
+
"/README.md",
|
| 81 |
+
"/LICENSE",
|
| 82 |
+
"/CITATION.cff",
|
| 83 |
+
"/CITATION.bib",
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
[tool.ruff]
|
| 87 |
+
line-length = 100
|
| 88 |
+
target-version = "py310"
|
| 89 |
+
|
| 90 |
+
[tool.ruff.lint]
|
| 91 |
+
select = ["E", "F", "W", "I", "N", "UP", "B"]
|
| 92 |
+
ignore = ["E501", "E741"]
|