Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """teacher_replay.py — N-teacher OpenRouter parallel client + DPO-pair extractor. | |
| This is channel 3 of the integrated trainer: at each step of a frozen agentic | |
| trace, query N pre-trained external teachers (frontier models from different | |
| labs) and convert teacher disagreement into preference pairs for DPO loss. | |
| Generalized from spike-001's `replay.py`. Verified economic floor (✅ spike 001): | |
| $0.98 mean per-trace cost ungated, $0.30/trace projected with VOI gating. | |
| Usage: | |
| from teacher_replay import replay_trace, extract_dpo_pairs | |
| # 1. Replay each step of a frozen trace with N teachers. | |
| teacher_actions = await replay_trace( | |
| states=trace_states, | |
| teachers=DEFAULT_TEACHERS, | |
| max_total_usd=10.0, | |
| ) | |
| # 2. Extract DPO pairs from teacher disagreement. | |
| pairs = extract_dpo_pairs( | |
| states=trace_states, | |
| student_actions=trace_student_actions, | |
| teacher_actions=teacher_actions, | |
| agreement_threshold=2, # at least 2/3 teachers must agree | |
| ) | |
| # → [{"chosen": …, "rejected": …, "state": …}, …] | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import time | |
| from collections import Counter | |
| from collections.abc import Sequence | |
| from pathlib import Path | |
| from typing import TypedDict | |
| # httpx is lazy-imported inside replay_trace() so that DPO-pair extraction | |
| # (the deterministic local logic) is testable without httpx installed. | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_TEACHERS: list["TeacherSpec"] = [ | |
| {"slug": "anthropic/claude-opus-4.7", "input_per_mtok": 15.0, "output_per_mtok": 75.0}, | |
| {"slug": "openai/gpt-5", "input_per_mtok": 1.25, "output_per_mtok": 10.0}, | |
| {"slug": "deepseek/deepseek-v4-pro", "input_per_mtok": 1.10, "output_per_mtok": 4.40}, | |
| ] | |
| OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" | |
| def _load_api_key() -> str: | |
| """Load OPENROUTER_API_KEY from env or ~/.hermes/.env (same as spike 001).""" | |
| if "OPENROUTER_API_KEY" in os.environ: | |
| return os.environ["OPENROUTER_API_KEY"] | |
| hermes_env = Path.home() / ".hermes" / ".env" | |
| if hermes_env.exists(): | |
| for line in hermes_env.read_text().splitlines(): | |
| line = line.strip() | |
| if line.startswith("OPENROUTER_API_KEY="): | |
| return line.split("=", 1)[1].strip().strip('"').strip("'") | |
| raise RuntimeError("OPENROUTER_API_KEY not found in env or ~/.hermes/.env") | |
| # --------------------------------------------------------------------------- | |
| # Types | |
| # --------------------------------------------------------------------------- | |
| class TeacherSpec(TypedDict): | |
| slug: str | |
| input_per_mtok: float | |
| output_per_mtok: float | |
| class TraceState(TypedDict): | |
| """One step of a frozen agentic trace.""" | |
| state_id: str # unique within the trace | |
| messages: list[dict] # the conversation up to and including this step's user prompt | |
| student_action: str # what the student actually did at this step (for DPO comparison) | |
| class TeacherCallResult(TypedDict): | |
| state_id: str | |
| teacher_slug: str | |
| response_text: str | None | |
| latency_s: float | |
| prompt_tokens: int | |
| completion_tokens: int | |
| cost_usd: float | |
| error: str | None | |
| class DPOPair(TypedDict): | |
| state_id: str | |
| state_messages: list[dict] | |
| chosen: str # teacher-consensus action | |
| rejected: str # student action | |
| n_teachers_agreeing: int | |
| # --------------------------------------------------------------------------- | |
| # Teacher replay | |
| # --------------------------------------------------------------------------- | |
| async def _call_teacher( | |
| client, # httpx.AsyncClient — lazy-typed so module imports without httpx | |
| state: TraceState, | |
| teacher: TeacherSpec, | |
| api_key: str, | |
| max_tokens: int = 200, | |
| ) -> TeacherCallResult: | |
| payload = { | |
| "model": teacher["slug"], | |
| "messages": state["messages"], | |
| "max_tokens": max_tokens, | |
| "temperature": 0.2, | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| "HTTP-Referer": "https://huggingface.co/Codeseys/composer-replication-framework", | |
| "X-Title": "composer-replication-framework spike-005-skeleton", | |
| } | |
| t0 = time.perf_counter() | |
| err = None | |
| response_text = None | |
| prompt_tokens = 0 | |
| completion_tokens = 0 | |
| try: | |
| r = await client.post(OPENROUTER_URL, json=payload, headers=headers, timeout=120.0) | |
| r.raise_for_status() | |
| data = r.json() | |
| response_text = data["choices"][0]["message"]["content"] | |
| usage = data.get("usage", {}) | |
| prompt_tokens = usage.get("prompt_tokens", 0) | |
| completion_tokens = usage.get("completion_tokens", 0) | |
| except Exception as e: # noqa: BLE001 — capture all for verdict logging | |
| err = repr(e)[:300] | |
| t1 = time.perf_counter() | |
| cost_usd = ( | |
| (prompt_tokens / 1_000_000) * teacher["input_per_mtok"] | |
| + (completion_tokens / 1_000_000) * teacher["output_per_mtok"] | |
| ) | |
| return { | |
| "state_id": state["state_id"], | |
| "teacher_slug": teacher["slug"], | |
| "response_text": response_text, | |
| "latency_s": round(t1 - t0, 3), | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "cost_usd": round(cost_usd, 6), | |
| "error": err, | |
| } | |
| async def replay_trace( | |
| states: Sequence[TraceState], | |
| teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS), | |
| max_total_usd: float = 5.0, | |
| api_key: str | None = None, | |
| ) -> list[TeacherCallResult]: | |
| """Query all (state, teacher) pairs in parallel within each state. | |
| Hard-caps spend at max_total_usd. Returns per-call results; aggregate | |
| by state_id downstream to extract DPO pairs. | |
| """ | |
| import httpx # lazy import — only required for live-API replay | |
| api_key = api_key or _load_api_key() | |
| results: list[TeacherCallResult] = [] | |
| cumulative_cost = 0.0 | |
| async with httpx.AsyncClient() as client: | |
| for state in states: | |
| tasks = [_call_teacher(client, state, t, api_key) for t in teachers] | |
| state_results = await asyncio.gather(*tasks) | |
| results.extend(state_results) | |
| cumulative_cost += sum( | |
| r["cost_usd"] for r in state_results if r["error"] is None | |
| ) | |
| if cumulative_cost > max_total_usd: | |
| break | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # DPO pair extraction | |
| # --------------------------------------------------------------------------- | |
| def _normalize_action(text: str | None) -> str: | |
| """Normalize an action string for cluster-by-equality. | |
| For real agentic traces, this should parse the tool call (name + args) and | |
| return a canonical form. For the skeleton we just normalize whitespace. | |
| """ | |
| if text is None: | |
| return "" | |
| return " ".join(text.split()).strip().lower() | |
| def extract_dpo_pairs( | |
| states: Sequence[TraceState], | |
| teacher_actions: Sequence[TeacherCallResult], | |
| agreement_threshold: int = 2, | |
| ) -> list[DPOPair]: | |
| """Convert teacher-disagreement-with-student into preference pairs. | |
| Logic: | |
| - Group teacher_actions by state_id. | |
| - For each state, normalize all teacher responses + student response. | |
| - If `agreement_threshold` or more teachers agree on action X, | |
| and student_action != X: | |
| emit (chosen=X, rejected=student_action) pair | |
| - Otherwise no pair (no signal). | |
| Args: | |
| states: sequence of TraceState (must include state["student_action"]). | |
| teacher_actions: flat list of TeacherCallResult from replay_trace(). | |
| agreement_threshold: min number of teachers that must agree for a pair. | |
| Returns: | |
| List of DPOPair dicts ready for DPO training. | |
| """ | |
| by_state: dict[str, list[TeacherCallResult]] = {} | |
| for tr in teacher_actions: | |
| if tr["error"] is None and tr["response_text"] is not None: | |
| by_state.setdefault(tr["state_id"], []).append(tr) | |
| state_lookup = {s["state_id"]: s for s in states} | |
| pairs: list[DPOPair] = [] | |
| for state_id, calls in by_state.items(): | |
| if state_id not in state_lookup: | |
| continue | |
| state = state_lookup[state_id] | |
| student_norm = _normalize_action(state["student_action"]) | |
| teacher_norm = [_normalize_action(c["response_text"]) for c in calls] | |
| counts = Counter(teacher_norm) | |
| for action, n in counts.items(): | |
| if n >= agreement_threshold and action != student_norm and action: | |
| # Find the original (un-normalized) teacher response for the chosen action. | |
| chosen_text = next( | |
| c["response_text"] for c, norm in zip(calls, teacher_norm) | |
| if norm == action and c["response_text"] | |
| ) | |
| pairs.append({ | |
| "state_id": state_id, | |
| "state_messages": state["messages"], | |
| "chosen": chosen_text, | |
| "rejected": state["student_action"], | |
| "n_teachers_agreeing": n, | |
| }) | |
| break # one pair per state — the most-agreed-upon teacher action | |
| return pairs | |
| def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None: | |
| p = Path(path) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| p.write_text("\n".join(json.dumps(d) for d in pairs) + "\n") | |
| __all__ = [ | |
| "DEFAULT_TEACHERS", | |
| "TeacherSpec", | |
| "TraceState", | |
| "TeacherCallResult", | |
| "DPOPair", | |
| "replay_trace", | |
| "extract_dpo_pairs", | |
| "save_pairs", | |
| ] | |