Spaces:
Running
Running
| """Headline benchmark with the eval-mode safety net OFF. | |
| Compares four routing policies on the three hard tasks, with five seeds and | |
| both RAG settings. The safety net is OFF (``auto_fill_required=False``) so an | |
| incomplete brief gets penalised by the grader. | |
| Policies | |
| -------- | |
| - ``base_naive``: untrained baseline, consults the analyst then submits. | |
| - ``base_roundrobin``: untrained second baseline, walks experts in fixed order. | |
| - ``trained_mlp``: actually trained CoS policy. A 2-layer MLP routing policy | |
| (REINFORCE, 600 episodes, lr 0.003) loaded from | |
| ``training/checkpoints/cos_final.pt``. This is the headline trained-model | |
| number. If torch / the checkpoint isn't available the row is skipped and the | |
| plot annotates that. | |
| - ``oracle_router``: deterministic upper bound. The handcoded routing policy | |
| that always consults the required experts in the canonical order. We label | |
| this an *upper bound*, not a trained model -- our trained policies (the MLP | |
| above and the SFT/GRPO LLMs) are trained to imitate this behaviour, and the | |
| number tells you how high a perfect routing policy can score on the | |
| current grader / RAG settings. | |
| Outputs | |
| ------- | |
| - ``training/evidence/headline_benchmark.json`` -- raw cells + per-seed runs. | |
| - ``training/evidence/plots/headline_terminal_reward.{png,svg}`` -- the chart. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import random | |
| import statistics | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| REPO = Path(__file__).resolve().parents[2] | |
| if str(REPO) not in sys.path: | |
| sys.path.insert(0, str(REPO)) | |
| from ceo_brief_env.environment import ( # noqa: E402 | |
| CEOBriefEnvironment, | |
| oracle_action_for_observation, | |
| ) | |
| from ceo_brief_env.models import CoSAction, CoSObservation # noqa: E402 | |
| HARD_TASKS = ["hard_brief", "expert_brief", "crisis_brief"] | |
| SEEDS = [11, 23, 47, 91, 137] | |
| RAG_SETTINGS = [False, True] | |
| def naive_picker(obs: CoSObservation) -> CoSAction: | |
| """Untrained base policy: consult analyst, then submit. Will be incomplete.""" | |
| if "analyst" not in obs.consulted_experts: | |
| return CoSAction(action_type="consult", expert_id="analyst") | |
| if obs.current_brief is None: | |
| return CoSAction(action_type="summarize") | |
| return CoSAction(action_type="submit") | |
| def roundrobin_picker(obs: CoSObservation) -> CoSAction: | |
| """Second base policy: walks experts in fixed order then submits.""" | |
| for expert in ["finance", "analyst", "hr", "strategy"]: | |
| if expert not in obs.consulted_experts: | |
| return CoSAction(action_type="consult", expert_id=expert) | |
| if obs.current_brief is None: | |
| return CoSAction(action_type="summarize") | |
| return CoSAction(action_type="submit") | |
| def _load_trained_mlp_picker() -> tuple[Callable[[CoSObservation], CoSAction], str] | None: | |
| """Load the actually-trained MLP CoS routing policy. | |
| Returns ``(picker_fn, info)`` or ``None`` if torch / the checkpoint isn't | |
| available. The MLP was trained with REINFORCE for 600 episodes (see | |
| ``training/scripts/train_cos_local.py``); ``cos_final.pt`` is the resulting | |
| state dict. | |
| """ | |
| try: | |
| import numpy as np # noqa: F401 | |
| import torch | |
| sys.path.insert(0, str(REPO / "training" / "scripts")) | |
| from train_cos_local import ( # type: ignore | |
| ACTIONS, | |
| PolicyNet, | |
| featurize, | |
| load_policy_state_dict_from_file, | |
| ) | |
| except Exception as exc: # pragma: no cover -- env without torch | |
| return None | |
| ckpt_paths = [ | |
| REPO / "training" / "checkpoints" / "cos_final.pt", | |
| REPO / "training" / "checkpoints" / "cos_ckpt0.pt", | |
| ] | |
| ckpt = next((p for p in ckpt_paths if p.exists()), None) | |
| if ckpt is None: | |
| return None | |
| model = PolicyNet() | |
| info = load_policy_state_dict_from_file(model, ckpt) | |
| model.eval() | |
| def picker(obs: CoSObservation) -> CoSAction: | |
| feats = torch.from_numpy(featurize(obs)).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(feats) | |
| idx = int(torch.argmax(logits, dim=-1).item()) | |
| return ACTIONS[idx] | |
| return picker, f"{ckpt.name} ({info})" | |
| def build_policies() -> tuple[dict[str, Callable[[CoSObservation], CoSAction]], dict[str, str]]: | |
| """Return the policy table plus a metadata dict for the plot legend.""" | |
| policies: dict[str, Callable[[CoSObservation], CoSAction]] = { | |
| "base_naive": naive_picker, | |
| "base_roundrobin": roundrobin_picker, | |
| } | |
| info: dict[str, str] = { | |
| "base_naive": "untrained baseline (analyst-only)", | |
| "base_roundrobin": "untrained baseline (fixed order)", | |
| } | |
| trained = _load_trained_mlp_picker() | |
| if trained is not None: | |
| policies["trained_mlp"] = trained[0] | |
| info["trained_mlp"] = f"trained MLP CoS (REINFORCE, 600 ep) - {trained[1]}" | |
| policies["oracle_router"] = oracle_action_for_observation | |
| info["oracle_router"] = "oracle router (upper bound, handcoded canonical sequence)" | |
| return policies, info | |
| def run_episode(picker: Callable[[CoSObservation], CoSAction], task: str, use_rag: bool, seed: int) -> dict[str, Any]: | |
| random.seed(seed) | |
| env = CEOBriefEnvironment(shaping="strict", auto_fill_required=False) | |
| obs = env.reset(task=task, use_rag=use_rag) | |
| max_steps = obs.max_steps | |
| steps = 0 | |
| consulted_path: list[str] = [] | |
| while not obs.done and steps < max_steps: | |
| steps += 1 | |
| action = picker(obs) | |
| if action.action_type == "consult" and action.expert_id: | |
| consulted_path.append(action.expert_id) | |
| obs = env.step(action) | |
| terminal = float(obs.terminal_grader_score or 0.0) | |
| cumulative = float(obs.reward_breakdown.cumulative if obs.reward_breakdown else 0.0) | |
| return { | |
| "terminal": terminal, | |
| "cumulative": cumulative, | |
| "steps": steps, | |
| "consulted": list(obs.consulted_experts), | |
| "path": consulted_path, | |
| "submitted": bool(obs.done), | |
| } | |
| def aggregate(samples: list[float]) -> dict[str, float]: | |
| if not samples: | |
| return {"mean": 0.0, "std": 0.0, "n": 0} | |
| if len(samples) == 1: | |
| return {"mean": samples[0], "std": 0.0, "n": 1} | |
| return { | |
| "mean": statistics.fmean(samples), | |
| "std": statistics.pstdev(samples), | |
| "n": len(samples), | |
| } | |
| def main() -> None: | |
| out_dir = REPO / "training" / "evidence" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| plots_dir = out_dir / "plots" | |
| plots_dir.mkdir(parents=True, exist_ok=True) | |
| policies, policy_info = build_policies() | |
| results: dict[str, Any] = { | |
| "schema": "autodatalab-plus.headline_benchmark.v2", | |
| "config": { | |
| "tasks": HARD_TASKS, | |
| "policies": list(policies.keys()), | |
| "policy_info": policy_info, | |
| "seeds": SEEDS, | |
| "rag_settings": RAG_SETTINGS, | |
| "auto_fill_required": False, | |
| "shaping": "strict", | |
| }, | |
| "cells": [], | |
| } | |
| for task in HARD_TASKS: | |
| for use_rag in RAG_SETTINGS: | |
| for policy_name, policy_fn in policies.items(): | |
| terminals: list[float] = [] | |
| cumulatives: list[float] = [] | |
| runs: list[dict[str, Any]] = [] | |
| for seed in SEEDS: | |
| rollout = run_episode(policy_fn, task, use_rag, seed) | |
| terminals.append(rollout["terminal"]) | |
| cumulatives.append(rollout["cumulative"]) | |
| runs.append({"seed": seed, **rollout}) | |
| cell = { | |
| "task": task, | |
| "policy": policy_name, | |
| "use_rag": use_rag, | |
| "terminal": aggregate(terminals), | |
| "cumulative": aggregate(cumulatives), | |
| "runs": runs, | |
| } | |
| results["cells"].append(cell) | |
| print( | |
| f"task={task:13s} rag={str(use_rag):5s} policy={policy_name:18s} " | |
| f"terminal_mean={cell['terminal']['mean']:.3f} " | |
| f"std={cell['terminal']['std']:.3f}" | |
| ) | |
| json_path = out_dir / "headline_benchmark.json" | |
| json_path.write_text(json.dumps(results, indent=2)) | |
| print(f"\nwrote {json_path.relative_to(REPO)}") | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("matplotlib not available; skipping money plot") | |
| return | |
| by_cell: dict[tuple[str, str], list[float]] = {} | |
| for cell in results["cells"]: | |
| key = (cell["task"], cell["policy"]) | |
| by_cell.setdefault(key, []).extend([r["terminal"] for r in cell["runs"]]) | |
| plot_order = [p for p in ("base_naive", "base_roundrobin", "trained_mlp", "oracle_router") if p in policies] | |
| pretty_name = { | |
| "base_naive": "base (naive)", | |
| "base_roundrobin": "base (round-robin)", | |
| "trained_mlp": "trained MLP CoS\n(REINFORCE, 600 ep)", | |
| "oracle_router": "oracle router\n(upper bound)", | |
| } | |
| color = { | |
| "base_naive": "#9aa0a6", | |
| "base_roundrobin": "#bdc1c6", | |
| "trained_mlp": "#34a853", | |
| "oracle_router": "#1a73e8", | |
| } | |
| hatch = { | |
| "oracle_router": "//", | |
| } | |
| fig, ax = plt.subplots(figsize=(10.4, 5.0), dpi=140) | |
| n_groups = len(HARD_TASKS) | |
| width = 0.78 / max(1, len(plot_order)) | |
| xs = list(range(n_groups)) | |
| for i, policy in enumerate(plot_order): | |
| means = [] | |
| stds = [] | |
| for task in HARD_TASKS: | |
| samples = by_cell.get((task, policy), []) | |
| agg = aggregate(samples) | |
| means.append(agg["mean"]) | |
| stds.append(agg["std"]) | |
| offsets = [x + (i - (len(plot_order) - 1) / 2.0) * width for x in xs] | |
| bars = ax.bar( | |
| offsets, | |
| means, | |
| width=width, | |
| yerr=stds, | |
| capsize=4, | |
| color=color[policy], | |
| edgecolor="black", | |
| linewidth=0.6, | |
| hatch=hatch.get(policy, ""), | |
| label=pretty_name[policy], | |
| ) | |
| for bar, mean in zip(bars, means): | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 0.015, | |
| f"{mean:.2f}", | |
| ha="center", | |
| va="bottom", | |
| fontsize=8, | |
| ) | |
| ax.set_xticks(xs) | |
| ax.set_xticklabels([t.replace("_brief", "") for t in HARD_TASKS]) | |
| ax.set_ylim(0.0, 1.0) | |
| ax.set_ylabel("Terminal grader score (0..1)") | |
| ax.set_title( | |
| "Terminal reward, fallback disabled\n" | |
| "untrained baselines vs trained MLP CoS vs oracle router (upper bound)\n" | |
| "3 hard tasks, 5 seeds (RAG on/off averaged)" | |
| ) | |
| ax.grid(axis="y", linestyle="--", alpha=0.4) | |
| ax.legend(loc="upper left", framealpha=0.9, fontsize=9) | |
| fig.tight_layout() | |
| png_path = plots_dir / "headline_terminal_reward.png" | |
| svg_path = plots_dir / "headline_terminal_reward.svg" | |
| fig.savefig(png_path) | |
| fig.savefig(svg_path) | |
| print(f"wrote {png_path.relative_to(REPO)}") | |
| print(f"wrote {svg_path.relative_to(REPO)}") | |
| if __name__ == "__main__": | |
| main() | |