Codeseys's picture
Wave 10 — packaging: composer_replication is now pip-installable
ac05fbf
"""teacher_replay.py — N-teacher OpenRouter parallel client + DPO-pair extractor.
This is channel 3 of the integrated trainer: at each step of a frozen agentic
trace, query N pre-trained external teachers (frontier models from different
labs) and convert teacher disagreement into preference pairs for DPO loss.
Generalized from spike-001's `replay.py`. Verified economic floor (✅ spike 001):
$0.98 mean per-trace cost ungated, $0.30/trace projected with VOI gating.
Usage:
from teacher_replay import replay_trace, extract_dpo_pairs
# 1. Replay each step of a frozen trace with N teachers.
teacher_actions = await replay_trace(
states=trace_states,
teachers=DEFAULT_TEACHERS,
max_total_usd=10.0,
)
# 2. Extract DPO pairs from teacher disagreement.
pairs = extract_dpo_pairs(
states=trace_states,
student_actions=trace_student_actions,
teacher_actions=teacher_actions,
agreement_threshold=2, # at least 2/3 teachers must agree
)
# → [{"chosen": …, "rejected": …, "state": …}, …]
"""
from __future__ import annotations
import asyncio
import json
import os
import time
from collections import Counter
from collections.abc import Sequence
from pathlib import Path
from typing import TypedDict
# httpx is lazy-imported inside replay_trace() so that DPO-pair extraction
# (the deterministic local logic) is testable without httpx installed.
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
DEFAULT_TEACHERS: list["TeacherSpec"] = [
{"slug": "anthropic/claude-opus-4.7", "input_per_mtok": 15.0, "output_per_mtok": 75.0},
{"slug": "openai/gpt-5", "input_per_mtok": 1.25, "output_per_mtok": 10.0},
{"slug": "deepseek/deepseek-v4-pro", "input_per_mtok": 1.10, "output_per_mtok": 4.40},
]
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
def _load_api_key() -> str:
"""Load OPENROUTER_API_KEY from env or ~/.hermes/.env (same as spike 001)."""
if "OPENROUTER_API_KEY" in os.environ:
return os.environ["OPENROUTER_API_KEY"]
hermes_env = Path.home() / ".hermes" / ".env"
if hermes_env.exists():
for line in hermes_env.read_text().splitlines():
line = line.strip()
if line.startswith("OPENROUTER_API_KEY="):
return line.split("=", 1)[1].strip().strip('"').strip("'")
raise RuntimeError("OPENROUTER_API_KEY not found in env or ~/.hermes/.env")
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class TeacherSpec(TypedDict):
slug: str
input_per_mtok: float
output_per_mtok: float
class TraceState(TypedDict):
"""One step of a frozen agentic trace."""
state_id: str # unique within the trace
messages: list[dict] # the conversation up to and including this step's user prompt
student_action: str # what the student actually did at this step (for DPO comparison)
class TeacherCallResult(TypedDict):
state_id: str
teacher_slug: str
response_text: str | None
latency_s: float
prompt_tokens: int
completion_tokens: int
cost_usd: float
error: str | None
class DPOPair(TypedDict):
state_id: str
state_messages: list[dict]
chosen: str # teacher-consensus action
rejected: str # student action
n_teachers_agreeing: int
# ---------------------------------------------------------------------------
# Teacher replay
# ---------------------------------------------------------------------------
async def _call_teacher(
client, # httpx.AsyncClient — lazy-typed so module imports without httpx
state: TraceState,
teacher: TeacherSpec,
api_key: str,
max_tokens: int = 200,
) -> TeacherCallResult:
payload = {
"model": teacher["slug"],
"messages": state["messages"],
"max_tokens": max_tokens,
"temperature": 0.2,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://huggingface.co/Codeseys/composer-replication-framework",
"X-Title": "composer-replication-framework spike-005-skeleton",
}
t0 = time.perf_counter()
err = None
response_text = None
prompt_tokens = 0
completion_tokens = 0
try:
r = await client.post(OPENROUTER_URL, json=payload, headers=headers, timeout=120.0)
r.raise_for_status()
data = r.json()
response_text = data["choices"][0]["message"]["content"]
usage = data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
except Exception as e: # noqa: BLE001 — capture all for verdict logging
err = repr(e)[:300]
t1 = time.perf_counter()
cost_usd = (
(prompt_tokens / 1_000_000) * teacher["input_per_mtok"]
+ (completion_tokens / 1_000_000) * teacher["output_per_mtok"]
)
return {
"state_id": state["state_id"],
"teacher_slug": teacher["slug"],
"response_text": response_text,
"latency_s": round(t1 - t0, 3),
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"cost_usd": round(cost_usd, 6),
"error": err,
}
async def replay_trace(
states: Sequence[TraceState],
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
max_total_usd: float = 5.0,
api_key: str | None = None,
) -> list[TeacherCallResult]:
"""Query all (state, teacher) pairs in parallel within each state.
Hard-caps spend at max_total_usd. Returns per-call results; aggregate
by state_id downstream to extract DPO pairs.
"""
import httpx # lazy import — only required for live-API replay
api_key = api_key or _load_api_key()
results: list[TeacherCallResult] = []
cumulative_cost = 0.0
async with httpx.AsyncClient() as client:
for state in states:
tasks = [_call_teacher(client, state, t, api_key) for t in teachers]
state_results = await asyncio.gather(*tasks)
results.extend(state_results)
cumulative_cost += sum(
r["cost_usd"] for r in state_results if r["error"] is None
)
if cumulative_cost > max_total_usd:
break
return results
# ---------------------------------------------------------------------------
# DPO pair extraction
# ---------------------------------------------------------------------------
def _normalize_action(text: str | None) -> str:
"""Normalize an action string for cluster-by-equality.
For real agentic traces, this should parse the tool call (name + args) and
return a canonical form. For the skeleton we just normalize whitespace.
"""
if text is None:
return ""
return " ".join(text.split()).strip().lower()
def extract_dpo_pairs(
states: Sequence[TraceState],
teacher_actions: Sequence[TeacherCallResult],
agreement_threshold: int = 2,
) -> list[DPOPair]:
"""Convert teacher-disagreement-with-student into preference pairs.
Logic:
- Group teacher_actions by state_id.
- For each state, normalize all teacher responses + student response.
- If `agreement_threshold` or more teachers agree on action X,
and student_action != X:
emit (chosen=X, rejected=student_action) pair
- Otherwise no pair (no signal).
Args:
states: sequence of TraceState (must include state["student_action"]).
teacher_actions: flat list of TeacherCallResult from replay_trace().
agreement_threshold: min number of teachers that must agree for a pair.
Returns:
List of DPOPair dicts ready for DPO training.
"""
by_state: dict[str, list[TeacherCallResult]] = {}
for tr in teacher_actions:
if tr["error"] is None and tr["response_text"] is not None:
by_state.setdefault(tr["state_id"], []).append(tr)
state_lookup = {s["state_id"]: s for s in states}
pairs: list[DPOPair] = []
for state_id, calls in by_state.items():
if state_id not in state_lookup:
continue
state = state_lookup[state_id]
student_norm = _normalize_action(state["student_action"])
teacher_norm = [_normalize_action(c["response_text"]) for c in calls]
counts = Counter(teacher_norm)
for action, n in counts.items():
if n >= agreement_threshold and action != student_norm and action:
# Find the original (un-normalized) teacher response for the chosen action.
chosen_text = next(
c["response_text"] for c, norm in zip(calls, teacher_norm)
if norm == action and c["response_text"]
)
pairs.append({
"state_id": state_id,
"state_messages": state["messages"],
"chosen": chosen_text,
"rejected": state["student_action"],
"n_teachers_agreeing": n,
})
break # one pair per state — the most-agreed-upon teacher action
return pairs
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None:
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text("\n".join(json.dumps(d) for d in pairs) + "\n")
__all__ = [
"DEFAULT_TEACHERS",
"TeacherSpec",
"TraceState",
"TeacherCallResult",
"DPOPair",
"replay_trace",
"extract_dpo_pairs",
"save_pairs",
]