Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import argparse | |
| from pathlib import Path | |
| import requests | |
| import torch | |
| import wandb | |
| from datasets import Dataset, load_dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from unsloth import FastLanguageModel, PatchFastRL | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from agent_prompt import SYSTEM_PROMPT | |
| from commitguard_env.parse_action import parse_action | |
| from commitguard_env.reward import compute_reward | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| # --- Configuration --- | |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct") | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b-grpo") | |
| WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard") | |
| ENV_URL = os.getenv("COMMITGUARD_ENV_URL", "").rstrip("/") | |
| CWE_KEYWORDS_PATH = REPO_ROOT / "data" / "cwe_keywords.json" | |
| CWE_KEYWORDS: dict[str, list[str]] = {} | |
| if CWE_KEYWORDS_PATH.exists(): | |
| CWE_KEYWORDS = json.loads(CWE_KEYWORDS_PATH.read_text(encoding="utf-8")) | |
| # Pre-built lookup: sample_id -> ground truth fields (loaded in build_dataset) | |
| SAMPLE_LABELS: dict[str, dict] = {} | |
| def _completion_text(completion) -> str: | |
| return completion[-1]["content"] if isinstance(completion, list) else str(completion) | |
| def get_reward_from_env(prompts, completions, sample_id, **kwargs) -> list[float]: | |
| """ | |
| Judge-preferred path: score completions through a running CommitGuard env. | |
| The env owns ground truth and returns only scalar reward, preserving the | |
| no-leak server/client split required by the submission. | |
| """ | |
| rewards = [] | |
| for p_id, completion in zip(sample_id, completions): | |
| try: | |
| text = _completion_text(completion) | |
| reset = requests.post(f"{ENV_URL}/reset", json={"sample_id": p_id}, timeout=10) | |
| reset.raise_for_status() | |
| step = requests.post(f"{ENV_URL}/step", json={"action": text}, timeout=10) | |
| step.raise_for_status() | |
| rewards.append(float(step.json().get("reward", -1.0))) | |
| except Exception: | |
| rewards.append(-1.0) | |
| return rewards | |
| def get_reward_local(prompts, completions, sample_id, **kwargs) -> list[float]: | |
| """Local fallback for debugging when no env URL is available.""" | |
| rewards = [] | |
| for p_id, completion in zip(sample_id, completions): | |
| text = _completion_text(completion) | |
| action = parse_action(text) | |
| labels = SAMPLE_LABELS.get(p_id, {}) | |
| reward = compute_reward( | |
| action=action, | |
| is_vulnerable=labels.get("is_vulnerable"), | |
| cwe=labels.get("cwe"), | |
| target_file=labels.get("target_file"), | |
| cwe_keywords=CWE_KEYWORDS, | |
| context_requests=0, | |
| ) | |
| rewards.append(reward) | |
| return rewards | |
| def format_prompt(sample): | |
| # Using the Llama-3.2 prompt template from the plan | |
| return { | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Analyze this commit and submit your verdict.\n\nCode diff:\n```diff\n{sample['diff']}\n```"}, | |
| ], | |
| "sample_id": sample["sample_id"], | |
| } | |
| def build_dataset(n_samples: int) -> Dataset: | |
| data_path = REPO_ROOT / "data" / "devign_filtered.jsonl" | |
| if not data_path.exists(): | |
| print(f"Dataset file {data_path} not found.") | |
| return Dataset.from_list([]) | |
| print(f"Loading training samples from {data_path}...") | |
| raw_dataset = load_dataset("json", data_files=str(data_path), split="train") | |
| raw_dataset = raw_dataset.select(range(min(n_samples, len(raw_dataset)))) | |
| for row in raw_dataset: | |
| sid = row["sample_id"] | |
| SAMPLE_LABELS[sid] = { | |
| "is_vulnerable": row.get("is_vulnerable"), | |
| "cwe": row.get("cwe"), | |
| "target_file": row.get("target_file"), | |
| } | |
| dataset = raw_dataset.map(format_prompt) | |
| print(f"Loaded {len(dataset)} samples ({len(SAMPLE_LABELS)} labels cached in-process).") | |
| return dataset | |
| def main(): | |
| global ENV_URL | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--samples", type=int, default=200) | |
| ap.add_argument("--max-steps", type=int, default=300) | |
| ap.add_argument("--save-steps", type=int, default=50) | |
| ap.add_argument("--num-generations", type=int, default=8) | |
| ap.add_argument("--batch-size", type=int, default=1) | |
| ap.add_argument("--grad-accum", type=int, default=8) | |
| ap.add_argument("--lr", type=float, default=5e-6) | |
| ap.add_argument("--no-wandb", action="store_true") | |
| ap.add_argument("--push-to-hub", action="store_true") | |
| ap.add_argument("--hub-model-id", type=str, default="inmodel-labs/commitguard-llama-3b") | |
| ap.add_argument("--env-url", default=ENV_URL, help="Running CommitGuard env URL, e.g. https://...hf.space") | |
| args = ap.parse_args() | |
| ENV_URL = args.env_url.rstrip("/") | |
| if args.num_generations < 2: | |
| raise ValueError("--num-generations must be at least 2 for GRPO") | |
| effective_batch = args.batch_size * args.grad_accum | |
| if effective_batch % args.num_generations != 0: | |
| raise ValueError( | |
| "For single-process GRPO training, --batch-size * --grad-accum " | |
| f"must be divisible by --num-generations; got {args.batch_size} * " | |
| f"{args.grad_accum} = {effective_batch}, num_generations={args.num_generations}." | |
| ) | |
| if not args.no_wandb and not os.getenv("WANDB_API_KEY"): | |
| print("WANDB_API_KEY not set — disabling wandb logging") | |
| args.no_wandb = True | |
| if not args.no_wandb: | |
| wandb.init(project=WANDB_PROJECT, name=f"grpo-{MODEL_NAME.split('/')[-1]}-run1") | |
| # 1. Load Model | |
| hf_token = os.getenv("HF_TOKEN") | |
| print(f"Loading {MODEL_NAME} with Unsloth 4-bit...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=MODEL_NAME, | |
| max_seq_length=2048, | |
| load_in_4bit=True, | |
| fast_inference=True, | |
| max_lora_rank=16, | |
| token=hf_token, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=8, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| # 2. Build dataset | |
| dataset = build_dataset(args.samples) | |
| # 3. GRPO config | |
| training_args = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| num_generations=args.num_generations, | |
| max_completion_length=256, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.lr, | |
| logging_steps=1, | |
| save_steps=args.save_steps, | |
| max_steps=args.max_steps, | |
| report_to="none" if args.no_wandb else "wandb", | |
| bf16=torch.cuda.is_bf16_supported(), | |
| fp16=not torch.cuda.is_bf16_supported(), | |
| ) | |
| reward_func = get_reward_from_env if ENV_URL else get_reward_local | |
| if ENV_URL: | |
| print(f"Using live CommitGuard env for rewards: {ENV_URL}") | |
| else: | |
| print("COMMITGUARD_ENV_URL not set; using local label-grounded reward fallback.") | |
| # 4. Train | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[reward_func], | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| print("Starting GRPO training...") | |
| trainer.train() | |
| # 5. Save | |
| final_dir = f"{OUTPUT_DIR}/final" | |
| model.save_pretrained_merged(final_dir, tokenizer, save_method="lora") | |
| print(f"Training complete. LoRA adapter saved to {final_dir}") | |
| if args.push_to_hub: | |
| print(f"Pushing to HF Hub: {args.hub_model_id}") | |
| model.push_to_hub(args.hub_model_id, token=True) | |
| tokenizer.push_to_hub(args.hub_model_id, token=True) | |
| if __name__ == "__main__": | |
| main() | |