#!/usr/bin/env python3 from __future__ import annotations import argparse import hashlib import json import os import sys from pathlib import Path from typing import Any, Callable, Optional API_BASE_URL = os.getenv('API_BASE_URL', 'https://router.huggingface.co/v1') MODEL_NAME = os.getenv('MODEL_NAME', 'Qwen/Qwen2.5-72B-Instruct') API_KEY = os.getenv('API_KEY') or os.getenv('HF_TOKEN') or '' BENCHMARK = 'autodatalab_plus' TASKS = [t.strip() for t in os.getenv('AUTODATALAB_PLUS_TASKS', 'easy_brief,medium_brief,hard_brief,expert_brief').split(',') if t.strip()] _LOG_SCORE_MIN = 0.001 _LOG_SCORE_MAX = 0.999 REPO = Path(__file__).resolve().parent if str(REPO) not in sys.path: sys.path.insert(0, str(REPO)) from ceo_brief_env.environment import CEOBriefEnvironment, oracle_action_for_observation, required_experts_for_task from ceo_brief_env.models import CoSAction def _bool_str(v: bool) -> str: return str(bool(v)).lower() def log_start(task: str, env: str, model: str) -> None: print(f'[START] task={task} env={env} model={model}', flush=True) def _clamp_log_reward(r: float) -> float: return max(_LOG_SCORE_MIN, min(_LOG_SCORE_MAX, float(r))) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: err = error if error else 'null' r = _clamp_log_reward(reward) print(f'[STEP] step={step} action={action} reward={r:.2f} done={_bool_str(done)} error={err}', flush=True) def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: s = _clamp_log_reward(score) rewards_str = ','.join(f'{_clamp_log_reward(r):.2f}' for r in rewards) print(f'[END] success={_bool_str(success)} steps={steps} score={s:.2f} rewards={rewards_str}', flush=True) def _action_str(action: CoSAction) -> str: return json.dumps(action.model_dump(exclude_none=True), separators=(',', ':'), sort_keys=True) def _single_baseline(obs) -> CoSAction: for e in required_experts_for_task(obs.task_name): if e not in obs.consulted_experts: return CoSAction(action_type='consult', expert_id=e) if obs.current_brief is None: return CoSAction(action_type='summarize') return CoSAction(action_type='submit') def _roundrobin_baseline(obs) -> CoSAction: for expert in ['analyst', 'finance', 'strategy', 'hr']: if expert not in obs.consulted_experts: return CoSAction(action_type='consult', expert_id=expert) if obs.current_brief is None: return CoSAction(action_type='summarize') return CoSAction(action_type='submit') _TRAINED_POLICY: dict = {"model": None, "load_status": None, "load_warned": False} def _trained_action(obs) -> CoSAction: import numpy as np import torch from training.train_cos_local import ( ACTIONS, PolicyNet, featurize, load_policy_state_dict_from_file, ) if _TRAINED_POLICY["model"] is None: import sys ckpt = REPO / "training" / "checkpoints" / "cos_final.pt" if not ckpt.exists(): ckpt = REPO / "training" / "checkpoints" / "cos_ckpt0.pt" model = PolicyNet() status = "random_init" if ckpt.exists(): try: status = load_policy_state_dict_from_file(model, ckpt) except (OSError, RuntimeError, KeyError) as e: print( f"[inference] warning: could not load {ckpt} ({e}); using random init.", file=sys.stderr, ) model = PolicyNet() status = f"load_failed: {e}" else: print( "[inference] warning: no checkpoint; using random PolicyNet. " "Run: python3 training/train_cos_local.py", file=sys.stderr, ) if status != "ok" and "padded" in status and not _TRAINED_POLICY.get("load_warned"): print( f"[inference] loaded checkpoint with compat: {status}. " "Re-run training for best results: python3 training/train_cos_local.py", file=sys.stderr, ) _TRAINED_POLICY["load_warned"] = True model.eval() _TRAINED_POLICY["model"] = model _TRAINED_POLICY["load_status"] = status model = _TRAINED_POLICY["model"] feats = torch.from_numpy(featurize(obs)).unsqueeze(0) with torch.no_grad(): logits = model(feats) idx = int(torch.argmax(logits, dim=-1).item()) return ACTIONS[idx] def _cache_path(task_name: str, step_count: int, prompt: str) -> Path: prompt_hash = hashlib.sha256(prompt.encode('utf-8')).hexdigest()[:16] cache_dir = REPO / 'cache' cache_dir.mkdir(exist_ok=True) return cache_dir / f'{task_name}_step{step_count}_{prompt_hash}.json' def _llm_action(obs) -> CoSAction: from openai import OpenAI client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) prompt = ( 'You are the Chief of Staff in a multi-agent company simulation. ' f'Task: {obs.task_name}. Instruction: {obs.instruction}. ' f'Consulted experts: {obs.consulted_experts}. Current issues: {obs.issues}. ' 'Return exactly one JSON object with keys action_type, expert_id, sub_question_id, notes. ' 'Valid experts: analyst, finance, hr, strategy. Valid actions: consult, ask, summarize, submit, noop.' ) cache_path = _cache_path(obs.task_name, obs.step_count, prompt) if cache_path.exists(): try: payload = json.loads(cache_path.read_text(encoding='utf-8')) except json.JSONDecodeError: payload = {} if not payload or 'action_type' not in payload: payload = {'action_type': 'noop'} return CoSAction.model_validate(payload) completion = client.chat.completions.create( model=MODEL_NAME, temperature=0.1, messages=[{'role': 'system', 'content': 'Respond with strict JSON only.'}, {'role': 'user', 'content': prompt}], ) text = completion.choices[0].message.content or '{}' start = text.find('{') end = text.rfind('}') if start != -1 and end != -1 and end >= start: try: payload = json.loads(text[start : end + 1]) except json.JSONDecodeError: payload = {} else: payload = {} if not payload or 'action_type' not in payload: payload = {'action_type': 'noop'} cache_path.write_text(json.dumps(payload, indent=2), encoding='utf-8') return CoSAction.model_validate(payload) def run_episode_collect( task: str, picker: Callable, label: str, use_rag: bool = False, quiet: bool = False, ) -> dict[str, Any]: """ Run one episode and return structured data (for --export). Same dynamics as `run_episode`. """ env = CEOBriefEnvironment() obs = env.reset(task=task, use_rag=use_rag) rewards: list[float] = [] trace: list[dict[str, Any]] = [] steps = 0 score = 0.001 success = False err: Optional[str] = None if not quiet: log_start(task=task, env=f'{BENCHMARK} rag={use_rag}', model=label) try: while not obs.done and steps < obs.max_steps: steps += 1 action = picker(obs) obs = env.step(action) r = float(obs.reward) rewards.append(r) if not quiet: log_step(steps, _action_str(action), r, bool(obs.done), None) trace.append( { "step": steps, "action": action.model_dump(exclude_none=True), "reward": round(r, 4), "done": bool(obs.done), "consulted_experts": list(obs.consulted_experts), } ) score = float(obs.terminal_grader_score or 0.001) score = max(0.001, min(0.999, score)) success = score >= 0.5 except Exception as exc: # noqa: BLE001 err = str(exc).replace('\n', ' ') if not quiet: log_step(max(steps, 1), 'exception', _LOG_SCORE_MIN, True, err) finally: if not quiet: log_end(success, steps, score, rewards) return { "task": task, "policy_label": label, "use_rag": use_rag, "success": success, "steps": steps, "terminal_score": round(score, 4), "cumulative_reward": round(sum(rewards), 4), "step_rewards": [round(x, 4) for x in rewards], "trace": trace, "error": err, "final_instruction": obs.instruction, "task_difficulty": obs.task_difficulty, "max_steps": obs.max_steps, "consulted_experts": list(obs.consulted_experts), "current_brief": obs.current_brief.model_dump() if obs.current_brief is not None else None, "expert_reports": {k: v.model_dump() for k, v in obs.expert_reports.items()}, } def run_episode(task: str, picker: Callable, label: str, use_rag: bool = False) -> float: out = run_episode_collect(task, picker, label, use_rag=use_rag, quiet=False) return float(out["terminal_score"]) def main() -> int: parser = argparse.ArgumentParser() parser.add_argument('--oracle', action='store_true') parser.add_argument('--baseline', choices=['single', 'roundrobin']) parser.add_argument('--trained', action='store_true', help='use locally trained CoS policy (training/checkpoints/cos_final.pt)') parser.add_argument('--task') parser.add_argument('--ablation', action='store_true') parser.add_argument( '--rag', action='store_true', help='enable organizational memory (RAG) in experts and grounding in the grader (default: off, team parity)', ) parser.add_argument( '--export', metavar='DIR', help='save one JSON per task (plus summary.json) under DIR; uses all tasks from env unless --task is set', ) args = parser.parse_args() use_rag = bool(args.rag) if args.trained: picker = _trained_action label = 'trained-cos' elif args.baseline == 'single': picker = _single_baseline label = 'single-baseline' elif args.baseline == 'roundrobin': picker = _roundrobin_baseline label = 'roundrobin-baseline' elif args.oracle or not API_KEY: picker = oracle_action_for_observation label = 'oracle' else: picker = _llm_action label = MODEL_NAME tasks = [args.task] if args.task else TASKS if args.export: out_dir = Path(args.export).resolve() out_dir.mkdir(parents=True, exist_ok=True) by_task: dict[str, dict[str, Any]] = {} for task in tasks: row = run_episode_collect(task, picker, label, use_rag=use_rag, quiet=True) by_task[task] = row (out_dir / f'{task}.json').write_text( json.dumps(row, indent=2, default=str), encoding='utf-8' ) print( f'[export] {task} terminal={row["terminal_score"]} steps={row["steps"]} success={row["success"]}', flush=True, ) mean_term = float(sum(t['terminal_score'] for t in by_task.values()) / max(len(by_task), 1)) summary: dict[str, Any] = { 'policy_label': label, 'use_rag': use_rag, 'checkpoint_load': _TRAINED_POLICY.get('load_status') if label == 'trained-cos' else None, 'tasks': { t: { 'terminal_score': by_task[t]['terminal_score'], 'success': by_task[t]['success'], 'steps': by_task[t]['steps'], 'cumulative_reward': by_task[t]['cumulative_reward'], } for t in by_task }, 'mean_terminal': round(mean_term, 4), 'export_dir': str(out_dir), } (out_dir / 'summary.json').write_text(json.dumps(summary, indent=2), encoding='utf-8') print(f'[export] wrote {len(tasks) + 1} files to {out_dir}', flush=True) return 0 else: results = {} for task in tasks: results[task] = run_episode(task, picker, label, use_rag=use_rag) if args.ablation and not args.export: ablations: dict[str, dict[str, float]] = {} if picker is _single_baseline: ablations['single'] = dict(results) else: ablations['single'] = { task: run_episode(task, _single_baseline, 'single-baseline', use_rag=use_rag) for task in tasks } if picker is _roundrobin_baseline: ablations['roundrobin'] = dict(results) else: ablations['roundrobin'] = { task: run_episode(task, _roundrobin_baseline, 'roundrobin-baseline', use_rag=use_rag) for task in tasks } ablations['oracle_or_llm'] = dict(results) trained_ckpt = REPO / 'training' / 'checkpoints' / 'cos_final.pt' if trained_ckpt.exists(): if picker is _trained_action: ablations['trained_cos'] = dict(results) else: ablations['trained_cos'] = { task: run_episode(task, _trained_action, 'trained-cos', use_rag=use_rag) for task in tasks } cache_dir = REPO / 'cache' cache_dir.mkdir(exist_ok=True) (cache_dir / 'ablation_results.json').write_text(json.dumps(ablations, indent=2), encoding='utf-8') return 0 if __name__ == '__main__': raise SystemExit(main())