{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CommitGuard GRPO Training Notebook\n", "\n", "Train Llama-3.2-3B-Instruct to detect exploitable vulnerabilities in code commits using GRPO (Group Relative Policy Optimization).\n", "\n", "**Requirements:** NVIDIA GPU with 16 GB VRAM (L4/A100/T4). Run this notebook on a GCP VM with GPU attached.\n", "\n", "## Setup\n", "Connect to this notebook via SSH tunnel:\n", "```bash\n", "# On GCP VM:\n", "jupyter notebook --no-browser --port=8888\n", "\n", "# On your local machine:\n", "gcloud compute ssh commitguard-train --zone=us-central1-a -- -NL 8888:localhost:8888\n", "# Then open http://localhost:8888 in browser\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 1 Install Dependencies" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "<3>WSL (3364 - Relay) ERROR: CreateProcessCommon:800: execvpe(/bin/bash) failed: No such file or directory\n" ] }, { "ename": "CalledProcessError", "evalue": "Command 'b'# Install uv for fast, reliable dependency resolution\\ncurl -LsSf https://astral.sh/uv/install.sh | sh\\nexport PATH=\"$HOME/.local/bin:$PATH\"\\n\\nuv pip install -q \\\\\\n \"unsloth[cu124-torch240]\" \\\\\\n \"trl>=0.12\" \\\\\\n \"peft>=0.13\" \\\\\\n \"bitsandbytes>=0.44\" \\\\\\n \"transformers>=4.46\" \\\\\\n \"datasets>=3.0\" \\\\\\n \"accelerate>=1.0\" \\\\\\n \"wandb\" \\\\\\n \"fastapi\" \\\\\\n \"uvicorn[standard]\" \\\\\\n \"requests\" \\\\\\n \"matplotlib\"\\n'' returned non-zero exit status 1.", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mCalledProcessError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m get_ipython().run_cell_magic(\u001b[33m'bash'\u001b[39m, \u001b[33m''\u001b[39m, \u001b[33m'# Install uv for fast, reliable dependency resolution\\ncurl -LsSf https://astral.sh/uv/install.sh | sh\\nexport PATH=\"$HOME/.local/bin:$PATH\"\\n\\nuv pip install -q \\\\\\n \"unsloth[cu124-torch240]\" \\\\\\n \"trl>=0.12\" \\\\\\n \"peft>=0.13\" \\\\\\n \"bitsandbytes>=0.44\" \\\\\\n \"transformers>=4.46\" \\\\\\n \"datasets>=3.0\" \\\\\\n \"accelerate>=1.0\" \\\\\\n \"wandb\" \\\\\\n \"fastapi\" \\\\\\n \"uvicorn[standard]\" \\\\\\n \"requests\" \\\\\\n \"matplotlib\"\\n'\u001b[39m)\n", "\u001b[31mCalledProcessError\u001b[39m: Command 'b'# Install uv for fast, reliable dependency resolution\\ncurl -LsSf https://astral.sh/uv/install.sh | sh\\nexport PATH=\"$HOME/.local/bin:$PATH\"\\n\\nuv pip install -q \\\\\\n \"unsloth[cu124-torch240]\" \\\\\\n \"trl>=0.12\" \\\\\\n \"peft>=0.13\" \\\\\\n \"bitsandbytes>=0.44\" \\\\\\n \"transformers>=4.46\" \\\\\\n \"datasets>=3.0\" \\\\\\n \"accelerate>=1.0\" \\\\\\n \"wandb\" \\\\\\n \"fastapi\" \\\\\\n \"uvicorn[standard]\" \\\\\\n \"requests\" \\\\\\n \"matplotlib\"\\n'' returned non-zero exit status 1." ] } ], "source": [ "!pip install -q unsloth\n", "!pip uninstall unsloth -y && pip install -q --upgrade --no-cache-dir \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n", "!pip install -q trl>=0.12 peft bitsandbytes transformers datasets accelerate wandb fastapi uvicorn[standard] requests matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 2 Verify GPU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "print(f\"PyTorch: {torch.__version__}\")\n", "print(f\"CUDA: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\")\n", " print(f\"BF16: {torch.cuda.is_bf16_supported()}\")\n", "else:\n", " raise RuntimeError(\"No GPU detected this notebook requires a CUDA GPU.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 3 Clone Repo & Start Env Server" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, subprocess, time, requests, sys\n", "\n", "# Check if running in Google Colab\n", "if \"google.colab\" in sys.modules:\n", " print(\"Running in Google Colab.\")\n", " # Reset to base directory in case cell is run multiple times\n", " os.chdir(\"/content\")\n", " \n", " if not os.path.exists(\"/content/project.zip\"):\n", " from google.colab import files\n", " print(\"\\n--- WE NEED YOUR PROJECT.ZIP ---\")\n", " print(\"Please click 'Choose Files' below and select project.zip from your computer:\\n\")\n", " uploaded = files.upload()\n", " \n", " if os.path.exists(\"/content/project.zip\"):\n", " print(\"Extracting project.zip...\")\n", " !unzip -q -o /content/project.zip -d /content/commitguard\n", " else:\n", " print(\"\\n*** ERROR: project.zip still not found! ***\\n\")\n", " sys.exit(1)\n", " \n", " os.chdir(\"/content/commitguard\")\n", " REPO_DIR = os.getcwd()\n", "else:\n", " if os.path.basename(os.getcwd()) == \"notebooks\":\n", " REPO_DIR = os.path.abspath(\"..\")\n", " else:\n", " REPO_DIR = os.getcwd()\n", " os.chdir(REPO_DIR)\n", "\n", "print(f\"Using REPO_DIR: {REPO_DIR}\")\n", "\n", "# 2. Install current project in editable mode\n", "!pip install -e . -q\n", "\n", "# 3. Start env server in background\n", "server_proc = subprocess.Popen(\n", " [sys.executable, \"-m\", \"commitguard_env.server\"],\n", " stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True\n", ")\n", "time.sleep(5)\n", "\n", "try:\n", " r = requests.get(\"http://localhost:8000/health\")\n", " print(f\"Env server: {r.json()}\")\n", "except Exception as e:\n", " print(f\"Server failed to start: {e}\")\n", " stdout, stderr = server_proc.communicate(timeout=1)\n", " print(f\"STDOUT: {stdout}\")\n", " print(f\"STDERR: {stderr}\")\n", "\n", "# Quick sanity reset + step\n", "r = requests.post(\"http://localhost:8000/reset\", json={})\n", "obs = r.json()[\"observation\"]\n", "print(f\"Sample diff length: {len(obs['diff'])} chars, files: {obs['available_files']}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 4 HuggingFace Login (for gated Llama model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import login\n", "\n", "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n", "if HF_TOKEN:\n", " login(token=HF_TOKEN)\n", " print(\"Logged in via token.\")\n", "else:\n", " login()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 5 Wandb Login (optional but recommended)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "USE_WANDB = False\n", "os.environ[\"WANDB_DISABLED\"] = \"true\"\n", "print(\"Wandb disabled.\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 6 Load Model with Unsloth (4-bit LoRA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from unsloth import FastLanguageModel, PatchFastRL\n", "from trl import GRPOConfig, GRPOTrainer\n", "\n", "PatchFastRL(\"GRPO\", FastLanguageModel)\n", "\n", "MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n", "\n", "print(f\"Loading {MODEL_NAME} in 4-bit...\")\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=MODEL_NAME,\n", " max_seq_length=2048,\n", " load_in_4bit=True,\n", " fast_inference=False,\n", " max_lora_rank=16,\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=8,\n", " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\"],\n", " lora_alpha=16,\n", " lora_dropout=0,\n", " bias=\"none\",\n", " use_gradient_checkpointing=\"unsloth\",\n", " random_state=3407,\n", ")\n", "\n", "print(f\"Model loaded. Trainable params: {model.print_trainable_parameters()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 7 Build Training Dataset from Env" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys, requests\n", "from datasets import Dataset\n", "\n", "sys.path.insert(0, os.path.join(REPO_DIR, \"scripts\"))\n", "from agent_prompt import SYSTEM_PROMPT, get_agent_prompt\n", "\n", "ENV_URL = \"http://localhost:8000\"\n", "N_SAMPLES = 200 # Number of training prompts (updated)\n", "\n", "samples = []\n", "for i in range(N_SAMPLES):\n", " r = requests.post(f\"{ENV_URL}/reset\", json={}, timeout=10)\n", " if r.status_code != 200:\n", " continue\n", " obs = r.json()[\"observation\"]\n", " state_r = requests.get(f\"{ENV_URL}/state\").json()\n", " current_sample_id = state_r.get(\"state\", {}).get(\"current_sample_id\", \"unknown\")\n", " user_msg = get_agent_prompt(obs[\"diff\"], obs[\"available_files\"], obs.get(\"step_idx\", 0))\n", " samples.append({\n", " \"prompt\": [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": user_msg},\n", " ],\n", " \"sample_id\": current_sample_id,\n", " })\n", " if (i + 1) % 50 == 0:\n", " print(f\" fetched {i + 1}/{N_SAMPLES}\")\n", "\n", "dataset = Dataset.from_list(samples)\n", "print(f\"\\nDataset ready: {len(dataset)} samples\")\n", "print(f\"Sample prompt preview: {str(dataset[0]['prompt'][1]['content'])[:200]}...\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 8 Define Reward Function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_reward_from_env(prompts, completions, sample_id, **kwargs) -> list[float]:\n", " \"\"\"Send each completion to the env as an action, collect reward.\"\"\"\n", " rewards = []\n", " for p_id, completion in zip(sample_id, completions):\n", " try:\n", " requests.post(f\"{ENV_URL}/reset\", json={\"sample_id\": p_id}, timeout=10)\n", " text = completion[-1][\"content\"] if isinstance(completion, list) else str(completion)\n", " r = requests.post(f\"{ENV_URL}/step\", json={\"action\": text}, timeout=10)\n", " if r.status_code == 200:\n", " rewards.append(float(r.json().get(\"reward\", 0.0)))\n", " else:\n", " rewards.append(-0.5)\n", " except Exception:\n", " rewards.append(-1.0)\n", " return rewards\n", "\n", "# Quick test\n", "test_r = get_reward_from_env(\n", " [\"test\"],\n", " [\"verdicttrueCWE-119buffer overflow\"],\n", " [\"test_id\"]\n", ")\n", "print(f\"Reward function test: {test_r}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 9 Configure & Launch GRPO Training\n", "\n", "This is the main training loop. ~2-3 hours on L4 for 300 steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "OUTPUT_DIR = \"outputs/commitguard-llama-3b\"\n", "\n", "training_args = GRPOConfig(\n", " output_dir=OUTPUT_DIR,\n", " num_generations=4,\n", " max_completion_length=512,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " learning_rate=5e-6,\n", " logging_steps=1,\n", " save_steps=50,\n", " max_steps=300,\n", " report_to=\"wandb\" if USE_WANDB else \"none\",\n", " bf16=torch.cuda.is_bf16_supported(),\n", " fp16=not torch.cuda.is_bf16_supported(),\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=[get_reward_from_env],\n", " args=training_args,\n", " train_dataset=dataset,\n", ")\n", "\n", "print(\"Starting GRPO training...\")\n", "print(f\" Steps: {training_args.max_steps}\")\n", "print(f\" Generations per prompt: {training_args.num_generations}\")\n", "print(f\" Save every: {training_args.save_steps} steps\")\n", "print(f\" Output: {OUTPUT_DIR}\")\n", "print(\"=\"*50)\n", "\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 10 Save Final LoRA Adapter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "FINAL_DIR = f\"{OUTPUT_DIR}/final\"\n", "model.save_pretrained_merged(FINAL_DIR, tokenizer, save_method=\"lora\")\n", "print(f\"LoRA adapter saved to {FINAL_DIR}\")\n", "\n", "# List saved files\n", "for f in sorted(os.listdir(FINAL_DIR)):\n", " size_mb = os.path.getsize(os.path.join(FINAL_DIR, f)) / 1024**2\n", " print(f\" {f}: {size_mb:.1f} MB\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 11 Quick Evaluation (Baseline vs Trained)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# Load test set\n", "test_path = os.path.join(REPO_DIR, \"data\", \"devign_test.jsonl\")\n", "with open(test_path) as f:\n", " test_samples = [json.loads(l) for l in f if l.strip()]\n", "\n", "print(f\"Evaluating on {len(test_samples)} held-out samples...\")\n", "\n", "# Run trained model on test set\n", "FastLanguageModel.for_inference(model)\n", "\n", "correct = 0\n", "results = []\n", "\n", "for i, sample in enumerate(test_samples):\n", " user_msg = get_agent_prompt(sample[\"diff\"], sample[\"available_files\"], 0)\n", " messages = [\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": user_msg},\n", " ]\n", " inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(model.device)\n", " with torch.no_grad():\n", " output = model.generate(inputs, max_new_tokens=512, temperature=0.1, do_sample=True)\n", " response = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)\n", "\n", " # Parse verdict\n", " sys.path.insert(0, os.path.join(REPO_DIR, \"commitguard_env\"))\n", " from commitguard_env.parse_action import parse_action\n", " action = parse_action(response)\n", "\n", " pred_vuln = bool(action.is_vulnerable) if action.is_vulnerable is not None else False\n", " truth_vuln = sample[\"is_vulnerable\"]\n", "\n", " if pred_vuln == truth_vuln:\n", " correct += 1\n", "\n", " results.append({\n", " \"sample_id\": sample[\"sample_id\"],\n", " \"pred\": pred_vuln,\n", " \"truth\": truth_vuln,\n", " \"cwe\": sample.get(\"cwe\"),\n", " \"vuln_type\": action.vuln_type,\n", " })\n", "\n", " if (i + 1) % 20 == 0:\n", " print(f\" {i+1}/{len(test_samples)} running accuracy: {100*correct/(i+1):.1f}%\")\n", "\n", "accuracy = 100 * correct / len(test_samples)\n", "print(f\"\\nFinal trained accuracy: {accuracy:.1f}%\")\n", "\n", "with open(os.path.join(REPO_DIR, \"eval_trained.json\"), \"w\") as f:\n", " json.dump(results, f, indent=2)\n", "print(\"Results saved to eval_trained.json\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 12 Generate Plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from collections import Counter\n", "\n", "os.makedirs(os.path.join(REPO_DIR, \"plots\"), exist_ok=True)\n", "\n", "# --- Plot 1: Training reward curve (from trainer logs) ---\n", "if hasattr(trainer, 'state') and trainer.state.log_history:\n", " steps = [l[\"step\"] for l in trainer.state.log_history if \"loss\" in l]\n", " losses = [l[\"loss\"] for l in trainer.state.log_history if \"loss\" in l]\n", " \n", " fig, ax = plt.subplots(figsize=(10, 5))\n", " ax.plot(steps, losses, color=\"#2ecc71\", linewidth=2)\n", " ax.set_xlabel(\"Training Step\")\n", " ax.set_ylabel(\"Loss\")\n", " ax.set_title(\"CommitGuard GRPO Training Loss\")\n", " ax.grid(True, linestyle=\"--\", alpha=0.5)\n", " fig.savefig(os.path.join(REPO_DIR, \"plots\", \"reward_curve.png\"), dpi=150)\n", " plt.show()\n", " print(\"Saved plots/reward_curve.png\")\n", "\n", " # --- Plot 2: Accuracy comparison ---\n", " with open(os.path.join(REPO_DIR, \"eval_baseline.json\")) as f:\n", " b_data = json.load(f)\n", " baseline_acc = 100 * sum(1 for x in b_data if x['pred'] == x['truth']) / len(b_data)\n", " trained_acc = accuracy\n", "\n", " fig, ax = plt.subplots(figsize=(8, 5))\n", " bars = ax.bar([\"Baseline (Untrained)\", \"CommitGuard (Trained)\"],\n", " [baseline_acc, trained_acc],\n", " color=[\"#95a5a6\", \"#3498db\"])\n", " ax.set_ylabel(\"Detection Accuracy (%)\")\n", " ax.set_title(\"Vulnerability Detection: Baseline vs. Trained\")\n", " ax.set_ylim(0, 100)\n", " for bar in bars:\n", " h = bar.get_height()\n", " ax.text(bar.get_x() + bar.get_width()/2., h + 1, f\"{h:.1f}%\",\n", " ha=\"center\", fontweight=\"bold\")\n", " fig.savefig(os.path.join(REPO_DIR, \"plots\", \"baseline_vs_trained.png\"), dpi=150)\n", " plt.show()\n", " print(\"Saved plots/baseline_vs_trained.png\")\n", "\n", " # --- Plot 3: Per-CWE breakdown ---\n", " cwe_correct = Counter()\n", " cwe_total = Counter()\n", " for r in results:\n", " if r[\"cwe\"]:\n", " cwe_total[r[\"cwe\"]] += 1\n", " if r[\"pred\"] == r[\"truth\"]:\n", " cwe_correct[r[\"cwe\"]] += 1\n", "\n", " cwes = sorted(cwe_total.keys())\n", " accs = [100 * cwe_correct[c] / cwe_total[c] if cwe_total[c] > 0 else 0 for c in cwes]\n", "\n", " if cwes:\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", " ax.bar(cwes, accs, color=\"#e67e22\")\n", " ax.set_ylabel(\"Accuracy (%)\")\n", " ax.set_title(\"Trained Model Accuracy by CWE Type\")\n", " ax.set_ylim(0, 100)\n", " plt.xticks(rotation=45)\n", " plt.tight_layout()\n", " fig.savefig(os.path.join(REPO_DIR, \"plots\", \"per_cwe.png\"), dpi=150)\n", " plt.show()\n", " print(\"Saved plots/per_cwe.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 13 Cleanup\n", "\n", "Stop the env server and print final summary." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "server_proc.terminate()\n", "print(\"Env server stopped.\")\n", "\n", "print(\"\\n\" + \"=\"*50)\n", "print(\" TRAINING COMPLETE\")\n", "print(\"=\"*50)\n", "print(f\" Model: {MODEL_NAME}\")\n", "print(f\" Steps: {training_args.max_steps}\")\n", "print(f\" Accuracy: {baseline_acc:.1f}% {trained_acc:.1f}% (+{trained_acc - baseline_acc:.1f}pp)\")\n", "print(f\" Adapter: {FINAL_DIR}\")\n", "print(f\" Plots: plots/reward_curve.png, baseline_vs_trained.png, per_cwe.png\")\n", "\n", "print(\"\\nNext: copy outputs/ and plots/ back to your local machine.\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.13" } }, "nbformat": 4, "nbformat_minor": 4 }