"""teacher_replay.py — N-teacher OpenRouter parallel client + DPO-pair extractor. This is channel 3 of the integrated trainer: at each step of a frozen agentic trace, query N pre-trained external teachers (frontier models from different labs) and convert teacher disagreement into preference pairs for DPO loss. Generalized from spike-001's `replay.py`. Verified economic floor (✅ spike 001): $0.98 mean per-trace cost ungated, $0.30/trace projected with VOI gating. Usage: from teacher_replay import replay_trace, extract_dpo_pairs # 1. Replay each step of a frozen trace with N teachers. teacher_actions = await replay_trace( states=trace_states, teachers=DEFAULT_TEACHERS, max_total_usd=10.0, ) # 2. Extract DPO pairs from teacher disagreement. pairs = extract_dpo_pairs( states=trace_states, student_actions=trace_student_actions, teacher_actions=teacher_actions, agreement_threshold=2, # at least 2/3 teachers must agree ) # → [{"chosen": …, "rejected": …, "state": …}, …] """ from __future__ import annotations import asyncio import json import os import time from collections import Counter from collections.abc import Sequence from pathlib import Path from typing import TypedDict # httpx is lazy-imported inside replay_trace() so that DPO-pair extraction # (the deterministic local logic) is testable without httpx installed. # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- DEFAULT_TEACHERS: list["TeacherSpec"] = [ {"slug": "anthropic/claude-opus-4.7", "input_per_mtok": 15.0, "output_per_mtok": 75.0}, {"slug": "openai/gpt-5", "input_per_mtok": 1.25, "output_per_mtok": 10.0}, {"slug": "deepseek/deepseek-v4-pro", "input_per_mtok": 1.10, "output_per_mtok": 4.40}, ] OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" def _load_api_key() -> str: """Load OPENROUTER_API_KEY from env or ~/.hermes/.env (same as spike 001).""" if "OPENROUTER_API_KEY" in os.environ: return os.environ["OPENROUTER_API_KEY"] hermes_env = Path.home() / ".hermes" / ".env" if hermes_env.exists(): for line in hermes_env.read_text().splitlines(): line = line.strip() if line.startswith("OPENROUTER_API_KEY="): return line.split("=", 1)[1].strip().strip('"').strip("'") raise RuntimeError("OPENROUTER_API_KEY not found in env or ~/.hermes/.env") # --------------------------------------------------------------------------- # Types # --------------------------------------------------------------------------- class TeacherSpec(TypedDict): slug: str input_per_mtok: float output_per_mtok: float class TraceState(TypedDict): """One step of a frozen agentic trace.""" state_id: str # unique within the trace messages: list[dict] # the conversation up to and including this step's user prompt student_action: str # what the student actually did at this step (for DPO comparison) class TeacherCallResult(TypedDict): state_id: str teacher_slug: str response_text: str | None latency_s: float prompt_tokens: int completion_tokens: int cost_usd: float error: str | None class DPOPair(TypedDict): state_id: str state_messages: list[dict] chosen: str # teacher-consensus action rejected: str # student action n_teachers_agreeing: int # --------------------------------------------------------------------------- # Teacher replay # --------------------------------------------------------------------------- async def _call_teacher( client, # httpx.AsyncClient — lazy-typed so module imports without httpx state: TraceState, teacher: TeacherSpec, api_key: str, max_tokens: int = 200, ) -> TeacherCallResult: payload = { "model": teacher["slug"], "messages": state["messages"], "max_tokens": max_tokens, "temperature": 0.2, } headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://huggingface.co/Codeseys/composer-replication-framework", "X-Title": "composer-replication-framework spike-005-skeleton", } t0 = time.perf_counter() err = None response_text = None prompt_tokens = 0 completion_tokens = 0 try: r = await client.post(OPENROUTER_URL, json=payload, headers=headers, timeout=120.0) r.raise_for_status() data = r.json() response_text = data["choices"][0]["message"]["content"] usage = data.get("usage", {}) prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) except Exception as e: # noqa: BLE001 — capture all for verdict logging err = repr(e)[:300] t1 = time.perf_counter() cost_usd = ( (prompt_tokens / 1_000_000) * teacher["input_per_mtok"] + (completion_tokens / 1_000_000) * teacher["output_per_mtok"] ) return { "state_id": state["state_id"], "teacher_slug": teacher["slug"], "response_text": response_text, "latency_s": round(t1 - t0, 3), "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "cost_usd": round(cost_usd, 6), "error": err, } async def replay_trace( states: Sequence[TraceState], teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS), max_total_usd: float = 5.0, api_key: str | None = None, ) -> list[TeacherCallResult]: """Query all (state, teacher) pairs in parallel within each state. Hard-caps spend at max_total_usd. Returns per-call results; aggregate by state_id downstream to extract DPO pairs. """ import httpx # lazy import — only required for live-API replay api_key = api_key or _load_api_key() results: list[TeacherCallResult] = [] cumulative_cost = 0.0 async with httpx.AsyncClient() as client: for state in states: tasks = [_call_teacher(client, state, t, api_key) for t in teachers] state_results = await asyncio.gather(*tasks) results.extend(state_results) cumulative_cost += sum( r["cost_usd"] for r in state_results if r["error"] is None ) if cumulative_cost > max_total_usd: break return results # --------------------------------------------------------------------------- # DPO pair extraction # --------------------------------------------------------------------------- def _normalize_action(text: str | None) -> str: """Normalize an action string for cluster-by-equality. For real agentic traces, this should parse the tool call (name + args) and return a canonical form. For the skeleton we just normalize whitespace. """ if text is None: return "" return " ".join(text.split()).strip().lower() def extract_dpo_pairs( states: Sequence[TraceState], teacher_actions: Sequence[TeacherCallResult], agreement_threshold: int = 2, ) -> list[DPOPair]: """Convert teacher-disagreement-with-student into preference pairs. Logic: - Group teacher_actions by state_id. - For each state, normalize all teacher responses + student response. - If `agreement_threshold` or more teachers agree on action X, and student_action != X: emit (chosen=X, rejected=student_action) pair - Otherwise no pair (no signal). Args: states: sequence of TraceState (must include state["student_action"]). teacher_actions: flat list of TeacherCallResult from replay_trace(). agreement_threshold: min number of teachers that must agree for a pair. Returns: List of DPOPair dicts ready for DPO training. """ by_state: dict[str, list[TeacherCallResult]] = {} for tr in teacher_actions: if tr["error"] is None and tr["response_text"] is not None: by_state.setdefault(tr["state_id"], []).append(tr) state_lookup = {s["state_id"]: s for s in states} pairs: list[DPOPair] = [] for state_id, calls in by_state.items(): if state_id not in state_lookup: continue state = state_lookup[state_id] student_norm = _normalize_action(state["student_action"]) teacher_norm = [_normalize_action(c["response_text"]) for c in calls] counts = Counter(teacher_norm) for action, n in counts.items(): if n >= agreement_threshold and action != student_norm and action: # Find the original (un-normalized) teacher response for the chosen action. chosen_text = next( c["response_text"] for c, norm in zip(calls, teacher_norm) if norm == action and c["response_text"] ) pairs.append({ "state_id": state_id, "state_messages": state["messages"], "chosen": chosen_text, "rejected": state["student_action"], "n_teachers_agreeing": n, }) break # one pair per state — the most-agreed-upon teacher action return pairs def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None: p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text("\n".join(json.dumps(d) for d in pairs) + "\n") __all__ = [ "DEFAULT_TEACHERS", "TeacherSpec", "TraceState", "TeacherCallResult", "DPOPair", "replay_trace", "extract_dpo_pairs", "save_pairs", ]