{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CommitGuard GRPO Training Notebook\n", "\n", "Train Llama-3.2-3B-Instruct to detect exploitable vulnerabilities in code commits using GRPO (Group Relative Policy Optimization).\n", "\n", "**Requirements:** NVIDIA GPU with 16 GB VRAM (L4/A100/T4). Run this notebook on a GCP VM with GPU attached.\n", "\n", "## Setup\n", "Connect to this notebook via SSH tunnel:\n", "```bash\n", "# On GCP VM:\n", "jupyter notebook --no-browser --port=8888\n", "\n", "# On your local machine:\n", "gcloud compute ssh commitguard-train --zone=us-central1-a -- -NL 8888:localhost:8888\n", "# Then open http://localhost:8888 in browser\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 1 Install Dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "pip install -q \\\n", " \"unsloth[cu124-torch240]\" \\\n", " \"trl>=0.12\" \\\n", " \"peft>=0.13\" \\\n", " \"bitsandbytes>=0.44\" \\\n", " \"transformers>=4.46\" \\\n", " \"datasets>=3.0\" \\\n", " \"accelerate>=1.0\" \\\n", " \"wandb\" \\\n", " \"fastapi\" \\\n", " \"uvicorn[standard]\" \\\n", " \"requests\" \\\n", " \"matplotlib\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 2 Verify GPU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "print(f\"PyTorch: {torch.__version__}\")\n", "print(f\"CUDA: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\")\n", " print(f\"BF16: {torch.cuda.is_bf16_supported()}\")\n", "else:\n", " raise RuntimeError(\"No GPU detected this notebook requires a CUDA GPU.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 3 Clone Repo & Start Env Server" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, subprocess, time, requests\n", "\n", "REPO_DIR = os.path.expanduser(\"~/commitguard\")\n", "if not os.path.isdir(REPO_DIR):\n", " !git clone https://github.com/NitishKumar-ai/commitguard.git {REPO_DIR}\n", "else:\n", " !cd {REPO_DIR} && git pull\n", "\n", "os.chdir(REPO_DIR)\n", "!pip install -e . -q\n", "\n", "# Start env server in background\n", "server_proc = subprocess.Popen(\n", " [\"python\", \"-m\", \"commitguard_env.server\"],\n", " stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,\n", ")\n", "time.sleep(3)\n", "\n", "r = requests.get(\"http://localhost:8000/health\")\n", "print(f\"Env server: {r.json()}\")\n", "\n", "# Quick sanity reset + step\n", "r = requests.post(\"http://localhost:8000/reset\", json={})\n", "obs = r.json()[\"observation\"]\n", "print(f\"Sample diff length: {len(obs['diff'])} chars, files: {obs['available_files']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 4 HuggingFace Login (for gated Llama model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import login\n", "\n", "# Paste your HF token here (or set HF_TOKEN env var)\n", "# Get one at: https://huggingface.co/settings/tokens\n", "# Make sure you accepted the Llama license: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct\n", "\n", "HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\")\n", "if HF_TOKEN:\n", " login(token=HF_TOKEN)\n", " print(\"Logged in via env var.\")\n", "else:\n", " login() # interactive prompt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 5 Wandb Login (optional but recommended)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "USE_WANDB = True # Set False to skip\n", "\n", "if USE_WANDB:\n", " WANDB_KEY = os.getenv(\"WANDB_API_KEY\", \"\")\n", " if WANDB_KEY:\n", " wandb.login(key=WANDB_KEY)\n", " else:\n", " wandb.login() # interactive\n", " os.environ[\"WANDB_PROJECT\"] = \"commitguard\"\n", " print(\"Wandb ready.\")\n", "else:\n", " os.environ[\"WANDB_DISABLED\"] = \"true\"\n", " print(\"Wandb disabled.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 6 Load Model with Unsloth (4-bit LoRA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from unsloth import FastLanguageModel, PatchFastRL\n", "from trl import GRPOConfig, GRPOTrainer\n", "\n", "PatchFastRL(\"GRPO\", FastLanguageModel)\n", "\n", "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n", "\n", "print(f\"Loading {MODEL_NAME} in 4-bit...\")\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=MODEL_NAME,\n", " max_seq_length=2048,\n", " load_in_4bit=True,\n", " fast_inference=True,\n", " max_lora_rank=16,\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=8,\n", " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\"],\n", " lora_alpha=16,\n", " lora_dropout=0,\n", " bias=\"none\",\n", " use_gradient_checkpointing=\"unsloth\",\n", " random_state=3407,\n", ")\n", "\n", "print(f\"Model loaded. Trainable params: {model.print_trainable_parameters()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 7 Build Training Dataset from Env" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys, requests\n", "from datasets import Dataset\n", "\n", "sys.path.insert(0, os.path.join(REPO_DIR, \"scripts\"))\n", "from agent_prompt import SYSTEM_PROMPT, get_agent_prompt\n", "\n", "ENV_URL = \"http://localhost:8000\"\n", "N_SAMPLES = 200 # Number of training prompts\n", "\n", "samples = []\n", "for i in range(N_SAMPLES):\n", " r = requests.post(f\"{ENV_URL}/reset\", json={}, timeout=10)\n", " if r.status_code != 200:\n", " continue\n", " obs = r.json()[\"observation\"]\n", " user_msg = get_agent_prompt(obs[\"diff\"], obs[\"available_files\"], obs.get(\"step_idx\", 0))\n", " samples.append({\n", " \"prompt\": [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": user_msg},\n", " ],\n", " })\n", " if (i + 1) % 50 == 0:\n", " print(f\" fetched {i + 1}/{N_SAMPLES}\")\n", "\n", "dataset = Dataset.from_list(samples)\n", "print(f\"\\nDataset ready: {len(dataset)} samples\")\n", "print(f\"Sample prompt preview: {str(dataset[0]['prompt'][1]['content'])[:200]}...\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 8 Define Reward Function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_reward_from_env(prompts, completions, **kwargs) -> list[float]:\n", " \"\"\"Send each completion to the env as an action, collect reward.\"\"\"\n", " rewards = []\n", " for prompt, completion in zip(prompts, completions):\n", " try:\n", " requests.post(f\"{ENV_URL}/reset\", json={}, timeout=10)\n", " text = completion[-1][\"content\"] if isinstance(completion, list) else str(completion)\n", " r = requests.post(f\"{ENV_URL}/step\", json={\"action\": text}, timeout=10)\n", " if r.status_code == 200:\n", " rewards.append(float(r.json().get(\"reward\", 0.0)))\n", " else:\n", " rewards.append(-0.5)\n", " except Exception:\n", " rewards.append(-1.0)\n", " return rewards\n", "\n", "# Quick test\n", "test_r = get_reward_from_env(\n", " [\"test\"],\n", " [\"verdicttrueCWE-119buffer overflow\"]\n", ")\n", "print(f\"Reward function test: {test_r}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 9 Configure & Launch GRPO Training\n", "\n", "This is the main training loop. ~2-3 hours on L4 for 300 steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "OUTPUT_DIR = \"outputs/commitguard-llama-3b\"\n", "\n", "training_args = GRPOConfig(\n", " output_dir=OUTPUT_DIR,\n", " num_generations=4,\n", " max_completion_length=512,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " learning_rate=5e-6,\n", " logging_steps=1,\n", " save_steps=50,\n", " max_steps=300,\n", " report_to=\"wandb\" if USE_WANDB else \"none\",\n", " bf16=torch.cuda.is_bf16_supported(),\n", " fp16=not torch.cuda.is_bf16_supported(),\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=[get_reward_from_env],\n", " args=training_args,\n", " train_dataset=dataset,\n", ")\n", "\n", "print(\"Starting GRPO training...\")\n", "print(f\" Steps: {training_args.max_steps}\")\n", "print(f\" Generations per prompt: {training_args.num_generations}\")\n", "print(f\" Save every: {training_args.save_steps} steps\")\n", "print(f\" Output: {OUTPUT_DIR}\")\n", "print(\"=\"*50)\n", "\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 10 Save Final LoRA Adapter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "FINAL_DIR = f\"{OUTPUT_DIR}/final\"\n", "model.save_pretrained_merged(FINAL_DIR, tokenizer, save_method=\"lora\")\n", "print(f\"LoRA adapter saved to {FINAL_DIR}\")\n", "\n", "# List saved files\n", "for f in sorted(os.listdir(FINAL_DIR)):\n", " size_mb = os.path.getsize(os.path.join(FINAL_DIR, f)) / 1024**2\n", " print(f\" {f}: {size_mb:.1f} MB\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 11 Quick Evaluation (Baseline vs Trained)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# Load test set\n", "test_path = os.path.join(REPO_DIR, \"data\", \"devign_test.jsonl\")\n", "with open(test_path) as f:\n", " test_samples = [json.loads(l) for l in f if l.strip()]\n", "\n", "print(f\"Evaluating on {len(test_samples)} held-out samples...\")\n", "\n", "# Run trained model on test set\n", "FastLanguageModel.for_inference(model)\n", "\n", "correct = 0\n", "results = []\n", "\n", "for i, sample in enumerate(test_samples):\n", " user_msg = get_agent_prompt(sample[\"diff\"], sample[\"available_files\"], 0)\n", " messages = [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": user_msg},\n", " ]\n", " inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(model.device)\n", " with torch.no_grad():\n", " output = model.generate(inputs, max_new_tokens=512, temperature=0.1, do_sample=True)\n", " response = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)\n", "\n", " # Parse verdict\n", " sys.path.insert(0, os.path.join(REPO_DIR, \"commitguard_env\"))\n", " from commitguard_env.parse_action import parse_action\n", " action = parse_action(response)\n", "\n", " pred_vuln = bool(action.is_vulnerable) if action.is_vulnerable is not None else False\n", " truth_vuln = sample[\"is_vulnerable\"]\n", "\n", " if pred_vuln == truth_vuln:\n", " correct += 1\n", "\n", " results.append({\n", " \"sample_id\": sample[\"sample_id\"],\n", " \"pred\": pred_vuln,\n", " \"truth\": truth_vuln,\n", " \"cwe\": sample.get(\"cwe\"),\n", " \"vuln_type\": action.vuln_type,\n", " })\n", "\n", " if (i + 1) % 20 == 0:\n", " print(f\" {i+1}/{len(test_samples)} running accuracy: {100*correct/(i+1):.1f}%\")\n", "\n", "accuracy = 100 * correct / len(test_samples)\n", "print(f\"\\nFinal trained accuracy: {accuracy:.1f}%\")\n", "\n", "with open(os.path.join(REPO_DIR, \"eval_trained.json\"), \"w\") as f:\n", " json.dump(results, f, indent=2)\n", "print(\"Results saved to eval_trained.json\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 12 Generate Plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from collections import Counter\n", "\n", "os.makedirs(os.path.join(REPO_DIR, \"plots\"), exist_ok=True)\n", "\n", "# --- Plot 1: Training reward curve (from trainer logs) ---\n", "if hasattr(trainer, 'state') and trainer.state.log_history:\n", " steps = [l[\"step\"] for l in trainer.state.log_history if \"loss\" in l]\n", " losses = [l[\"loss\"] for l in trainer.state.log_history if \"loss\" in l]\n", " \n", " fig, ax = plt.subplots(figsize=(10, 5))\n", " ax.plot(steps, losses, color=\"#2ecc71\", linewidth=2)\n", " ax.set_xlabel(\"Training Step\")\n", " ax.set_ylabel(\"Loss\")\n", " ax.set_title(\"CommitGuard GRPO Training Loss\")\n", " ax.grid(True, linestyle=\"--\", alpha=0.5)\n", " fig.savefig(os.path.join(REPO_DIR, \"plots\", \"reward_curve.png\"), dpi=150)\n", " plt.show()\n", " print(\"Saved plots/reward_curve.png\")\n", "\n", "# --- Plot 2: Accuracy comparison ---\n", "baseline_acc = 50.0 # Update with actual baseline number\n", "trained_acc = accuracy\n", "\n", "fig, ax = plt.subplots(figsize=(8, 5))\n", "bars = ax.bar([\"Baseline (Untrained)\", \"CommitGuard (Trained)\"],\n", " [baseline_acc, trained_acc],\n", " color=[\"#95a5a6\", \"#3498db\"])\n", "ax.set_ylabel(\"Detection Accuracy (%)\")\n", "ax.set_title(\"Vulnerability Detection: Baseline vs. Trained\")\n", "ax.set_ylim(0, 100)\n", "for bar in bars:\n", " h = bar.get_height()\n", " ax.text(bar.get_x() + bar.get_width()/2., h + 1, f\"{h:.1f}%\",\n", " ha=\"center\", fontweight=\"bold\")\n", "fig.savefig(os.path.join(REPO_DIR, \"plots\", \"baseline_vs_trained.png\"), dpi=150)\n", "plt.show()\n", "print(\"Saved plots/baseline_vs_trained.png\")\n", "\n", "# --- Plot 3: Per-CWE breakdown ---\n", "cwe_correct = Counter()\n", "cwe_total = Counter()\n", "for r in results:\n", " if r[\"cwe\"]:\n", " cwe_total[r[\"cwe\"]] += 1\n", " if r[\"pred\"] == r[\"truth\"]:\n", " cwe_correct[r[\"cwe\"]] += 1\n", "\n", "cwes = sorted(cwe_total.keys())\n", "accs = [100 * cwe_correct[c] / cwe_total[c] if cwe_total[c] > 0 else 0 for c in cwes]\n", "\n", "if cwes:\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", " ax.bar(cwes, accs, color=\"#e67e22\")\n", " ax.set_ylabel(\"Accuracy (%)\")\n", " ax.set_title(\"Trained Model Accuracy by CWE Type\")\n", " ax.set_ylim(0, 100)\n", " plt.xticks(rotation=45)\n", " plt.tight_layout()\n", " fig.savefig(os.path.join(REPO_DIR, \"plots\", \"per_cwe.png\"), dpi=150)\n", " plt.show()\n", " print(\"Saved plots/per_cwe.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 13 Cleanup\n", "\n", "Stop the env server and print final summary." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "server_proc.terminate()\n", "print(\"Env server stopped.\")\n", "\n", "print(\"\\n\" + \"=\"*50)\n", "print(\" TRAINING COMPLETE\")\n", "print(\"=\"*50)\n", "print(f\" Model: {MODEL_NAME}\")\n", "print(f\" Steps: {training_args.max_steps}\")\n", "print(f\" Accuracy: {baseline_acc:.1f}% {trained_acc:.1f}% (+{trained_acc - baseline_acc:.1f}pp)\")\n", "print(f\" Adapter: {FINAL_DIR}\")\n", "print(f\" Plots: plots/reward_curve.png, baseline_vs_trained.png, per_cwe.png\")\n", "print(\"\\nNext: copy outputs/ and plots/ back to your local machine.\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }