File size: 9,834 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""teacher_replay.py — N-teacher OpenRouter parallel client + DPO-pair extractor.

This is channel 3 of the integrated trainer: at each step of a frozen agentic
trace, query N pre-trained external teachers (frontier models from different
labs) and convert teacher disagreement into preference pairs for DPO loss.

Generalized from spike-001's `replay.py`. Verified economic floor (✅ spike 001):
$0.98 mean per-trace cost ungated, $0.30/trace projected with VOI gating.

Usage:
    from teacher_replay import replay_trace, extract_dpo_pairs

    # 1. Replay each step of a frozen trace with N teachers.
    teacher_actions = await replay_trace(
        states=trace_states,
        teachers=DEFAULT_TEACHERS,
        max_total_usd=10.0,
    )

    # 2. Extract DPO pairs from teacher disagreement.
    pairs = extract_dpo_pairs(
        states=trace_states,
        student_actions=trace_student_actions,
        teacher_actions=teacher_actions,
        agreement_threshold=2,  # at least 2/3 teachers must agree
    )
    # → [{"chosen": …, "rejected": …, "state": …}, …]
"""

from __future__ import annotations

import asyncio
import json
import os
import time
from collections import Counter
from collections.abc import Sequence
from pathlib import Path
from typing import TypedDict

# httpx is lazy-imported inside replay_trace() so that DPO-pair extraction
# (the deterministic local logic) is testable without httpx installed.


# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

DEFAULT_TEACHERS: list["TeacherSpec"] = [
    {"slug": "anthropic/claude-opus-4.7", "input_per_mtok": 15.0, "output_per_mtok": 75.0},
    {"slug": "openai/gpt-5",              "input_per_mtok": 1.25, "output_per_mtok": 10.0},
    {"slug": "deepseek/deepseek-v4-pro",  "input_per_mtok": 1.10, "output_per_mtok": 4.40},
]

OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"


def _load_api_key() -> str:
    """Load OPENROUTER_API_KEY from env or ~/.hermes/.env (same as spike 001)."""
    if "OPENROUTER_API_KEY" in os.environ:
        return os.environ["OPENROUTER_API_KEY"]
    hermes_env = Path.home() / ".hermes" / ".env"
    if hermes_env.exists():
        for line in hermes_env.read_text().splitlines():
            line = line.strip()
            if line.startswith("OPENROUTER_API_KEY="):
                return line.split("=", 1)[1].strip().strip('"').strip("'")
    raise RuntimeError("OPENROUTER_API_KEY not found in env or ~/.hermes/.env")


# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------

class TeacherSpec(TypedDict):
    slug: str
    input_per_mtok: float
    output_per_mtok: float


class TraceState(TypedDict):
    """One step of a frozen agentic trace."""
    state_id: str           # unique within the trace
    messages: list[dict]    # the conversation up to and including this step's user prompt
    student_action: str     # what the student actually did at this step (for DPO comparison)


class TeacherCallResult(TypedDict):
    state_id: str
    teacher_slug: str
    response_text: str | None
    latency_s: float
    prompt_tokens: int
    completion_tokens: int
    cost_usd: float
    error: str | None


class DPOPair(TypedDict):
    state_id: str
    state_messages: list[dict]
    chosen: str       # teacher-consensus action
    rejected: str     # student action
    n_teachers_agreeing: int


# ---------------------------------------------------------------------------
# Teacher replay
# ---------------------------------------------------------------------------

async def _call_teacher(
    client,  # httpx.AsyncClient — lazy-typed so module imports without httpx
    state: TraceState,
    teacher: TeacherSpec,
    api_key: str,
    max_tokens: int = 200,
) -> TeacherCallResult:
    payload = {
        "model": teacher["slug"],
        "messages": state["messages"],
        "max_tokens": max_tokens,
        "temperature": 0.2,
    }
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
        "HTTP-Referer": "https://huggingface.co/Codeseys/composer-replication-framework",
        "X-Title": "composer-replication-framework spike-005-skeleton",
    }
    t0 = time.perf_counter()
    err = None
    response_text = None
    prompt_tokens = 0
    completion_tokens = 0
    try:
        r = await client.post(OPENROUTER_URL, json=payload, headers=headers, timeout=120.0)
        r.raise_for_status()
        data = r.json()
        response_text = data["choices"][0]["message"]["content"]
        usage = data.get("usage", {})
        prompt_tokens = usage.get("prompt_tokens", 0)
        completion_tokens = usage.get("completion_tokens", 0)
    except Exception as e:  # noqa: BLE001 — capture all for verdict logging
        err = repr(e)[:300]
    t1 = time.perf_counter()
    cost_usd = (
        (prompt_tokens / 1_000_000) * teacher["input_per_mtok"]
        + (completion_tokens / 1_000_000) * teacher["output_per_mtok"]
    )
    return {
        "state_id": state["state_id"],
        "teacher_slug": teacher["slug"],
        "response_text": response_text,
        "latency_s": round(t1 - t0, 3),
        "prompt_tokens": prompt_tokens,
        "completion_tokens": completion_tokens,
        "cost_usd": round(cost_usd, 6),
        "error": err,
    }


async def replay_trace(
    states: Sequence[TraceState],
    teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
    max_total_usd: float = 5.0,
    api_key: str | None = None,
) -> list[TeacherCallResult]:
    """Query all (state, teacher) pairs in parallel within each state.

    Hard-caps spend at max_total_usd. Returns per-call results; aggregate
    by state_id downstream to extract DPO pairs.
    """
    import httpx  # lazy import — only required for live-API replay

    api_key = api_key or _load_api_key()
    results: list[TeacherCallResult] = []
    cumulative_cost = 0.0
    async with httpx.AsyncClient() as client:
        for state in states:
            tasks = [_call_teacher(client, state, t, api_key) for t in teachers]
            state_results = await asyncio.gather(*tasks)
            results.extend(state_results)
            cumulative_cost += sum(
                r["cost_usd"] for r in state_results if r["error"] is None
            )
            if cumulative_cost > max_total_usd:
                break
    return results


# ---------------------------------------------------------------------------
# DPO pair extraction
# ---------------------------------------------------------------------------

def _normalize_action(text: str | None) -> str:
    """Normalize an action string for cluster-by-equality.

    For real agentic traces, this should parse the tool call (name + args) and
    return a canonical form. For the skeleton we just normalize whitespace.
    """
    if text is None:
        return ""
    return " ".join(text.split()).strip().lower()


def extract_dpo_pairs(
    states: Sequence[TraceState],
    teacher_actions: Sequence[TeacherCallResult],
    agreement_threshold: int = 2,
) -> list[DPOPair]:
    """Convert teacher-disagreement-with-student into preference pairs.

    Logic:
      - Group teacher_actions by state_id.
      - For each state, normalize all teacher responses + student response.
      - If `agreement_threshold` or more teachers agree on action X,
        and student_action != X:
            emit (chosen=X, rejected=student_action) pair
      - Otherwise no pair (no signal).

    Args:
        states: sequence of TraceState (must include state["student_action"]).
        teacher_actions: flat list of TeacherCallResult from replay_trace().
        agreement_threshold: min number of teachers that must agree for a pair.

    Returns:
        List of DPOPair dicts ready for DPO training.
    """
    by_state: dict[str, list[TeacherCallResult]] = {}
    for tr in teacher_actions:
        if tr["error"] is None and tr["response_text"] is not None:
            by_state.setdefault(tr["state_id"], []).append(tr)

    state_lookup = {s["state_id"]: s for s in states}
    pairs: list[DPOPair] = []

    for state_id, calls in by_state.items():
        if state_id not in state_lookup:
            continue
        state = state_lookup[state_id]
        student_norm = _normalize_action(state["student_action"])

        teacher_norm = [_normalize_action(c["response_text"]) for c in calls]
        counts = Counter(teacher_norm)

        for action, n in counts.items():
            if n >= agreement_threshold and action != student_norm and action:
                # Find the original (un-normalized) teacher response for the chosen action.
                chosen_text = next(
                    c["response_text"] for c, norm in zip(calls, teacher_norm)
                    if norm == action and c["response_text"]
                )
                pairs.append({
                    "state_id": state_id,
                    "state_messages": state["messages"],
                    "chosen": chosen_text,
                    "rejected": state["student_action"],
                    "n_teachers_agreeing": n,
                })
                break  # one pair per state — the most-agreed-upon teacher action

    return pairs


def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    p.write_text("\n".join(json.dumps(d) for d in pairs) + "\n")


__all__ = [
    "DEFAULT_TEACHERS",
    "TeacherSpec",
    "TraceState",
    "TeacherCallResult",
    "DPOPair",
    "replay_trace",
    "extract_dpo_pairs",
    "save_pairs",
]