"""Headline benchmark with the eval-mode safety net OFF. Compares four routing policies on the three hard tasks, with five seeds and both RAG settings. The safety net is OFF (``auto_fill_required=False``) so an incomplete brief gets penalised by the grader. Policies -------- - ``base_naive``: untrained baseline, consults the analyst then submits. - ``base_roundrobin``: untrained second baseline, walks experts in fixed order. - ``trained_mlp``: actually trained CoS policy. A 2-layer MLP routing policy (REINFORCE, 600 episodes, lr 0.003) loaded from ``training/checkpoints/cos_final.pt``. This is the headline trained-model number. If torch / the checkpoint isn't available the row is skipped and the plot annotates that. - ``oracle_router``: deterministic upper bound. The handcoded routing policy that always consults the required experts in the canonical order. We label this an *upper bound*, not a trained model -- our trained policies (the MLP above and the SFT/GRPO LLMs) are trained to imitate this behaviour, and the number tells you how high a perfect routing policy can score on the current grader / RAG settings. Outputs ------- - ``training/evidence/headline_benchmark.json`` -- raw cells + per-seed runs. - ``training/evidence/plots/headline_terminal_reward.{png,svg}`` -- the chart. """ from __future__ import annotations import json import random import statistics import sys from pathlib import Path from typing import Any, Callable REPO = Path(__file__).resolve().parents[2] if str(REPO) not in sys.path: sys.path.insert(0, str(REPO)) from ceo_brief_env.environment import ( # noqa: E402 CEOBriefEnvironment, oracle_action_for_observation, ) from ceo_brief_env.models import CoSAction, CoSObservation # noqa: E402 HARD_TASKS = ["hard_brief", "expert_brief", "crisis_brief"] SEEDS = [11, 23, 47, 91, 137] RAG_SETTINGS = [False, True] def naive_picker(obs: CoSObservation) -> CoSAction: """Untrained base policy: consult analyst, then submit. Will be incomplete.""" if "analyst" not in obs.consulted_experts: return CoSAction(action_type="consult", expert_id="analyst") if obs.current_brief is None: return CoSAction(action_type="summarize") return CoSAction(action_type="submit") def roundrobin_picker(obs: CoSObservation) -> CoSAction: """Second base policy: walks experts in fixed order then submits.""" for expert in ["finance", "analyst", "hr", "strategy"]: if expert not in obs.consulted_experts: return CoSAction(action_type="consult", expert_id=expert) if obs.current_brief is None: return CoSAction(action_type="summarize") return CoSAction(action_type="submit") def _load_trained_mlp_picker() -> tuple[Callable[[CoSObservation], CoSAction], str] | None: """Load the actually-trained MLP CoS routing policy. Returns ``(picker_fn, info)`` or ``None`` if torch / the checkpoint isn't available. The MLP was trained with REINFORCE for 600 episodes (see ``training/scripts/train_cos_local.py``); ``cos_final.pt`` is the resulting state dict. """ try: import numpy as np # noqa: F401 import torch sys.path.insert(0, str(REPO / "training" / "scripts")) from train_cos_local import ( # type: ignore ACTIONS, PolicyNet, featurize, load_policy_state_dict_from_file, ) except Exception as exc: # pragma: no cover -- env without torch return None ckpt_paths = [ REPO / "training" / "checkpoints" / "cos_final.pt", REPO / "training" / "checkpoints" / "cos_ckpt0.pt", ] ckpt = next((p for p in ckpt_paths if p.exists()), None) if ckpt is None: return None model = PolicyNet() info = load_policy_state_dict_from_file(model, ckpt) model.eval() def picker(obs: CoSObservation) -> CoSAction: feats = torch.from_numpy(featurize(obs)).unsqueeze(0) with torch.no_grad(): logits = model(feats) idx = int(torch.argmax(logits, dim=-1).item()) return ACTIONS[idx] return picker, f"{ckpt.name} ({info})" def build_policies() -> tuple[dict[str, Callable[[CoSObservation], CoSAction]], dict[str, str]]: """Return the policy table plus a metadata dict for the plot legend.""" policies: dict[str, Callable[[CoSObservation], CoSAction]] = { "base_naive": naive_picker, "base_roundrobin": roundrobin_picker, } info: dict[str, str] = { "base_naive": "untrained baseline (analyst-only)", "base_roundrobin": "untrained baseline (fixed order)", } trained = _load_trained_mlp_picker() if trained is not None: policies["trained_mlp"] = trained[0] info["trained_mlp"] = f"trained MLP CoS (REINFORCE, 600 ep) - {trained[1]}" policies["oracle_router"] = oracle_action_for_observation info["oracle_router"] = "oracle router (upper bound, handcoded canonical sequence)" return policies, info def run_episode(picker: Callable[[CoSObservation], CoSAction], task: str, use_rag: bool, seed: int) -> dict[str, Any]: random.seed(seed) env = CEOBriefEnvironment(shaping="strict", auto_fill_required=False) obs = env.reset(task=task, use_rag=use_rag) max_steps = obs.max_steps steps = 0 consulted_path: list[str] = [] while not obs.done and steps < max_steps: steps += 1 action = picker(obs) if action.action_type == "consult" and action.expert_id: consulted_path.append(action.expert_id) obs = env.step(action) terminal = float(obs.terminal_grader_score or 0.0) cumulative = float(obs.reward_breakdown.cumulative if obs.reward_breakdown else 0.0) return { "terminal": terminal, "cumulative": cumulative, "steps": steps, "consulted": list(obs.consulted_experts), "path": consulted_path, "submitted": bool(obs.done), } def aggregate(samples: list[float]) -> dict[str, float]: if not samples: return {"mean": 0.0, "std": 0.0, "n": 0} if len(samples) == 1: return {"mean": samples[0], "std": 0.0, "n": 1} return { "mean": statistics.fmean(samples), "std": statistics.pstdev(samples), "n": len(samples), } def main() -> None: out_dir = REPO / "training" / "evidence" out_dir.mkdir(parents=True, exist_ok=True) plots_dir = out_dir / "plots" plots_dir.mkdir(parents=True, exist_ok=True) policies, policy_info = build_policies() results: dict[str, Any] = { "schema": "autodatalab-plus.headline_benchmark.v2", "config": { "tasks": HARD_TASKS, "policies": list(policies.keys()), "policy_info": policy_info, "seeds": SEEDS, "rag_settings": RAG_SETTINGS, "auto_fill_required": False, "shaping": "strict", }, "cells": [], } for task in HARD_TASKS: for use_rag in RAG_SETTINGS: for policy_name, policy_fn in policies.items(): terminals: list[float] = [] cumulatives: list[float] = [] runs: list[dict[str, Any]] = [] for seed in SEEDS: rollout = run_episode(policy_fn, task, use_rag, seed) terminals.append(rollout["terminal"]) cumulatives.append(rollout["cumulative"]) runs.append({"seed": seed, **rollout}) cell = { "task": task, "policy": policy_name, "use_rag": use_rag, "terminal": aggregate(terminals), "cumulative": aggregate(cumulatives), "runs": runs, } results["cells"].append(cell) print( f"task={task:13s} rag={str(use_rag):5s} policy={policy_name:18s} " f"terminal_mean={cell['terminal']['mean']:.3f} " f"std={cell['terminal']['std']:.3f}" ) json_path = out_dir / "headline_benchmark.json" json_path.write_text(json.dumps(results, indent=2)) print(f"\nwrote {json_path.relative_to(REPO)}") try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except ImportError: print("matplotlib not available; skipping money plot") return by_cell: dict[tuple[str, str], list[float]] = {} for cell in results["cells"]: key = (cell["task"], cell["policy"]) by_cell.setdefault(key, []).extend([r["terminal"] for r in cell["runs"]]) plot_order = [p for p in ("base_naive", "base_roundrobin", "trained_mlp", "oracle_router") if p in policies] pretty_name = { "base_naive": "base (naive)", "base_roundrobin": "base (round-robin)", "trained_mlp": "trained MLP CoS\n(REINFORCE, 600 ep)", "oracle_router": "oracle router\n(upper bound)", } color = { "base_naive": "#9aa0a6", "base_roundrobin": "#bdc1c6", "trained_mlp": "#34a853", "oracle_router": "#1a73e8", } hatch = { "oracle_router": "//", } fig, ax = plt.subplots(figsize=(10.4, 5.0), dpi=140) n_groups = len(HARD_TASKS) width = 0.78 / max(1, len(plot_order)) xs = list(range(n_groups)) for i, policy in enumerate(plot_order): means = [] stds = [] for task in HARD_TASKS: samples = by_cell.get((task, policy), []) agg = aggregate(samples) means.append(agg["mean"]) stds.append(agg["std"]) offsets = [x + (i - (len(plot_order) - 1) / 2.0) * width for x in xs] bars = ax.bar( offsets, means, width=width, yerr=stds, capsize=4, color=color[policy], edgecolor="black", linewidth=0.6, hatch=hatch.get(policy, ""), label=pretty_name[policy], ) for bar, mean in zip(bars, means): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.015, f"{mean:.2f}", ha="center", va="bottom", fontsize=8, ) ax.set_xticks(xs) ax.set_xticklabels([t.replace("_brief", "") for t in HARD_TASKS]) ax.set_ylim(0.0, 1.0) ax.set_ylabel("Terminal grader score (0..1)") ax.set_title( "Terminal reward, fallback disabled\n" "untrained baselines vs trained MLP CoS vs oracle router (upper bound)\n" "3 hard tasks, 5 seeds (RAG on/off averaged)" ) ax.grid(axis="y", linestyle="--", alpha=0.4) ax.legend(loc="upper left", framealpha=0.9, fontsize=9) fig.tight_layout() png_path = plots_dir / "headline_terminal_reward.png" svg_path = plots_dir / "headline_terminal_reward.svg" fig.savefig(png_path) fig.savefig(svg_path) print(f"wrote {png_path.relative_to(REPO)}") print(f"wrote {svg_path.relative_to(REPO)}") if __name__ == "__main__": main()