AutoDataLab2.0 / training /scripts /run_headline_benchmark.py
uchihamadara1816's picture
Upload 172 files
d02bacd verified
"""Headline benchmark with the eval-mode safety net OFF.
Compares four routing policies on the three hard tasks, with five seeds and
both RAG settings. The safety net is OFF (``auto_fill_required=False``) so an
incomplete brief gets penalised by the grader.
Policies
--------
- ``base_naive``: untrained baseline, consults the analyst then submits.
- ``base_roundrobin``: untrained second baseline, walks experts in fixed order.
- ``trained_mlp``: actually trained CoS policy. A 2-layer MLP routing policy
(REINFORCE, 600 episodes, lr 0.003) loaded from
``training/checkpoints/cos_final.pt``. This is the headline trained-model
number. If torch / the checkpoint isn't available the row is skipped and the
plot annotates that.
- ``oracle_router``: deterministic upper bound. The handcoded routing policy
that always consults the required experts in the canonical order. We label
this an *upper bound*, not a trained model -- our trained policies (the MLP
above and the SFT/GRPO LLMs) are trained to imitate this behaviour, and the
number tells you how high a perfect routing policy can score on the
current grader / RAG settings.
Outputs
-------
- ``training/evidence/headline_benchmark.json`` -- raw cells + per-seed runs.
- ``training/evidence/plots/headline_terminal_reward.{png,svg}`` -- the chart.
"""
from __future__ import annotations
import json
import random
import statistics
import sys
from pathlib import Path
from typing import Any, Callable
REPO = Path(__file__).resolve().parents[2]
if str(REPO) not in sys.path:
sys.path.insert(0, str(REPO))
from ceo_brief_env.environment import ( # noqa: E402
CEOBriefEnvironment,
oracle_action_for_observation,
)
from ceo_brief_env.models import CoSAction, CoSObservation # noqa: E402
HARD_TASKS = ["hard_brief", "expert_brief", "crisis_brief"]
SEEDS = [11, 23, 47, 91, 137]
RAG_SETTINGS = [False, True]
def naive_picker(obs: CoSObservation) -> CoSAction:
"""Untrained base policy: consult analyst, then submit. Will be incomplete."""
if "analyst" not in obs.consulted_experts:
return CoSAction(action_type="consult", expert_id="analyst")
if obs.current_brief is None:
return CoSAction(action_type="summarize")
return CoSAction(action_type="submit")
def roundrobin_picker(obs: CoSObservation) -> CoSAction:
"""Second base policy: walks experts in fixed order then submits."""
for expert in ["finance", "analyst", "hr", "strategy"]:
if expert not in obs.consulted_experts:
return CoSAction(action_type="consult", expert_id=expert)
if obs.current_brief is None:
return CoSAction(action_type="summarize")
return CoSAction(action_type="submit")
def _load_trained_mlp_picker() -> tuple[Callable[[CoSObservation], CoSAction], str] | None:
"""Load the actually-trained MLP CoS routing policy.
Returns ``(picker_fn, info)`` or ``None`` if torch / the checkpoint isn't
available. The MLP was trained with REINFORCE for 600 episodes (see
``training/scripts/train_cos_local.py``); ``cos_final.pt`` is the resulting
state dict.
"""
try:
import numpy as np # noqa: F401
import torch
sys.path.insert(0, str(REPO / "training" / "scripts"))
from train_cos_local import ( # type: ignore
ACTIONS,
PolicyNet,
featurize,
load_policy_state_dict_from_file,
)
except Exception as exc: # pragma: no cover -- env without torch
return None
ckpt_paths = [
REPO / "training" / "checkpoints" / "cos_final.pt",
REPO / "training" / "checkpoints" / "cos_ckpt0.pt",
]
ckpt = next((p for p in ckpt_paths if p.exists()), None)
if ckpt is None:
return None
model = PolicyNet()
info = load_policy_state_dict_from_file(model, ckpt)
model.eval()
def picker(obs: CoSObservation) -> CoSAction:
feats = torch.from_numpy(featurize(obs)).unsqueeze(0)
with torch.no_grad():
logits = model(feats)
idx = int(torch.argmax(logits, dim=-1).item())
return ACTIONS[idx]
return picker, f"{ckpt.name} ({info})"
def build_policies() -> tuple[dict[str, Callable[[CoSObservation], CoSAction]], dict[str, str]]:
"""Return the policy table plus a metadata dict for the plot legend."""
policies: dict[str, Callable[[CoSObservation], CoSAction]] = {
"base_naive": naive_picker,
"base_roundrobin": roundrobin_picker,
}
info: dict[str, str] = {
"base_naive": "untrained baseline (analyst-only)",
"base_roundrobin": "untrained baseline (fixed order)",
}
trained = _load_trained_mlp_picker()
if trained is not None:
policies["trained_mlp"] = trained[0]
info["trained_mlp"] = f"trained MLP CoS (REINFORCE, 600 ep) - {trained[1]}"
policies["oracle_router"] = oracle_action_for_observation
info["oracle_router"] = "oracle router (upper bound, handcoded canonical sequence)"
return policies, info
def run_episode(picker: Callable[[CoSObservation], CoSAction], task: str, use_rag: bool, seed: int) -> dict[str, Any]:
random.seed(seed)
env = CEOBriefEnvironment(shaping="strict", auto_fill_required=False)
obs = env.reset(task=task, use_rag=use_rag)
max_steps = obs.max_steps
steps = 0
consulted_path: list[str] = []
while not obs.done and steps < max_steps:
steps += 1
action = picker(obs)
if action.action_type == "consult" and action.expert_id:
consulted_path.append(action.expert_id)
obs = env.step(action)
terminal = float(obs.terminal_grader_score or 0.0)
cumulative = float(obs.reward_breakdown.cumulative if obs.reward_breakdown else 0.0)
return {
"terminal": terminal,
"cumulative": cumulative,
"steps": steps,
"consulted": list(obs.consulted_experts),
"path": consulted_path,
"submitted": bool(obs.done),
}
def aggregate(samples: list[float]) -> dict[str, float]:
if not samples:
return {"mean": 0.0, "std": 0.0, "n": 0}
if len(samples) == 1:
return {"mean": samples[0], "std": 0.0, "n": 1}
return {
"mean": statistics.fmean(samples),
"std": statistics.pstdev(samples),
"n": len(samples),
}
def main() -> None:
out_dir = REPO / "training" / "evidence"
out_dir.mkdir(parents=True, exist_ok=True)
plots_dir = out_dir / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
policies, policy_info = build_policies()
results: dict[str, Any] = {
"schema": "autodatalab-plus.headline_benchmark.v2",
"config": {
"tasks": HARD_TASKS,
"policies": list(policies.keys()),
"policy_info": policy_info,
"seeds": SEEDS,
"rag_settings": RAG_SETTINGS,
"auto_fill_required": False,
"shaping": "strict",
},
"cells": [],
}
for task in HARD_TASKS:
for use_rag in RAG_SETTINGS:
for policy_name, policy_fn in policies.items():
terminals: list[float] = []
cumulatives: list[float] = []
runs: list[dict[str, Any]] = []
for seed in SEEDS:
rollout = run_episode(policy_fn, task, use_rag, seed)
terminals.append(rollout["terminal"])
cumulatives.append(rollout["cumulative"])
runs.append({"seed": seed, **rollout})
cell = {
"task": task,
"policy": policy_name,
"use_rag": use_rag,
"terminal": aggregate(terminals),
"cumulative": aggregate(cumulatives),
"runs": runs,
}
results["cells"].append(cell)
print(
f"task={task:13s} rag={str(use_rag):5s} policy={policy_name:18s} "
f"terminal_mean={cell['terminal']['mean']:.3f} "
f"std={cell['terminal']['std']:.3f}"
)
json_path = out_dir / "headline_benchmark.json"
json_path.write_text(json.dumps(results, indent=2))
print(f"\nwrote {json_path.relative_to(REPO)}")
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib not available; skipping money plot")
return
by_cell: dict[tuple[str, str], list[float]] = {}
for cell in results["cells"]:
key = (cell["task"], cell["policy"])
by_cell.setdefault(key, []).extend([r["terminal"] for r in cell["runs"]])
plot_order = [p for p in ("base_naive", "base_roundrobin", "trained_mlp", "oracle_router") if p in policies]
pretty_name = {
"base_naive": "base (naive)",
"base_roundrobin": "base (round-robin)",
"trained_mlp": "trained MLP CoS\n(REINFORCE, 600 ep)",
"oracle_router": "oracle router\n(upper bound)",
}
color = {
"base_naive": "#9aa0a6",
"base_roundrobin": "#bdc1c6",
"trained_mlp": "#34a853",
"oracle_router": "#1a73e8",
}
hatch = {
"oracle_router": "//",
}
fig, ax = plt.subplots(figsize=(10.4, 5.0), dpi=140)
n_groups = len(HARD_TASKS)
width = 0.78 / max(1, len(plot_order))
xs = list(range(n_groups))
for i, policy in enumerate(plot_order):
means = []
stds = []
for task in HARD_TASKS:
samples = by_cell.get((task, policy), [])
agg = aggregate(samples)
means.append(agg["mean"])
stds.append(agg["std"])
offsets = [x + (i - (len(plot_order) - 1) / 2.0) * width for x in xs]
bars = ax.bar(
offsets,
means,
width=width,
yerr=stds,
capsize=4,
color=color[policy],
edgecolor="black",
linewidth=0.6,
hatch=hatch.get(policy, ""),
label=pretty_name[policy],
)
for bar, mean in zip(bars, means):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.015,
f"{mean:.2f}",
ha="center",
va="bottom",
fontsize=8,
)
ax.set_xticks(xs)
ax.set_xticklabels([t.replace("_brief", "") for t in HARD_TASKS])
ax.set_ylim(0.0, 1.0)
ax.set_ylabel("Terminal grader score (0..1)")
ax.set_title(
"Terminal reward, fallback disabled\n"
"untrained baselines vs trained MLP CoS vs oracle router (upper bound)\n"
"3 hard tasks, 5 seeds (RAG on/off averaged)"
)
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.legend(loc="upper left", framealpha=0.9, fontsize=9)
fig.tight_layout()
png_path = plots_dir / "headline_terminal_reward.png"
svg_path = plots_dir / "headline_terminal_reward.svg"
fig.savefig(png_path)
fig.savefig(svg_path)
print(f"wrote {png_path.relative_to(REPO)}")
print(f"wrote {svg_path.relative_to(REPO)}")
if __name__ == "__main__":
main()