Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Kaggle GPU trainer for AutoDataLab++ CoS routing policies on Qwen2.5-1.5B. | |
| Goal | |
| ---- | |
| Try many lightweight training methods quickly, evaluate the *actual routing | |
| trajectory*, and keep only methods that improve behavior. | |
| Recommended Kaggle setup | |
| ------------------------ | |
| Settings: | |
| - Accelerator: GPU (T4 is enough for Qwen2.5-1.5B QLoRA) | |
| - Internet: On | |
| Install cell: | |
| !pip install -q -U "transformers>=4.45,<4.49" "accelerate>=0.33,<1.1" \\ | |
| "peft>=0.13,<0.16" "bitsandbytes>=0.45.0" "trl>=0.11,<0.13" \\ | |
| "datasets>=2.20" "huggingface_hub>=0.24,<1.0" pydantic pandas matplotlib | |
| Run examples: | |
| # 1) Fast supervised routing policy | |
| !python3 training/kaggle_train_1p5b_methods.py --method sft --epochs 2 | |
| # 2) Preference optimization from base model | |
| !python3 training/kaggle_train_1p5b_methods.py --method dpo --epochs 1 | |
| # 3) Best practical path: SFT first, then DPO on top | |
| !python3 training/kaggle_train_1p5b_methods.py --method sft_then_dpo --sft-epochs 2 --dpo-epochs 1 | |
| Outputs | |
| ------- | |
| Each run writes an adapter plus evidence files under /kaggle/working/cos_1p5b_runs/<run_name>: | |
| - adapter/ PEFT LoRA adapter | |
| - eval/evidence.json per-task trajectory evidence | |
| - eval/evidence.md PPT/report-ready table | |
| - eval/reward_curve.png cumulative reward plot | |
| Important | |
| --------- | |
| This is for *training evidence*. The evaluator runs with: | |
| - strict reward shaping | |
| - auto_fill_required=False during model-controlled policy steps | |
| So if a policy tries `summarize -> summarize`, it gets punished and the output | |
| shows the failure clearly. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| os.environ.setdefault("TORCH_COMPILE_DISABLE", "1") | |
| os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| REPO = Path(__file__).resolve().parents[2] | |
| if str(REPO) not in sys.path: | |
| sys.path.insert(0, str(REPO)) | |
| from ceo_brief_env.environment import CEOBriefEnvironment, oracle_action_for_observation, required_experts_for_task | |
| from ceo_brief_env.models import CoSAction, CoSObservation | |
| TASKS = ["easy_brief", "medium_brief", "hard_brief", "expert_brief", "risk_brief", "crisis_brief"] | |
| VALID_ACTIONS = {"consult", "ask", "summarize", "submit", "noop"} | |
| VALID_EXPERTS = {"analyst", "finance", "hr", "strategy"} | |
| JSON_RE = re.compile(r"\{[^{}]*\}", re.S) | |
| SYSTEM_PROMPT = ( | |
| "You are the Chief of Staff in AutoDataLab++. You orchestrate four specialists: " | |
| "analyst, finance, strategy, hr. Reply with STRICT JSON only.\n" | |
| 'Schema: {"action_type": one of [consult, ask, summarize, submit, noop], ' | |
| '"expert_id": one of [analyst, finance, hr, strategy] or null}.\n' | |
| "Rules: consult each required expert exactly once when required, then summarize, then submit. " | |
| "Never summarize while required experts are missing. Never repeat summarize." | |
| ) | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| try: | |
| import numpy as np | |
| import torch | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| except Exception: | |
| pass | |
| def action_json(action: CoSAction) -> str: | |
| return json.dumps(action.model_dump(exclude_none=True), separators=(",", ":"), sort_keys=True) | |
| def render_obs(obs: CoSObservation, variant: int = 0) -> str: | |
| required = required_experts_for_task(obs.task_name) | |
| missing = [e for e in required if e not in obs.consulted_experts] | |
| base = { | |
| "task": obs.task_name, | |
| "difficulty": obs.task_difficulty, | |
| "step": obs.step_count, | |
| "max_steps": obs.max_steps, | |
| "rag_enabled": obs.rag_enabled, | |
| "required_experts": required, | |
| "consulted_experts": list(obs.consulted_experts), | |
| "missing_required_experts": missing, | |
| "brief_ready": obs.current_brief is not None, | |
| "issues": obs.issues[-3:], | |
| } | |
| if variant == 0: | |
| return json.dumps(base, separators=(",", ":")) | |
| if variant == 1: | |
| return ( | |
| f"task={obs.task_name}; step={obs.step_count}/{obs.max_steps}; " | |
| f"required={required}; consulted={obs.consulted_experts}; missing={missing}; " | |
| f"brief_ready={obs.current_brief is not None}; choose next JSON action." | |
| ) | |
| return ( | |
| "Decision checklist:\n" | |
| f"- task: {obs.task_name}\n" | |
| f"- required experts: {', '.join(required)}\n" | |
| f"- consulted experts: {', '.join(obs.consulted_experts) or 'none'}\n" | |
| f"- missing required experts: {', '.join(missing) or 'none'}\n" | |
| f"- brief ready: {obs.current_brief is not None}\n" | |
| "Return the single next strict JSON action." | |
| ) | |
| def messages_for(obs: CoSObservation, variant: int, assistant: str | None = None) -> list[dict[str, str]]: | |
| msgs = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": render_obs(obs, variant=variant)}, | |
| ] | |
| if assistant is not None: | |
| msgs.append({"role": "assistant", "content": assistant}) | |
| return msgs | |
| def rejected_actions_for(obs: CoSObservation, chosen: CoSAction) -> list[CoSAction]: | |
| required = required_experts_for_task(obs.task_name) | |
| missing = [e for e in required if e not in obs.consulted_experts] | |
| rejects: list[CoSAction] = [] | |
| if missing: | |
| rejects.extend( | |
| [ | |
| CoSAction(action_type="summarize"), | |
| CoSAction(action_type="submit"), | |
| CoSAction(action_type="noop"), | |
| ] | |
| ) | |
| if obs.consulted_experts: | |
| rejects.append(CoSAction(action_type="consult", expert_id=obs.consulted_experts[0])) | |
| elif obs.current_brief is None: | |
| rejects.extend( | |
| [ | |
| CoSAction(action_type="consult", expert_id=required[0] if required else "analyst"), | |
| CoSAction(action_type="submit"), | |
| CoSAction(action_type="noop"), | |
| ] | |
| ) | |
| else: | |
| rejects.extend( | |
| [ | |
| CoSAction(action_type="summarize"), | |
| CoSAction(action_type="consult", expert_id=required[0] if required else "analyst"), | |
| CoSAction(action_type="noop"), | |
| ] | |
| ) | |
| chosen_s = action_json(chosen) | |
| out = [] | |
| seen = {chosen_s} | |
| for r in rejects: | |
| s = action_json(r) | |
| if s not in seen: | |
| seen.add(s) | |
| out.append(r) | |
| return out | |
| def collect_routing_records(tasks: list[str], rag_modes: list[bool], variants: int = 3) -> tuple[list[dict], list[dict]]: | |
| """Return (sft_records, dpo_records).""" | |
| sft_records: list[dict[str, Any]] = [] | |
| dpo_records: list[dict[str, Any]] = [] | |
| for task in tasks: | |
| for use_rag in rag_modes: | |
| env = CEOBriefEnvironment(auto_fill_required=False) | |
| obs = env.reset(task=task, use_rag=use_rag) | |
| while not obs.done and obs.step_count < obs.max_steps: | |
| chosen = oracle_action_for_observation(obs) | |
| chosen_s = action_json(chosen) | |
| for variant in range(variants): | |
| prompt_msgs = messages_for(obs, variant=variant) | |
| sft_records.append({"messages": prompt_msgs + [{"role": "assistant", "content": chosen_s}]}) | |
| for rejected in rejected_actions_for(obs, chosen)[:2]: | |
| dpo_records.append( | |
| { | |
| "prompt": prompt_msgs, | |
| "chosen": chosen_s, | |
| "rejected": action_json(rejected), | |
| } | |
| ) | |
| obs = env.step(chosen) | |
| return sft_records, dpo_records | |
| def parse_action(text: str) -> CoSAction: | |
| match = JSON_RE.search(text or "") | |
| if not match: | |
| return CoSAction(action_type="noop") | |
| try: | |
| payload = json.loads(match.group(0)) | |
| except Exception: | |
| return CoSAction(action_type="noop") | |
| action_type = payload.get("action_type") | |
| if action_type not in VALID_ACTIONS: | |
| return CoSAction(action_type="noop") | |
| expert_id = payload.get("expert_id") | |
| if expert_id is not None and expert_id not in VALID_EXPERTS: | |
| expert_id = None | |
| return CoSAction(action_type=action_type, expert_id=expert_id) | |
| def action_label(action: CoSAction) -> str: | |
| if action.action_type in {"consult", "ask"}: | |
| return f"{action.action_type}:{action.expert_id or 'null'}" | |
| return action.action_type | |
| class LoadedModel: | |
| tokenizer: Any | |
| model: Any | |
| def load_base_with_lora(model_id: str, use_4bit: bool, hf_token: str | None): | |
| import torch | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| tok = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| tok.padding_side = "right" | |
| bnb = None | |
| if use_4bit: | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_token, | |
| device_map="auto", | |
| quantization_config=bnb, | |
| torch_dtype=torch.float16, | |
| ) | |
| if use_4bit: | |
| model = prepare_model_for_kbit_training(model) | |
| lora_cfg = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| ) | |
| model = get_peft_model(model, lora_cfg) | |
| model.config.use_cache = False | |
| return LoadedModel(tok, model) | |
| def format_chat(tokenizer, messages: list[dict[str, str]], add_generation_prompt: bool = False) -> str: | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=add_generation_prompt | |
| ) | |
| text = "" | |
| for m in messages: | |
| text += f"{m['role'].upper()}:\n{m['content']}\n" | |
| if add_generation_prompt: | |
| text += "ASSISTANT:\n" | |
| return text | |
| class SFTTorchDataset: | |
| def __init__(self, records: list[dict], tokenizer, max_length: int = 1024): | |
| self.records = records | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.records) | |
| def __getitem__(self, idx: int) -> dict[str, Any]: | |
| rec = self.records[idx] | |
| messages = rec["messages"] | |
| prompt = format_chat(self.tokenizer, messages[:-1], add_generation_prompt=True) | |
| answer = messages[-1]["content"] + (self.tokenizer.eos_token or "") | |
| full = prompt + answer | |
| enc = self.tokenizer(full, truncation=True, max_length=self.max_length) | |
| prompt_ids = self.tokenizer(prompt, truncation=True, max_length=self.max_length)["input_ids"] | |
| labels = list(enc["input_ids"]) | |
| mask_n = min(len(prompt_ids), len(labels)) | |
| labels[:mask_n] = [-100] * mask_n | |
| return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": labels} | |
| def train_sft(args, out_dir: Path) -> Path: | |
| import torch | |
| from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments | |
| sft_records, _dpo_records = collect_routing_records(TASKS, [False, True], variants=args.variants) | |
| if args.max_train_examples: | |
| sft_records = sft_records[: args.max_train_examples] | |
| print(f"[data] SFT examples={len(sft_records)}", flush=True) | |
| loaded = load_base_with_lora(args.model_id, use_4bit=not args.no_4bit, hf_token=args.hf_token or None) | |
| train_ds = SFTTorchDataset(sft_records, loaded.tokenizer, max_length=args.max_length) | |
| collator = DataCollatorForSeq2Seq(loaded.tokenizer, model=loaded.model, padding=True) | |
| train_args = TrainingArguments( | |
| output_dir=str(out_dir / "trainer"), | |
| logging_dir=str(out_dir / "trainer" / "logs"), | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| num_train_epochs=args.epochs, | |
| learning_rate=args.lr, | |
| logging_steps=5, | |
| save_strategy="no", | |
| report_to=args.report_to, | |
| run_name=args.run_name or f"{args.method}_qwen15b", | |
| fp16=True, | |
| optim="paged_adamw_8bit" if not args.no_4bit else "adamw_torch", | |
| max_grad_norm=0.3, | |
| warmup_ratio=0.03, | |
| remove_unused_columns=False, | |
| ) | |
| trainer = Trainer( | |
| model=loaded.model, | |
| args=train_args, | |
| train_dataset=train_ds, | |
| data_collator=collator, | |
| ) | |
| trainer.train() | |
| adapter_dir = out_dir / "adapter" | |
| loaded.model.save_pretrained(adapter_dir) | |
| loaded.tokenizer.save_pretrained(adapter_dir) | |
| del loaded | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return adapter_dir | |
| def train_dpo(args, out_dir: Path, base_adapter: Path | None = None) -> Path: | |
| import torch | |
| from datasets import Dataset | |
| from peft import LoraConfig, PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from trl import DPOConfig, DPOTrainer | |
| _sft_records, dpo_records = collect_routing_records(TASKS, [False, True], variants=args.variants) | |
| if args.max_train_examples: | |
| dpo_records = dpo_records[: args.max_train_examples] | |
| print(f"[data] DPO pairs={len(dpo_records)}", flush=True) | |
| tok_source = str(base_adapter) if base_adapter else args.model_id | |
| tok = AutoTokenizer.from_pretrained(tok_source, token=args.hf_token or None) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| tok.padding_side = "right" | |
| def to_row(rec: dict[str, Any]) -> dict[str, str]: | |
| return { | |
| "prompt": format_chat(tok, rec["prompt"], add_generation_prompt=True), | |
| "chosen": rec["chosen"] + (tok.eos_token or ""), | |
| "rejected": rec["rejected"] + (tok.eos_token or ""), | |
| } | |
| ds = Dataset.from_list([to_row(r) for r in dpo_records]) | |
| bnb = None | |
| if not args.no_4bit: | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_id, | |
| token=args.hf_token or None, | |
| device_map="auto", | |
| quantization_config=bnb, | |
| torch_dtype=torch.float16, | |
| ) | |
| model.config.use_cache = False | |
| if base_adapter: | |
| model.resize_token_embeddings(len(tok)) | |
| model = PeftModel.from_pretrained(model, str(base_adapter), is_trainable=True) | |
| else: | |
| lora_cfg = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| ) | |
| model.add_adapter(lora_cfg) | |
| cfg = DPOConfig( | |
| output_dir=str(out_dir / "trainer"), | |
| logging_dir=str(out_dir / "trainer" / "logs"), | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| num_train_epochs=args.dpo_epochs if args.method == "sft_then_dpo" else args.epochs, | |
| learning_rate=args.dpo_lr, | |
| logging_steps=5, | |
| save_strategy="no", | |
| report_to=args.report_to, | |
| run_name=args.run_name or f"{args.method}_qwen15b", | |
| fp16=True, | |
| beta=args.dpo_beta, | |
| max_length=args.max_length, | |
| max_prompt_length=min(768, args.max_length - 64), | |
| remove_unused_columns=False, | |
| ) | |
| trainer = DPOTrainer(model=model, ref_model=None, args=cfg, train_dataset=ds, tokenizer=tok) | |
| trainer.train() | |
| adapter_dir = out_dir / "adapter" | |
| model.save_pretrained(adapter_dir) | |
| tok.save_pretrained(adapter_dir) | |
| del trainer, model, tok | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return adapter_dir | |
| def load_for_eval(model_id: str, adapter_dir: Path, use_4bit: bool, hf_token: str | None): | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| tok = AutoTokenizer.from_pretrained(adapter_dir, token=hf_token) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| bnb = None | |
| if use_4bit: | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_token, | |
| device_map="auto", | |
| quantization_config=bnb, | |
| torch_dtype=torch.float16, | |
| ) | |
| model.resize_token_embeddings(len(tok)) | |
| model = PeftModel.from_pretrained(model, adapter_dir) | |
| model.eval() | |
| return tok, model | |
| def generate_action(tok, model, obs: CoSObservation, max_new_tokens: int = 48) -> tuple[CoSAction, str]: | |
| import torch | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": render_obs(obs, variant=0)}, | |
| ] | |
| prompt = format_chat(tok, messages, add_generation_prompt=True) | |
| ids = tok(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tok.pad_token_id or tok.eos_token_id, | |
| ) | |
| text = tok.decode(out[0][ids["input_ids"].shape[1] :], skip_special_tokens=True) | |
| return parse_action(text), text.strip() | |
| def deterministic_next(obs: CoSObservation, task: str) -> CoSAction: | |
| missing = [e for e in required_experts_for_task(task) if e not in obs.consulted_experts] | |
| if missing: | |
| return CoSAction(action_type="consult", expert_id=missing[0]) | |
| if obs.current_brief is None: | |
| return CoSAction(action_type="summarize") | |
| return CoSAction(action_type="submit") | |
| def evaluate_adapter(args, adapter_dir: Path, eval_dir: Path) -> list[dict[str, Any]]: | |
| import matplotlib.pyplot as plt | |
| import torch | |
| tok, model = load_for_eval(args.model_id, adapter_dir, use_4bit=not args.no_4bit, hf_token=args.hf_token or None) | |
| rows: list[dict[str, Any]] = [] | |
| rag_modes = [s.strip().lower() in {"1", "true", "yes", "rag"} for s in args.eval_rag_modes.split(",")] | |
| for use_rag in rag_modes: | |
| for task in args.eval_tasks.split(","): | |
| task = task.strip() | |
| if not task: | |
| continue | |
| env = CEOBriefEnvironment(shaping="strict", auto_fill_required=False) | |
| obs = env.reset(task=task, use_rag=use_rag) | |
| rewards: list[float] = [] | |
| trace: list[dict[str, Any]] = [] | |
| routed: list[str] = [] | |
| for _ in range(args.policy_steps): | |
| if obs.done: | |
| break | |
| action, completion = generate_action(tok, model, obs, max_new_tokens=args.eval_new_tokens) | |
| obs = env.step(action) | |
| rewards.append(float(obs.reward)) | |
| if action.expert_id in required_experts_for_task(task) and action.expert_id not in routed: | |
| routed.append(action.expert_id) | |
| trace.append( | |
| { | |
| "step": obs.step_count, | |
| "action": action.model_dump(exclude_none=True), | |
| "action_label": action_label(action), | |
| "completion_preview": completion[:300], | |
| "reward": round(float(obs.reward), 4), | |
| "consulted_after": list(obs.consulted_experts), | |
| "model_routed_required": list(routed), | |
| } | |
| ) | |
| fallback: list[str] = [] | |
| while not obs.done and obs.step_count < obs.max_steps: | |
| act = deterministic_next(obs, task) | |
| obs = env.step(act) | |
| rewards.append(float(obs.reward)) | |
| fallback.append(action_label(act)) | |
| rows.append( | |
| { | |
| "task": task, | |
| "rag": bool(use_rag), | |
| "action_sequence": [t["action_label"] for t in trace], | |
| "model_routed_required": routed, | |
| "required_experts": required_experts_for_task(task), | |
| "fallback": fallback, | |
| "needed_fallback": bool(fallback), | |
| "policy_reward": round(sum(t["reward"] for t in trace), 4), | |
| "total_reward": round(sum(rewards), 4), | |
| "terminal_score": round(float(obs.terminal_grader_score or 0.0), 4), | |
| "trace": trace, | |
| } | |
| ) | |
| eval_dir.mkdir(parents=True, exist_ok=True) | |
| (eval_dir / "evidence.json").write_text(json.dumps(rows, indent=2, default=str), encoding="utf-8") | |
| md = [ | |
| "# AutoDataLab++ 1.5B Training Evidence", | |
| "", | |
| "| Task | RAG | Action sequence | Routed required experts | Needed fallback | Policy reward | Terminal |", | |
| "|---|---:|---|---|---:|---:|---:|", | |
| ] | |
| for row in rows: | |
| md.append( | |
| f"| {row['task']} | {row['rag']} | `{' -> '.join(row['action_sequence'])}` | " | |
| f"{', '.join(row['model_routed_required']) or '-'} | {row['needed_fallback']} | " | |
| f"{row['policy_reward']} | {row['terminal_score']} |" | |
| ) | |
| (eval_dir / "evidence.md").write_text("\n".join(md), encoding="utf-8") | |
| plt.figure(figsize=(8, 4)) | |
| for row in rows: | |
| cum = [] | |
| total = 0.0 | |
| # plot policy-step rewards only, because that is the learning evidence | |
| for t in row["trace"]: | |
| total += float(t["reward"]) | |
| cum.append(total) | |
| if cum: | |
| plt.plot(range(1, len(cum) + 1), cum, marker="o", label=f"{row['task']} rag={row['rag']}") | |
| plt.title("Model-controlled policy reward (strict shaping)") | |
| plt.xlabel("policy step") | |
| plt.ylabel("cumulative reward") | |
| plt.grid(True, alpha=0.3) | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(eval_dir / "reward_curve.png", dpi=160) | |
| plt.close() | |
| del model, tok | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return rows | |
| def main() -> int: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--method", choices=("sft", "dpo", "sft_then_dpo"), required=True) | |
| ap.add_argument("--model-id", default="Qwen/Qwen2.5-1.5B-Instruct") | |
| ap.add_argument("--run-name", default="") | |
| ap.add_argument("--out-root", type=Path, default=Path("/kaggle/working/cos_1p5b_runs") if Path("/kaggle/working").is_dir() else Path("cos_1p5b_runs")) | |
| ap.add_argument("--hf-token", default=os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "") | |
| ap.add_argument("--no-4bit", action="store_true") | |
| ap.add_argument("--epochs", type=float, default=2.0) | |
| ap.add_argument("--sft-epochs", type=float, default=2.0) | |
| ap.add_argument("--dpo-epochs", type=float, default=1.0) | |
| ap.add_argument("--lr", type=float, default=2e-4) | |
| ap.add_argument("--dpo-lr", type=float, default=5e-5) | |
| ap.add_argument("--dpo-beta", type=float, default=0.1) | |
| ap.add_argument("--batch-size", type=int, default=2) | |
| ap.add_argument("--grad-accum", type=int, default=8) | |
| ap.add_argument("--max-length", type=int, default=1024) | |
| ap.add_argument("--variants", type=int, default=3) | |
| ap.add_argument("--max-train-examples", type=int, default=0) | |
| ap.add_argument("--policy-steps", type=int, default=6) | |
| ap.add_argument("--eval-new-tokens", type=int, default=48) | |
| ap.add_argument("--eval-tasks", default="expert_brief,risk_brief,crisis_brief") | |
| ap.add_argument("--eval-rag-modes", default="false,true", help="comma list: false,true") | |
| ap.add_argument( | |
| "--report-to", | |
| choices=("none", "tensorboard", "wandb"), | |
| default=os.environ.get("REPORT_TO", "tensorboard"), | |
| help="experimental tracking backend for Trainer logs; default tensorboard for judging", | |
| ) | |
| ap.add_argument("--seed", type=int, default=42) | |
| args = ap.parse_args() | |
| if args.max_train_examples == 0: | |
| args.max_train_examples = None | |
| set_seed(args.seed) | |
| run_name = args.run_name or f"{args.method}_qwen15b_seed{args.seed}" | |
| out_dir = args.out_root / run_name | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| (out_dir / "config.json").write_text(json.dumps(vars(args), indent=2, default=str), encoding="utf-8") | |
| print(f"[run] {run_name} -> {out_dir}", flush=True) | |
| if args.method == "sft": | |
| adapter = train_sft(args, out_dir) | |
| elif args.method == "dpo": | |
| adapter = train_dpo(args, out_dir) | |
| else: | |
| sft_dir = out_dir / "sft_stage" | |
| dpo_dir = out_dir / "dpo_stage" | |
| old_epochs = args.epochs | |
| args.epochs = args.sft_epochs | |
| sft_adapter = train_sft(args, sft_dir) | |
| args.epochs = old_epochs | |
| adapter = train_dpo(args, dpo_dir, base_adapter=sft_adapter) | |
| final_adapter = out_dir / "adapter" | |
| if final_adapter.exists(): | |
| import shutil | |
| shutil.rmtree(final_adapter) | |
| import shutil | |
| shutil.copytree(adapter, final_adapter) | |
| adapter = final_adapter | |
| evidence = evaluate_adapter(args, adapter, out_dir / "eval") | |
| print("\n=== EVIDENCE SUMMARY ===", flush=True) | |
| for row in evidence: | |
| print( | |
| f"{row['task']}: {' -> '.join(row['action_sequence'])} | " | |
| f"routed={row['model_routed_required']} | fallback={row['needed_fallback']} | " | |
| f"policy_reward={row['policy_reward']} terminal={row['terminal_score']}", | |
| flush=True, | |
| ) | |
| print(f"\n[adapter] {adapter}", flush=True) | |
| print(f"[eval] {out_dir / 'eval'}", flush=True) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |