Spaces:
Running on A10G
Running on A10G
Commit ·
1d3a7a5
0
Parent(s):
Deployment Build (v4): Debug Flush Logging
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .agent/FUTURE_WORK.md +16 -0
- .agent/README.md +38 -0
- .agent/agent_instructions.md +69 -0
- .agent/architecture.md +149 -0
- .agent/checkpoints.md +57 -0
- .agent/coding_conventions.md +63 -0
- .agent/decision_log.md +40 -0
- .agent/git_workflow.md +85 -0
- .agent/project_context.md +82 -0
- .agent/test_contracts.md +48 -0
- .claude/settings.local.json +12 -0
- .dockerignore +17 -0
- .gitignore +36 -0
- .pre-commit-hooks.yaml +7 -0
- .vscode/settings.json +10 -0
- .vscode/tasks.json +16 -0
- Dockerfile +17 -0
- Dockerfile.train +56 -0
- GEMINI.md +55 -0
- README.md +186 -0
- README_SUBMISSION.md +64 -0
- __init__.py +0 -0
- action.yml +34 -0
- commitguard_env/__init__.py +8 -0
- commitguard_env/agent_prompt.py +68 -0
- commitguard_env/cli.py +131 -0
- commitguard_env/environment.py +173 -0
- commitguard_env/grpo_prompt.py +38 -0
- commitguard_env/hooks.py +50 -0
- commitguard_env/inference.py +86 -0
- commitguard_env/models.py +70 -0
- commitguard_env/parse_action.py +97 -0
- commitguard_env/reward.py +100 -0
- commitguard_env/scanner.py +54 -0
- commitguard_env/server.py +127 -0
- configs/openenv.yaml +4 -0
- data/cwe_keywords.json +11 -0
- data/devign_filtered.jsonl +0 -0
- data/devign_test.jsonl +0 -0
- data/devign_train.jsonl +0 -0
- gitlab-ci-template.yml +16 -0
- notebooks/train_commitguard.ipynb +604 -0
- pyproject.toml +48 -0
- pyrightconfig.json +16 -0
- scratch/extract_sample.py +24 -0
- scripts/README.md +7 -0
- scripts/__init__.py +1 -0
- scripts/check_cuda.py +6 -0
- scripts/check_disjoint.py +20 -0
- scripts/check_unsloth.py +13 -0
.agent/FUTURE_WORK.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!--
|
| 2 |
+
If an agent is tempted to build something not in the current scope, append it here instead and continue with the locked task.
|
| 3 |
+
|
| 4 |
+
Source: ../prd.md 14 (Future Work). Do not execute these during the hackathon build unless explicitly re-scoped by the whole team (and documented).
|
| 5 |
+
-->
|
| 6 |
+
|
| 7 |
+
## Future Work (post-hackathon)
|
| 8 |
+
|
| 9 |
+
- **Sandboxed exploit execution** replace pattern-match reward with actual exploit runs against compiled code in a Docker sandbox
|
| 10 |
+
- **Multi-file commit reasoning** extend the env to support diffs spanning multiple files, with a context budget
|
| 11 |
+
- **Self-play loop** pair CommitGuard with a code-generation agent; defender and attacker train against each other (the AlphaGo pattern for security)
|
| 12 |
+
- **Agentic harness integration** wire into real CI pipelines via the OpenEnv MCP layer, enabling commit-time security review at PR open
|
| 13 |
+
- **Real CVE corpus** extend beyond Devign to recent CVE-tagged commits from major open-source repos
|
| 14 |
+
- **Multi-language support** current env is C-focused via Devign; extend to Python, JavaScript, Go
|
| 15 |
+
- **Reward shape ablations** formal study of how reward composition affects which vulnerability types the model learns fastest
|
| 16 |
+
|
.agent/README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## What this folder is
|
| 2 |
+
|
| 3 |
+
`.agent/` is the **operating system for AI agents** on this repo. It locks the architecture decisions from `../prd.md`, prevents scope creep under deadline pressure, and makes sure three engineers can use Cursor / Claude Code in parallel without drifting.
|
| 4 |
+
|
| 5 |
+
If you're an agent: **load `project_context.md` first**. If you're a human: treat this folder like the team's constitution.
|
| 6 |
+
|
| 7 |
+
## Nonnegotiable rule (scope freeze)
|
| 8 |
+
|
| 9 |
+
**Scope freeze is midnight Saturday (00:00 IST).** After that time:
|
| 10 |
+
- Do not add features, endpoints, model changes, UI, or nice to haves.
|
| 11 |
+
- Only do bug fixes, tests, wiring, docs, and reliability work that protects the locked deliverables.
|
| 12 |
+
- If youre tempted to add something: append it to `FUTURE_WORK.md` and continue the locked task.
|
| 13 |
+
|
| 14 |
+
## Files and what each enforces
|
| 15 |
+
|
| 16 |
+
- `project_context.md`: **Single source of truth**. The compressed PRD: what were building, why, who for, locked stack, 30sec pitch, nongoals.
|
| 17 |
+
- `architecture.md`: **Technical contract**. File layout, dataclass schemas, XML action format, reward signature, observation schema, cheating prevention, required HTTP endpoints.
|
| 18 |
+
- `coding_conventions.md`: **How we write code**. Typed dataclasses, import order, errors, forbidden patterns, repo hygiene.
|
| 19 |
+
- `decision_log.md`: **Locked decisions + fallbacks**. PRD 7.1 in table form, PRD 7.2 fallback triggers. New decisions go here with timestamp+author.
|
| 20 |
+
- `agent_instructions.md`: **System prompt** for any coding agent. Read order, refusal rules, time pressure behavior, fallback triggers.
|
| 21 |
+
- `checkpoints.md`: **Team sync contract** at midnight / 9 AM / 3 PM. What must be demoable; what triggers scope cuts; what gets cut first.
|
| 22 |
+
- `test_contracts.md`: **Blocking tests** required before merge: no-leak, reward cases, XML parser robustness, env smoke.
|
| 23 |
+
- `git_workflow.md`: **Parallel work rules**. Branch naming, commit conventions, merge gates, no-force-push rules, pre-submission checklist.
|
| 24 |
+
- `FUTURE_WORK.md`: **Parking lot** for anything not in current scope (pre-populated from PRD 14).
|
| 25 |
+
|
| 26 |
+
## Where the real spec lives
|
| 27 |
+
|
| 28 |
+
The authoritative PRD is `../prd.md`. If any `.agent/` file disagrees with the PRD, **the PRD wins** and you must update the `.agent/` file immediately.
|
| 29 |
+
|
| 30 |
+
## Task files (per person)
|
| 31 |
+
|
| 32 |
+
This repo expects per-person task lists:
|
| 33 |
+
- `../tasks_niti.md`
|
| 34 |
+
- `../tasks_deepak.md`
|
| 35 |
+
- `../tasks_divyank.md`
|
| 36 |
+
|
| 37 |
+
If they dont exist yet, create them now with 1020 bullet tasks each and keep them updated. Agents should read the relevant one **after** `project_context.md` and `architecture.md`.
|
| 38 |
+
|
.agent/agent_instructions.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## System prompt for CommitGuard coding agents
|
| 2 |
+
|
| 3 |
+
You are an AI coding agent working on the **CommitGuard** hackathon repo.
|
| 4 |
+
|
| 5 |
+
Your job is to ship the locked deliverables before **Sunday 5:00 PM IST** with minimal risk. This is a **deadline game**, not a feature game.
|
| 6 |
+
|
| 7 |
+
### Read order (mandatory)
|
| 8 |
+
|
| 9 |
+
1. Read `.agent/project_context.md` (single source of truth).
|
| 10 |
+
2. Read `.agent/architecture.md` (technical contract).
|
| 11 |
+
3. Read `.agent/coding_conventions.md` (how we write code).
|
| 12 |
+
4. Read the relevant task list:
|
| 13 |
+
- `tasks_niti.md` OR `tasks_deepak.md` OR `tasks_divyank.md`
|
| 14 |
+
- If missing: create it with concrete bullets and continue.
|
| 15 |
+
|
| 16 |
+
Only then start coding.
|
| 17 |
+
|
| 18 |
+
### Scope control (hard refusal rule)
|
| 19 |
+
|
| 20 |
+
**Scope freeze is midnight Saturday (00:00 IST).** After that:
|
| 21 |
+
- Refuse any scope expansion, new features, new endpoints, new UI, new metrics.
|
| 22 |
+
- Only do: bug fixes, tests, wiring, packaging, docs, reliability.
|
| 23 |
+
|
| 24 |
+
If asked to add a feature:
|
| 25 |
+
- Do **not** implement it.
|
| 26 |
+
- Append it to `.agent/FUTURE_WORK.md` with 1-line rationale.
|
| 27 |
+
- Continue the locked task.
|
| 28 |
+
|
| 29 |
+
### Architectural choices (dont guess)
|
| 30 |
+
|
| 31 |
+
If a decision is not covered by `.agent/architecture.md`:
|
| 32 |
+
- Ask for clarification (or check `../prd.md`).
|
| 33 |
+
- Do not invent new schemas or endpoints because it seems right.
|
| 34 |
+
|
| 35 |
+
### Cheating prevention (highest priority constraint)
|
| 36 |
+
|
| 37 |
+
The environment is RLVR: reward comes from dataset ground truth, but the agent must never see labels.
|
| 38 |
+
|
| 39 |
+
Rules:
|
| 40 |
+
- Observations must never contain ground truth (`is_vulnerable`, `cwe`, labels, this is vulnerable strings).
|
| 41 |
+
- The server must never return label fields in HTTP responses.
|
| 42 |
+
- Debug endpoints must never include ground truth.
|
| 43 |
+
- Always keep `test_no_leak.py` green.
|
| 44 |
+
|
| 45 |
+
### Time-pressure behavior (what good looks like)
|
| 46 |
+
|
| 47 |
+
Under deadline pressure:
|
| 48 |
+
- Prefer the simplest implementation that passes the contracts in `.agent/test_contracts.md`.
|
| 49 |
+
- Treat the fallbacks in `.agent/project_context.md` as pre-approved pivots; if triggered, pivot immediately and log in `.agent/decision_log.md`.
|
| 50 |
+
- Avoid refactors unless they remove a clear blocker.
|
| 51 |
+
|
| 52 |
+
### Fallback triggers (execute immediately)
|
| 53 |
+
|
| 54 |
+
If any trigger happens, switch to the fallback with no debate:
|
| 55 |
+
- OOM on A10G Qwen2.5-1.5B-Instruct
|
| 56 |
+
- HF Jobs queue >30 min GCP A10G on-demand
|
| 57 |
+
- 3-action env not shipped by midnight 2-action env
|
| 58 |
+
- Tiered reward buggy binary reward only
|
| 59 |
+
- Curve flat at 10 AM Sunday qualitative narrative
|
| 60 |
+
- Video recording fails twice text trace in README
|
| 61 |
+
|
| 62 |
+
### CLI-first ops (HF + GCP)
|
| 63 |
+
|
| 64 |
+
Prefer repeatable CLI commands over UI clicks:
|
| 65 |
+
- HF Space + repos: use `huggingface-cli` / git
|
| 66 |
+
- GCP: use `gcloud`
|
| 67 |
+
|
| 68 |
+
Document any required commands in `README.md` or `scripts/`.
|
| 69 |
+
|
.agent/architecture.md
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Architecture contract (do not improvise)
|
| 2 |
+
|
| 3 |
+
This is the technical contract for CommitGuard. If youre about to invent a new shape, dont. Either its already here, or it belongs in `FUTURE_WORK.md`.
|
| 4 |
+
|
| 5 |
+
Authoritative source: `../prd.md` (58).
|
| 6 |
+
|
| 7 |
+
## Repo layout (locked)
|
| 8 |
+
|
| 9 |
+
Target layout (names are contracts; adjust only if repo already differs):
|
| 10 |
+
|
| 11 |
+
- `commitguard_env/`
|
| 12 |
+
- `models.py` typed dataclasses: `Action`, `Observation`, `EnvState`, `GroundTruth`
|
| 13 |
+
- `parse_action.py` XML action parser (robust to malformed output)
|
| 14 |
+
- `reward.py` `compute_reward(...) -> float` (pure function)
|
| 15 |
+
- `environment.py` `CommitGuardEnvironment` implementing OpenEnv reset/step/state
|
| 16 |
+
- `server.py` FastAPI app exposing OpenEnv HTTP endpoints
|
| 17 |
+
- `data/`
|
| 18 |
+
- `devign_filtered.jsonl` dataset embedded in Docker image
|
| 19 |
+
- `cwe_keywords.json` top-10 CWE keyword map (for exploit sketch bonus)
|
| 20 |
+
- `tests/` blocking tests listed in `test_contracts.md`
|
| 21 |
+
- `scripts/` dataset preprocessing and ops scripts (CLI-first)
|
| 22 |
+
- `README.md` story + links + how to run
|
| 23 |
+
|
| 24 |
+
If the codebase already has a different structure, keep the same semantics and update this file to match.
|
| 25 |
+
|
| 26 |
+
## Dataclass schemas (typed; no untyped dicts in public APIs)
|
| 27 |
+
|
| 28 |
+
All public shapes are typed dataclasses. Internal parsing may use dicts, but boundaries must be dataclasses.
|
| 29 |
+
|
| 30 |
+
### `Action`
|
| 31 |
+
|
| 32 |
+
- **Raw input**: `raw_action: str` (the model output)
|
| 33 |
+
- **Parsed**:
|
| 34 |
+
- `action_type: Literal["request_context", "analyze", "verdict"]`
|
| 35 |
+
- `fields: ActionFields` (typed union by action_type)
|
| 36 |
+
|
| 37 |
+
### `Observation` (cheating-prevention critical)
|
| 38 |
+
|
| 39 |
+
Must include only:
|
| 40 |
+
- `episode_id: str`
|
| 41 |
+
- `step_idx: int`
|
| 42 |
+
- `diff: str` (code_before/code_after diff or unified diff string)
|
| 43 |
+
- `repo_files: list[str]` (or `available_files`)
|
| 44 |
+
- `context_snippets: list[ContextSnippet]` (only if requested)
|
| 45 |
+
- `budget_remaining: int`
|
| 46 |
+
- `error: str | None` (for malformed actions, etc.)
|
| 47 |
+
|
| 48 |
+
Must **never** include:
|
| 49 |
+
- `is_vulnerable`, `label`, `ground_truth`, `cwe_type`, `target_file_with_label`
|
| 50 |
+
- anything that trivially implies the label (e.g., this sample is vulnerable)
|
| 51 |
+
|
| 52 |
+
### `GroundTruth` (server-only)
|
| 53 |
+
|
| 54 |
+
Lives only on the server. Never serialized into observations.
|
| 55 |
+
- `is_vulnerable: bool`
|
| 56 |
+
- `cwe: str | None`
|
| 57 |
+
- `target_file: str`
|
| 58 |
+
- `exploit_keywords: list[str]` (or derived via CWE map)
|
| 59 |
+
|
| 60 |
+
## Cheating-prevention rule (non-negotiable)
|
| 61 |
+
|
| 62 |
+
**Observation must never contain ground truth.** Reward is the only scalar feedback; it must not leak label via strings or metadata.
|
| 63 |
+
|
| 64 |
+
Enforcement:
|
| 65 |
+
- observation schema excludes forbidden fields
|
| 66 |
+
- `tests/test_no_leak.py` asserts forbidden keys and suspicious strings never appear
|
| 67 |
+
- server returns reward as a float only; never returns label/cwe for debugging
|
| 68 |
+
|
| 69 |
+
## Episode contract
|
| 70 |
+
|
| 71 |
+
- Max **5 steps** per episode.
|
| 72 |
+
- Episode ends when `verdict` is received OR budget hits zero.
|
| 73 |
+
- `request_context` consumes budget and has per-step penalty.
|
| 74 |
+
- `analyze` is allowed, logged, and should not affect reward directly.
|
| 75 |
+
|
| 76 |
+
## Reward function (signature + invariants)
|
| 77 |
+
|
| 78 |
+
Reward is RLVR: computed from ground truth and simple keyword checks, **not** an LLM judge.
|
| 79 |
+
|
| 80 |
+
Signature:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
def compute_reward(
|
| 84 |
+
action: "Action",
|
| 85 |
+
ground_truth: "GroundTruth",
|
| 86 |
+
*,
|
| 87 |
+
cwe_keywords: dict[str, list[str]],
|
| 88 |
+
context_requests: int,
|
| 89 |
+
) -> float: ...
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Reward shape (from PRD):
|
| 93 |
+
- correct vulnerable/safe: **+1.0**
|
| 94 |
+
- correct CWE (when vulnerable): **+0.5**
|
| 95 |
+
- plausible exploit sketch (keyword match): **+0.5**
|
| 96 |
+
- false positive: **-1.0**
|
| 97 |
+
- false negative: **-0.5**
|
| 98 |
+
- per context request: **-0.05**
|
| 99 |
+
- malformed action: penalize (recommended **-0.5**) but do not crash
|
| 100 |
+
|
| 101 |
+
## XML action format (the model output contract)
|
| 102 |
+
|
| 103 |
+
Model outputs exactly one top-level `<action>` block. Parser must tolerate:
|
| 104 |
+
- extra whitespace
|
| 105 |
+
- missing fields (treated as malformed)
|
| 106 |
+
- wrong casing (normalize)
|
| 107 |
+
- stray text before/after tags
|
| 108 |
+
- malformed XML (best-effort extraction; never crash)
|
| 109 |
+
|
| 110 |
+
### Spec
|
| 111 |
+
|
| 112 |
+
Top-level:
|
| 113 |
+
- `<action>`
|
| 114 |
+
- `<action_type>request_context|analyze|verdict</action_type>`
|
| 115 |
+
- `<fields>...</fields>`
|
| 116 |
+
- `</action>`
|
| 117 |
+
|
| 118 |
+
Fields by type:
|
| 119 |
+
|
| 120 |
+
**request_context**
|
| 121 |
+
- `<file_path>path/in/repo.ext</file_path>`
|
| 122 |
+
- optional: `<start_line>int</start_line>`, `<end_line>int</end_line>`
|
| 123 |
+
|
| 124 |
+
**analyze**
|
| 125 |
+
- `<reasoning>free text</reasoning>`
|
| 126 |
+
|
| 127 |
+
**verdict**
|
| 128 |
+
- `<is_vulnerable>true|false</is_vulnerable>`
|
| 129 |
+
- `<vuln_type>CWE-79|CWE-89|...|NONE</vuln_type>`
|
| 130 |
+
- `<exploit_sketch>free text</exploit_sketch>`
|
| 131 |
+
|
| 132 |
+
Parsing rules:
|
| 133 |
+
- if `action_type` missing/invalid malformed
|
| 134 |
+
- booleans accept `true/false/1/0/yes/no` (case-insensitive)
|
| 135 |
+
- `vuln_type` normalized; if safe verdict, allow `NONE`
|
| 136 |
+
- on malformed: return a safe `Action` with `action_type="analyze"` and `error` set, and apply malformed penalty
|
| 137 |
+
|
| 138 |
+
## Env server HTTP endpoints (P0)
|
| 139 |
+
|
| 140 |
+
The env server must expose these endpoints (names from PRD 8.1):
|
| 141 |
+
|
| 142 |
+
- `GET /health` 200 OK and simple JSON payload
|
| 143 |
+
- `POST /reset` returns initial `Observation` (+ episode id)
|
| 144 |
+
- `POST /step` accepts raw action string, returns `{observation, reward, done, info}`
|
| 145 |
+
- `GET /state` returns minimal server/env state for debugging (no ground truth)
|
| 146 |
+
- `GET /docs` FastAPI OpenAPI docs (automatic)
|
| 147 |
+
|
| 148 |
+
Do not add new endpoints after scope freeze unless required for reliability.
|
| 149 |
+
|
.agent/checkpoints.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Checkpoints (sync-or-die contract)
|
| 2 |
+
|
| 3 |
+
Goal: keep three engineers aligned and prevent cool demo scope creep from killing the submission. Source: `../prd.md` 12.
|
| 4 |
+
|
| 5 |
+
### Checkpoint 1 Midnight (00:00 IST) scope freeze + Phase 1 gate
|
| 6 |
+
|
| 7 |
+
**Everyone must demonstrate (live, locally or on Space):**
|
| 8 |
+
- **Env server runs** and responds to `GET /health`
|
| 9 |
+
- **OpenEnv loop works**: `reset` `step` done, without crashing
|
| 10 |
+
- **Action parser is robust**: malformed XML doesnt crash; returns safe error
|
| 11 |
+
- **No-leak invariant**: observation contains no ground truth fields
|
| 12 |
+
|
| 13 |
+
**Role deliverables:**
|
| 14 |
+
- **Env/Server owner**: endpoints exist (`/health`, `/reset`, `/step`, `/state`, `/docs`)
|
| 15 |
+
- **Reward owner**: reward function wired and deterministic on handcrafted cases
|
| 16 |
+
- **Training owner**: mock training loop can call env repeatedly (even if reward is dummy)
|
| 17 |
+
|
| 18 |
+
**If any of these are red, trigger a scope cut immediately:**
|
| 19 |
+
- 3-action env incomplete cut to 2-action env (analyze + verdict)
|
| 20 |
+
- Tiered reward unstable cut to binary reward only
|
| 21 |
+
|
| 22 |
+
**After this checkpoint:**
|
| 23 |
+
- **Scope freeze is active.** New features go to `.agent/FUTURE_WORK.md` only.
|
| 24 |
+
|
| 25 |
+
### Checkpoint 2 9:00 AM Sunday training evidence gate
|
| 26 |
+
|
| 27 |
+
**Everyone must demonstrate:**
|
| 28 |
+
- Training run launched (HF Jobs A10G preferred) or fallback running
|
| 29 |
+
- Wandb logging works (reward curve visible)
|
| 30 |
+
- Evaluation script/notebook can run 100 held-out samples
|
| 31 |
+
|
| 32 |
+
**Scope-cut triggers:**
|
| 33 |
+
- Training blocked by infra >30 min move to GCP A10G fallback
|
| 34 |
+
- Training curve still flat by 10:00 AM commit to qualitative narrative (no more training tweaks)
|
| 35 |
+
|
| 36 |
+
**What gets cut first (in order):**
|
| 37 |
+
1. P2 items (web UI polish, blog post)
|
| 38 |
+
2. Per-CWE breakdown (keep overall accuracy)
|
| 39 |
+
3. Exploit sketch bonus (keep binary + CWE if stable)
|
| 40 |
+
4. CWE classification bonus (keep binary only)
|
| 41 |
+
|
| 42 |
+
### Checkpoint 3 3:00 PM Sunday feature freeze gate
|
| 43 |
+
|
| 44 |
+
**Everyone must demonstrate:**
|
| 45 |
+
- HF Space is live and stable; `/health` 200; `/docs` loads
|
| 46 |
+
- `tests/` pass (see `.agent/test_contracts.md`)
|
| 47 |
+
- Demo artifact path is locked (video or text-trace fallback)
|
| 48 |
+
- README has all submission links (Space, notebook, video, wandb, repo)
|
| 49 |
+
|
| 50 |
+
**Hard rule:**
|
| 51 |
+
- **No changes after 3:00 PM** except emergency fixes that prevent submission failure.
|
| 52 |
+
|
| 53 |
+
**Final scope cuts (if needed to protect submission):**
|
| 54 |
+
1. Video text trace in README
|
| 55 |
+
2. Training curve single plot + narrative
|
| 56 |
+
3. Held-out eval small N sanity check
|
| 57 |
+
|
.agent/coding_conventions.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Coding conventions (enforced under deadline pressure)
|
| 2 |
+
|
| 3 |
+
This repo is optimized for: **correctness, reproducibility, and not leaking labels**. Read `architecture.md` first.
|
| 4 |
+
|
| 5 |
+
## Python style (hard rules)
|
| 6 |
+
|
| 7 |
+
- **Typed dataclasses everywhere** for public API shapes (actions/observations/state).
|
| 8 |
+
- Use `@dataclass(frozen=True, slots=True)` by default.
|
| 9 |
+
- Public functions must be type-annotated end-to-end.
|
| 10 |
+
- **No untyped dicts in public APIs.** Dicts are allowed only internally (e.g., during XML parse), and must be converted to dataclasses at the boundary.
|
| 11 |
+
- Keep functions small. Prefer pure functions (`reward.py`) with no hidden state.
|
| 12 |
+
|
| 13 |
+
## Import ordering
|
| 14 |
+
|
| 15 |
+
1. stdlib
|
| 16 |
+
2. third-party
|
| 17 |
+
3. local modules
|
| 18 |
+
|
| 19 |
+
Within a section: alphabetical. One import per line if it improves diff clarity.
|
| 20 |
+
|
| 21 |
+
## Docstrings and naming
|
| 22 |
+
|
| 23 |
+
- Docstrings: short, imperative, include constraints (e.g., must not leak ground truth).
|
| 24 |
+
- Names: explicit over clever (`compute_reward`, `parse_action_xml`, `EpisodeState`).
|
| 25 |
+
|
| 26 |
+
## Error handling patterns
|
| 27 |
+
|
| 28 |
+
- **Never crash on model output.** Malformed actions must be handled gracefully.
|
| 29 |
+
- Raise exceptions only for programmer errors; user/model errors return structured error fields.
|
| 30 |
+
- Every boundary (HTTP handlers, XML parser) must be defensive:
|
| 31 |
+
- validate inputs
|
| 32 |
+
- clamp budgets
|
| 33 |
+
- return safe defaults
|
| 34 |
+
|
| 35 |
+
## Forbidden patterns (do not do these)
|
| 36 |
+
|
| 37 |
+
- **No LLM-as-judge in reward.** Reward must be verifiable (dataset truth + keyword checks). See `architecture.md`.
|
| 38 |
+
- **No label leakage**: do not log, return, or print ground truth in observations, HTTP responses, or debug endpoints.
|
| 39 |
+
- **No hardcoded local paths** (e.g., `C:\\Users\\...`, `/home/...`). Use repo-relative paths + `pathlib`.
|
| 40 |
+
- **No committing data files > 5MB** without explicit team sign-off. (If necessary, use HF Datasets or remote storage.)
|
| 41 |
+
- **No localStorage in any UI.** If you add UI later (unlikely), store state server-side or in-memory only.
|
| 42 |
+
- **No adding endpoints/features after scope freeze** (midnight Saturday).
|
| 43 |
+
|
| 44 |
+
## Repo hygiene
|
| 45 |
+
|
| 46 |
+
- Prefer CLI-driven ops so teammates can reproduce quickly:
|
| 47 |
+
- HF: `huggingface-cli`, `hf` (where available), `git lfs` if needed
|
| 48 |
+
- GCP: `gcloud`
|
| 49 |
+
- Keep logs minimal. Under hackathon pressure, noisy logs hide real bugs.
|
| 50 |
+
- Dont vendor big artifacts in git. Link them (video, wandb, Space) from README.
|
| 51 |
+
|
| 52 |
+
## Scope creep rule (non-negotiable)
|
| 53 |
+
|
| 54 |
+
If youre tempted to add a feature that isnt required for the locked deliverables:
|
| 55 |
+
- Append one bullet to `FUTURE_WORK.md` (with 1-line rationale).
|
| 56 |
+
- Return to your current task.
|
| 57 |
+
|
| 58 |
+
## Cross-reference
|
| 59 |
+
|
| 60 |
+
- Architecture contract: `architecture.md`
|
| 61 |
+
- Scope and fallbacks: `project_context.md`
|
| 62 |
+
- Locked decisions: `decision_log.md`
|
| 63 |
+
|
.agent/decision_log.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Decision log (locked + fallbacks)
|
| 2 |
+
|
| 3 |
+
This file is a **contract**. It mirrors `../prd.md` 7.1 and 7.2.
|
| 4 |
+
|
| 5 |
+
If you want to change a decision: you dont. If you must due to a trigger, use the fallback and log it.
|
| 6 |
+
|
| 7 |
+
## Locked technical decisions (PRD 7.1)
|
| 8 |
+
|
| 9 |
+
| Decision | Choice | Rationale |
|
| 10 |
+
|---|---|---|
|
| 11 |
+
| Env framework | Meta OpenEnv 0.2.3+ | Mandatory per submission rules |
|
| 12 |
+
| Server runtime | FastAPI in Docker | OpenEnv default, lowest friction |
|
| 13 |
+
| Hosting | Hugging Face Space | Mandatory; server+repo+registry |
|
| 14 |
+
| Data source | Devign (DetectBERT subset) | Real CWE labels, manageable size |
|
| 15 |
+
| Model | Llama-3.2-3B-Instruct | Meta-branded; fits A10G with GRPO |
|
| 16 |
+
| Training framework | TRL with GRPO | Native OpenEnv integration via reward funcs |
|
| 17 |
+
| Training optimization | Unsloth 4-bit + LoRA r=8 | Big memory reduction + speed |
|
| 18 |
+
| Training infra | HF Jobs A10G | Unattended, HF-native |
|
| 19 |
+
| Dev infra | GCP VM with T4 | Stable, no Colab disconnects |
|
| 20 |
+
| Action serialization | XML-tag free-text | Robust to small-model variance |
|
| 21 |
+
| Logging | Weights & Biases | TRL native; shareable runs |
|
| 22 |
+
|
| 23 |
+
## Pre-approved fallback rules (PRD 7.2)
|
| 24 |
+
|
| 25 |
+
| If this fails | Fall back to | Trigger condition |
|
| 26 |
+
|---|---|---|
|
| 27 |
+
| Llama-3.2-3B OOM on A10G | Qwen2.5-1.5B-Instruct | First test step crashes |
|
| 28 |
+
| HF Jobs queue full | GCP A10G on-demand | Job queues for >30 min |
|
| 29 |
+
| 3-action env doesnt ship by midnight | 2-action env (analyze + verdict) | Midnight checkpoint is red |
|
| 30 |
+
| Tiered reward buggy | Binary correct/incorrect reward | Reward checkpoint is red |
|
| 31 |
+
| Training curve flat | Qualitative comparison only | Still flat at 10 AM Sunday |
|
| 32 |
+
| Demo video hard to record | Side-by-side text trace in README | Recording fails twice |
|
| 33 |
+
|
| 34 |
+
## New decisions made during the build
|
| 35 |
+
|
| 36 |
+
Rule: any new decision must be logged here with timestamp + author and must not violate the locked PRD unless its a PRD-defined fallback.
|
| 37 |
+
|
| 38 |
+
Template:
|
| 39 |
+
- **[YYYY-MM-DD HH:MM IST] (author)**: decision rationale impact rollback plan
|
| 40 |
+
|
.agent/git_workflow.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Git workflow (parallel, safe, deadline-optimized)
|
| 2 |
+
|
| 3 |
+
This repo will have three engineers working in parallel with agents. The workflow exists to prevent integration chaos.
|
| 4 |
+
|
| 5 |
+
## Branch naming (required)
|
| 6 |
+
|
| 7 |
+
Format: `<name>/<short-scope>`
|
| 8 |
+
|
| 9 |
+
Examples:
|
| 10 |
+
- `niti/env-scaffolding`
|
| 11 |
+
- `deepak/data-pipeline`
|
| 12 |
+
- `divyank/training-grpo`
|
| 13 |
+
|
| 14 |
+
Rules:
|
| 15 |
+
- One scope per branch.
|
| 16 |
+
- If a branch grows beyond 23 related commits, cut scope or split.
|
| 17 |
+
|
| 18 |
+
## Commit message convention (required)
|
| 19 |
+
|
| 20 |
+
Use **Conventional Commits**:
|
| 21 |
+
|
| 22 |
+
- `feat(env): add OpenEnv reset/step`
|
| 23 |
+
- `fix(parser): handle malformed xml without crash`
|
| 24 |
+
- `test(reward): add 5 handcrafted cases`
|
| 25 |
+
- `docs(readme): add demo + wandb links`
|
| 26 |
+
|
| 27 |
+
Rules:
|
| 28 |
+
- Short subject, present tense.
|
| 29 |
+
- Prefer why over what in body.
|
| 30 |
+
|
| 31 |
+
## Merge policy (hard rules)
|
| 32 |
+
|
| 33 |
+
- Merge to `main` **only after** the relevant tests pass locally:
|
| 34 |
+
- Env changes: `test_no_leak.py`, `test_env_smoke.py`, `test_action_parser.py`
|
| 35 |
+
- Reward changes: `test_reward.py` + `test_no_leak.py`
|
| 36 |
+
- Parser changes: `test_action_parser.py` + `test_env_smoke.py`
|
| 37 |
+
- No merge now, fix later. Under deadline, broken `main` is a team-wide blocker.
|
| 38 |
+
|
| 39 |
+
## Force-push rules
|
| 40 |
+
|
| 41 |
+
- Before midnight Saturday: allowed on your feature branches if necessary.
|
| 42 |
+
- **After midnight Saturday: no force-push to `main` (ever).**
|
| 43 |
+
- Prefer no force-push at all; use revert commits if needed.
|
| 44 |
+
|
| 45 |
+
## PR expectations (fast reviews)
|
| 46 |
+
|
| 47 |
+
Each PR must include:
|
| 48 |
+
- 13 sentence summary
|
| 49 |
+
- test plan (what you ran)
|
| 50 |
+
- risk note (what could break)
|
| 51 |
+
|
| 52 |
+
If its large, its wrong: split it.
|
| 53 |
+
|
| 54 |
+
## Pre-submission checklist (Sunday)
|
| 55 |
+
|
| 56 |
+
By 3 PM:
|
| 57 |
+
- [ ] HF Space live; `/health` 200; `/docs` loads
|
| 58 |
+
- [ ] Blocking tests pass (`.agent/test_contracts.md`)
|
| 59 |
+
- [ ] Training artifact exists (plots + wandb link)
|
| 60 |
+
- [ ] Demo artifact exists (video URL or text trace fallback)
|
| 61 |
+
- [ ] README links all resolve (Space, notebook, video, wandb, repo)
|
| 62 |
+
|
| 63 |
+
By 4:30 PM:
|
| 64 |
+
- [ ] Fresh clone + run instructions work
|
| 65 |
+
- [ ] Final smoke test: 100 episodes dont crash
|
| 66 |
+
- [ ] Submission package is complete
|
| 67 |
+
|
| 68 |
+
## CLI-first ops (HF + GCP)
|
| 69 |
+
|
| 70 |
+
Keep ops repeatable. Prefer CLI over UI clicks.
|
| 71 |
+
|
| 72 |
+
Hugging Face:
|
| 73 |
+
- `huggingface-cli login`
|
| 74 |
+
- `huggingface-cli whoami`
|
| 75 |
+
- Use git-based Space workflow (clone, commit, push) for deploys.
|
| 76 |
+
|
| 77 |
+
GCP:
|
| 78 |
+
- `gcloud auth login`
|
| 79 |
+
- `gcloud config set project <PROJECT_ID>`
|
| 80 |
+
- Use `gcloud compute ssh` + `gcloud compute instances list` for VM workflow.
|
| 81 |
+
|
| 82 |
+
Cross-reference:
|
| 83 |
+
- Merge gates: `test_contracts.md`
|
| 84 |
+
- Scope freeze + fallbacks: `project_context.md`
|
| 85 |
+
|
.agent/project_context.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## CommitGuard: project context (load this first)
|
| 2 |
+
|
| 3 |
+
This file is the **single source of truth for agents**. It compresses `../prd.md` into must-know facts so you can make correct decisions at 3 AM.
|
| 4 |
+
|
| 5 |
+
If youre unsure: re-read `../prd.md` and then update this file to match.
|
| 6 |
+
|
| 7 |
+
## What were building
|
| 8 |
+
|
| 9 |
+
**CommitGuard** is a **Meta OpenEnv** reinforcement learning environment where an LLM agent learns to detect exploitable vulnerabilities in **code commits** (single-file diffs) and output a vulnerability verdict + CWE type + exploit sketch.
|
| 10 |
+
|
| 11 |
+
The environment runs as an **HTTP server (FastAPI in Docker)**, hosted on **Hugging Face Spaces**. Training runs with **TRL GRPO + Unsloth** on **Llama3.23BInstruct**, using verifiable rewards from dataset ground truth (RLVR).
|
| 12 |
+
|
| 13 |
+
## Why this matters (the thesis)
|
| 14 |
+
|
| 15 |
+
AI writes code at AI speed. Security review still runs on human cycles. Offense can now scale with the same LLM tooling. **Were building the RL environment that trains AI-paced commit-time security review.**
|
| 16 |
+
|
| 17 |
+
## Who its for
|
| 18 |
+
|
| 19 |
+
- **Hackathon judges / Meta partner engineers**: want innovation + evidence (learning curve) + clean story.
|
| 20 |
+
- **Meta researchers**: want RLVR framing, cheating-prevention, and extensibility.
|
| 21 |
+
- **HF community**: wants a runnable Space + reproducible training notebook.
|
| 22 |
+
|
| 23 |
+
## 30-second pitch (verbatim; memorize)
|
| 24 |
+
|
| 25 |
+
> "AI is now writing production code at AI speed. Security review still runs on a 6-month human cycle. The same LLMs that write the code can attack it defense is on human time, offense is on AI time, and that asymmetry breaks the security model.
|
| 26 |
+
>
|
| 27 |
+
> CommitGuard is an OpenEnv where an agent learns to flag exploitable diffs at commit time. We trained Llama-3.2-3B on it via GRPO and the detection rate climbs measurably. It's RLVR verifiable rewards from ground truth, not LLM judges. The thesis: continuous AI red-teaming at the velocity code is being shipped. This is the environment to train it."
|
| 28 |
+
|
| 29 |
+
## Locked stack (do not change)
|
| 30 |
+
|
| 31 |
+
- **Env framework**: Meta OpenEnv **0.2.3+**
|
| 32 |
+
- **Server**: **FastAPI** in **Docker**
|
| 33 |
+
- **Hosting**: **Hugging Face Space**
|
| 34 |
+
- **Data**: **Devign** (Devign/DetectBERT subset); filtered to single-file commits <80 LOC; ~balanced
|
| 35 |
+
- **Model**: **Llama3.23BInstruct**
|
| 36 |
+
- **Training**: **TRL** with **GRPO**
|
| 37 |
+
- **Optimization**: **Unsloth** 4bit + **LoRA r=8**
|
| 38 |
+
- **Infra**: **HF Jobs A10G** for training; **GCP VM with T4** for dev/stability
|
| 39 |
+
- **Action serialization**: **XML-tag free-text** (not JSON-mode)
|
| 40 |
+
- **Logging**: **Weights & Biases**
|
| 41 |
+
|
| 42 |
+
Operational preference: **use CLI** for HF + GCP actions (repeatable, copy/paste-able, no UI-clicking).
|
| 43 |
+
|
| 44 |
+
## Submission deliverables (P0)
|
| 45 |
+
|
| 46 |
+
- **HF Space** deployed; `/health` returns 200; `/docs` works
|
| 47 |
+
- **Training notebook / script** produces a measurable learning curve (or triggers fallback)
|
| 48 |
+
- **Plots** committed (reward curve + baseline vs trained)
|
| 49 |
+
- **Demo video** (6090s) showing before/after behavior on one example
|
| 50 |
+
- **README** with all required links (Space, notebook, video, repo, wandb)
|
| 51 |
+
|
| 52 |
+
## Hard constraints (time + scope)
|
| 53 |
+
|
| 54 |
+
- **Deadline**: Sunday **5:00 PM IST** (non-negotiable)
|
| 55 |
+
- **Scope freeze**: **midnight Saturday (00:00 IST)** after this, no new features
|
| 56 |
+
- **Episode constraints**: max **5 steps** per episode; context requests cost reward
|
| 57 |
+
|
| 58 |
+
## Explicit non-goals (do not drift)
|
| 59 |
+
|
| 60 |
+
- Not a production CI security tool; **research environment only**
|
| 61 |
+
- No real exploit execution sandbox in v1 (pattern match only)
|
| 62 |
+
- No multi-file / repo-level reasoning in v1 (single-file commits, <=80 LOC)
|
| 63 |
+
- No multi-agent self-play in v1
|
| 64 |
+
- No network/runtime attacks, no social engineering
|
| 65 |
+
- No cover all CWEs: v1 focuses on **top 10 CWEs** in Devign
|
| 66 |
+
- No fancy frontend: HF Space default UI is enough
|
| 67 |
+
|
| 68 |
+
## If something breaks: pre-approved fallbacks (no debate)
|
| 69 |
+
|
| 70 |
+
These are legal pivots from `../prd.md` 7.2. If trigger happens, switch immediately and log it in `decision_log.md`.
|
| 71 |
+
|
| 72 |
+
- **OOM on Llama3.23B on A10G** use **Qwen2.51.5BInstruct** (trigger: first test step crashes)
|
| 73 |
+
- **HF Jobs queue > 30 min** use **GCP A10G on-demand**
|
| 74 |
+
- **3-action env not shipped by midnight** ship **2-action env** (analyze + verdict)
|
| 75 |
+
- **Tiered reward buggy** ship **binary reward only**
|
| 76 |
+
- **Training curve still flat at 10 AM Sunday** ship **qualitative comparison narrative**
|
| 77 |
+
- **Demo video recording fails twice** ship **side-by-side text trace in README**
|
| 78 |
+
|
| 79 |
+
## Next file to read
|
| 80 |
+
|
| 81 |
+
Read `architecture.md` next. Then read your per-person task list (e.g. `../tasks_niti.md`) if present.
|
| 82 |
+
|
.agent/test_contracts.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Test contracts (merge blockers)
|
| 2 |
+
|
| 3 |
+
These tests are **merge gates**. If any fails, do not merge to `main`. See `git_workflow.md`.
|
| 4 |
+
|
| 5 |
+
Owners are initial; if you touch the area, you own the test too.
|
| 6 |
+
|
| 7 |
+
### `tests/test_no_leak.py`
|
| 8 |
+
|
| 9 |
+
- **Asserts**:
|
| 10 |
+
- `Observation` serialization never includes ground-truth fields (e.g., `is_vulnerable`, `ground_truth`, `label`, `cwe_type`).
|
| 11 |
+
- Response payloads from `/reset` and `/step` do not contain forbidden keys or suspicious strings that imply labels.
|
| 12 |
+
- **Owner**: Niti (env integrity)
|
| 13 |
+
- **Blocking condition**: Any leakage is a submission-killer. Must be fixed immediately.
|
| 14 |
+
|
| 15 |
+
### `tests/test_reward.py`
|
| 16 |
+
|
| 17 |
+
- **Asserts**: `compute_reward(...)` returns expected values for **5 handcrafted cases**:
|
| 18 |
+
1. True positive + correct CWE + exploit match
|
| 19 |
+
2. True positive + wrong CWE
|
| 20 |
+
3. False positive
|
| 21 |
+
4. False negative
|
| 22 |
+
5. Malformed action penalty (and no crash)
|
| 23 |
+
- **Owner**: Deepak (reward design)
|
| 24 |
+
- **Blocking condition**: If tiered reward is flaky, trigger fallback to binary reward (log in `decision_log.md`).
|
| 25 |
+
|
| 26 |
+
### `tests/test_action_parser.py`
|
| 27 |
+
|
| 28 |
+
- **Asserts**:
|
| 29 |
+
- XML action parsing works for all 3 action types.
|
| 30 |
+
- Parser is robust to malformed inputs (missing tags, invalid XML, extra text).
|
| 31 |
+
- Parser never throws; returns a safe Action + error info.
|
| 32 |
+
- **Owner**: Divyank (agent I/O contract)
|
| 33 |
+
- **Blocking condition**: Any parser crash blocks training and demo; fix before anything else.
|
| 34 |
+
|
| 35 |
+
### `tests/test_env_smoke.py`
|
| 36 |
+
|
| 37 |
+
- **Asserts**:
|
| 38 |
+
- 100 random episodes do not crash.
|
| 39 |
+
- `reset`/`step` latency stays reasonable and budget cap terminates episodes.
|
| 40 |
+
- Malformed actions do not crash and return done when appropriate.
|
| 41 |
+
- **Owner**: Niti (env reliability)
|
| 42 |
+
- **Blocking condition**: If smoke test fails, training is not allowed to run.
|
| 43 |
+
|
| 44 |
+
## Required behavior under failure
|
| 45 |
+
|
| 46 |
+
- If a test reveals a scope-level failure, use a PRD-approved fallback (see `project_context.md`) rather than inventing new features.
|
| 47 |
+
- If a failure requires a new decision, log it in `decision_log.md` with timestamp + author.
|
| 48 |
+
|
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(python -m pip install -e .)",
|
| 5 |
+
"Bash(python *)",
|
| 6 |
+
"Bash(pip install *)",
|
| 7 |
+
"Bash(.venv/Scripts/pip install *)",
|
| 8 |
+
"Bash(.venv/Scripts/python.exe *)",
|
| 9 |
+
"Bash(grep -v \"^d.*\\\\.\\\\|^total\\\\|^$\")"
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
}
|
.dockerignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.mypy_cache/
|
| 5 |
+
.ruff_cache/
|
| 6 |
+
.venv/
|
| 7 |
+
venv/
|
| 8 |
+
ENV/
|
| 9 |
+
.uv-cache/
|
| 10 |
+
wandb/
|
| 11 |
+
outputs/
|
| 12 |
+
temp_deploy/
|
| 13 |
+
temp_space/
|
| 14 |
+
temp_write_probe/
|
| 15 |
+
temp_pip_*/
|
| 16 |
+
*.log
|
| 17 |
+
.git/
|
.gitignore
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*.pyd
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
.mypy_cache/
|
| 6 |
+
.ruff_cache/
|
| 7 |
+
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
ENV/
|
| 11 |
+
.uv-cache/
|
| 12 |
+
|
| 13 |
+
build/
|
| 14 |
+
dist/
|
| 15 |
+
*.egg-info/
|
| 16 |
+
commitguard.egg-info/
|
| 17 |
+
|
| 18 |
+
.DS_Store
|
| 19 |
+
|
| 20 |
+
# Local tooling / logs
|
| 21 |
+
wandb/
|
| 22 |
+
*.log
|
| 23 |
+
outputs/
|
| 24 |
+
|
| 25 |
+
# IDE
|
| 26 |
+
.vscode/
|
| 27 |
+
.idea/
|
| 28 |
+
|
| 29 |
+
# Temporary
|
| 30 |
+
*.tmp
|
| 31 |
+
temp_space/
|
| 32 |
+
temp_deploy/
|
| 33 |
+
temp_pip_*/
|
| 34 |
+
temp_write_probe/
|
| 35 |
+
unsloth_compiled_cache/
|
| 36 |
+
.venv-check/
|
.pre-commit-hooks.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- id: commitguard
|
| 2 |
+
name: CommitGuard vulnerability scan
|
| 3 |
+
entry: commitguard scan --staged --format text --fail-on-vulnerable
|
| 4 |
+
language: python
|
| 5 |
+
stages: [pre-commit]
|
| 6 |
+
pass_filenames: false
|
| 7 |
+
additional_dependencies: ["commitguard[scan]"]
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.analysis.extraPaths": [
|
| 3 |
+
"${workspaceFolder}",
|
| 4 |
+
"${workspaceFolder}/scripts"
|
| 5 |
+
],
|
| 6 |
+
"python.autoComplete.extraPaths": [
|
| 7 |
+
"${workspaceFolder}",
|
| 8 |
+
"${workspaceFolder}/scripts"
|
| 9 |
+
]
|
| 10 |
+
}
|
.vscode/tasks.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "2.0.0",
|
| 3 |
+
"tasks": [
|
| 4 |
+
{
|
| 5 |
+
"label": "CommitGuard: Scan Staged Changes",
|
| 6 |
+
"type": "shell",
|
| 7 |
+
"command": "commitguard scan --staged --format text",
|
| 8 |
+
"problemMatcher": [],
|
| 9 |
+
"presentation": {
|
| 10 |
+
"reveal": "always",
|
| 11 |
+
"panel": "new"
|
| 12 |
+
},
|
| 13 |
+
"group": "test"
|
| 14 |
+
}
|
| 15 |
+
]
|
| 16 |
+
}
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
COPY pyproject.toml README.md ./
|
| 8 |
+
COPY commitguard_env/ commitguard_env/
|
| 9 |
+
COPY data/ data/
|
| 10 |
+
COPY configs/ configs/
|
| 11 |
+
COPY server/ server/
|
| 12 |
+
|
| 13 |
+
RUN pip install -e .
|
| 14 |
+
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
CMD ["uvicorn", "commitguard_env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
Dockerfile.train
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use CUDA 12.1 base image
|
| 2 |
+
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# Avoid prompts
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 6 |
+
|
| 7 |
+
# Install Python 3.11 and other essentials
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
python3.11 \
|
| 10 |
+
python3-pip \
|
| 11 |
+
python3.11-dev \
|
| 12 |
+
git \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Set python3.11 as default python
|
| 16 |
+
RUN ln -s /usr/bin/python3.11 /usr/bin/python
|
| 17 |
+
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
# Upgrade pip
|
| 21 |
+
RUN pip install --no-cache-dir -U pip setuptools wheel
|
| 22 |
+
|
| 23 |
+
# Install PyTorch with CUDA 12.1 support
|
| 24 |
+
RUN pip install --no-cache-dir \
|
| 25 |
+
torch==2.4.0 \
|
| 26 |
+
triton \
|
| 27 |
+
xformers \
|
| 28 |
+
--index-url https://download.pytorch.org/whl/cu121
|
| 29 |
+
|
| 30 |
+
# Install Unsloth and let it resolve its own compatible TRL/PEFT stack.
|
| 31 |
+
RUN pip install --no-cache-dir \
|
| 32 |
+
"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
|
| 33 |
+
datasets \
|
| 34 |
+
wandb \
|
| 35 |
+
matplotlib \
|
| 36 |
+
fastapi \
|
| 37 |
+
uvicorn \
|
| 38 |
+
pydantic
|
| 39 |
+
|
| 40 |
+
# Copy the project files
|
| 41 |
+
COPY . .
|
| 42 |
+
|
| 43 |
+
# Install the local package in editable mode
|
| 44 |
+
RUN pip install -e .
|
| 45 |
+
|
| 46 |
+
# Make scripts executable
|
| 47 |
+
RUN chmod +x scripts/*.py
|
| 48 |
+
|
| 49 |
+
# Set environment variables
|
| 50 |
+
ENV MODEL_NAME="meta-llama/Llama-3.2-3B-Instruct"
|
| 51 |
+
ENV OUTPUT_DIR="outputs/commitguard-llama-3b-grpo"
|
| 52 |
+
ENV WANDB_PROJECT="commitguard"
|
| 53 |
+
|
| 54 |
+
# Default command: Run training and push to Hub
|
| 55 |
+
# Note: HF_TOKEN and WANDB_API_KEY should be set as Space Secrets
|
| 56 |
+
CMD ["python", "scripts/train_grpo.py", "--samples", "200", "--max-steps", "300", "--push-to-hub"]
|
GEMINI.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CommitGuard - Project Context & Instructions
|
| 2 |
+
|
| 3 |
+
This file is the **foundational mandate** for the CommitGuard project. It defines the technical standards, security protocols, and operational workflows that must be followed by all agents.
|
| 4 |
+
|
| 5 |
+
## 🚀 Project Overview
|
| 6 |
+
CommitGuard is a specialized RL environment built on **Meta OpenEnv** for commit-time vulnerability detection. It trains LLM agents (primarily **Llama-3.2-3B-Instruct**) to identify exploitable vulnerabilities in single-file code commits using **Reinforcement Learning from Verifiable Rewards (RLVR)**.
|
| 7 |
+
|
| 8 |
+
- **Objective:** Bridge the gap between AI-speed code generation and human-paced security review.
|
| 9 |
+
- **Framework:** Meta OpenEnv (v0.2.3+).
|
| 10 |
+
- **Incentive:** Tiered rewards grounded in dataset truth (Devign), not LLM judgment.
|
| 11 |
+
|
| 12 |
+
## 📐 Engineering Standards (Non-Negotiable)
|
| 13 |
+
|
| 14 |
+
### 1. The "No-Leak" Rule (Highest Priority)
|
| 15 |
+
The agent must **NEVER** see ground truth labels (`is_vulnerable`, `cwe`, etc.) during an episode.
|
| 16 |
+
- **Constraint:** `CommitGuardObservation` and all reward calculations must be stripped of label fields before being presented to the model.
|
| 17 |
+
- **Validation:** `tests/test_no_leak.py` must remain green. Any change that causes a leak is a blocking failure.
|
| 18 |
+
|
| 19 |
+
### 2. Python Architecture
|
| 20 |
+
- **Typed Dataclasses:** Use `@dataclass(frozen=True, slots=True)` for all API shapes (Actions, Observations, State).
|
| 21 |
+
- **Strict Typing:** Every function and variable must be type-annotated end-to-end.
|
| 22 |
+
- **No Untyped Dicts:** Dicts are for internal parsing only; convert to dataclasses at all boundaries.
|
| 23 |
+
- **Defensive Parsing:** XML parsers must handle malformed model output without crashing, returning safe defaults and structured errors.
|
| 24 |
+
|
| 25 |
+
### 3. XML Action Format
|
| 26 |
+
Models must emit exactly one top-level `<action>` block to ensure robust parsing.
|
| 27 |
+
- **Structure:** `<action><action_type>...</action_type><fields>...</fields></action>`
|
| 28 |
+
- **Types:** `request_context`, `analyze`, `verdict`.
|
| 29 |
+
|
| 30 |
+
## 🛠️ Operational Workflows
|
| 31 |
+
|
| 32 |
+
### 1. Evaluation Pipeline (`scripts/evaluate.py`)
|
| 33 |
+
This script executes local inference on test samples to compute accuracy metrics.
|
| 34 |
+
- **Deterministic Selection:** It iterates through `data/devign_test.jsonl`.
|
| 35 |
+
- **Strict Scoring:** `is_correct` requires both a correct binary verdict AND a correct CWE type match (if vulnerable).
|
| 36 |
+
- **Inference:** Uses Unsloth/FastLanguageModel for accelerated evaluation.
|
| 37 |
+
|
| 38 |
+
### 2. Training Pipeline (`scripts/train_grpo.py`)
|
| 39 |
+
- **Framework:** Uses TRL's `GRPOTrainer` with Unsloth 4-bit quantization.
|
| 40 |
+
- **Local Rewards:** Reward functions are computed in-process (`get_reward_local`) to eliminate latency.
|
| 41 |
+
|
| 42 |
+
### 3. Visualization (`plots/`)
|
| 43 |
+
- `plot_reward_curve.py`: Visualizes reward trends from `eval_results.json`.
|
| 44 |
+
- `plot_per_cwe.py`: Generates bar charts showing accuracy breakdown by CWE category.
|
| 45 |
+
- `plot_baseline_vs_trained.py`: Compares untrained vs. trained model performance.
|
| 46 |
+
|
| 47 |
+
## 📁 Critical Files
|
| 48 |
+
- `commitguard_env/`: Core logic (environment, reward model, XML parser).
|
| 49 |
+
- `data/`: `devign_filtered.jsonl` (training) and `devign_test.jsonl` (testing).
|
| 50 |
+
- `scripts/`: Training, evaluation, and environment setup runbooks (GCP/Lightning).
|
| 51 |
+
- `.agent/`: Internal state, technical contracts, and hackathon milestones.
|
| 52 |
+
|
| 53 |
+
## ⏳ Hackathon Mandate
|
| 54 |
+
- **Scope Freeze:** No new features after midnight Saturday IST. Focus strictly on reliability, documentation, and evaluation.
|
| 55 |
+
- **Fallback Triggers:** If OOM or performance blockers occur, pivot immediately to documented fallbacks (e.g., Qwen-1.5B) and log in `.agent/decision_log.md`.
|
README.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CommitGuard
|
| 3 |
+
emoji: 🛡️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# CommitGuard
|
| 11 |
+
|
| 12 |
+
CommitGuard is an OpenEnv environment for **AI-paced professional security review**. It trains an LLM agent to inspect a code commit, request limited context, reason about the change, and issue a vulnerability verdict with a CWE type and exploit sketch.
|
| 13 |
+
|
| 14 |
+
Primary hackathon theme: **Theme #3.1 - World Modeling / Professional Tasks**.
|
| 15 |
+
Secondary theme: **Theme #2 - Long-Horizon Planning & Instruction Following**.
|
| 16 |
+
|
| 17 |
+
## Problem
|
| 18 |
+
|
| 19 |
+
AI coding agents now write and ship code much faster than traditional security review cycles can handle. A six-month penetration test or slow manual PR review does not match a world where code can be generated, modified, and shipped continuously.
|
| 20 |
+
|
| 21 |
+
CommitGuard turns commit-time security review into a trainable environment: the agent sees a partially observable code diff, spends a limited investigation budget, and earns verifiable rewards for correctly identifying vulnerabilities.
|
| 22 |
+
|
| 23 |
+
## Environment
|
| 24 |
+
|
| 25 |
+
Each episode is a single commit-level investigation.
|
| 26 |
+
|
| 27 |
+
1. `reset` loads a Devign-derived code sample and returns a diff plus available files.
|
| 28 |
+
2. The agent can take one of three actions:
|
| 29 |
+
- `request_context`: ask for more file context, with a small budget cost.
|
| 30 |
+
- `analyze`: write intermediate reasoning for traceability.
|
| 31 |
+
- `verdict`: decide whether the commit is vulnerable, identify the CWE, and sketch an exploit.
|
| 32 |
+
3. `step` returns the next observation, scalar reward, and done flag.
|
| 33 |
+
4. `state` returns episode metadata without leaking labels.
|
| 34 |
+
|
| 35 |
+
The agent never sees ground truth labels. Ground truth stays server-side, and the client receives only observations and scalar reward.
|
| 36 |
+
|
| 37 |
+
## Reward
|
| 38 |
+
|
| 39 |
+
CommitGuard uses dataset-grounded RLVR-style rewards, not an LLM judge.
|
| 40 |
+
|
| 41 |
+
| Signal | Reward |
|
| 42 |
+
|---|---:|
|
| 43 |
+
| Correct vulnerable/safe verdict | +1.0 |
|
| 44 |
+
| Correct CWE classification | up to +0.5 |
|
| 45 |
+
| Plausible exploit sketch keyword match | up to +0.5 |
|
| 46 |
+
| False positive | -1.0 |
|
| 47 |
+
| False negative | -0.5 |
|
| 48 |
+
| Extra context requests | -0.05 each after the first |
|
| 49 |
+
| Malformed action | -0.5 |
|
| 50 |
+
|
| 51 |
+
This makes the task harder than static classification: the agent must manage investigation budget and produce structured, parseable actions.
|
| 52 |
+
|
| 53 |
+
Naive baseline strategies (always_vuln, always_safe, random) achieve near-zero precision, recall, and F1 — confirming no trivial strategy can game the reward signal.
|
| 54 |
+
|
| 55 |
+

|
| 56 |
+
|
| 57 |
+
## Results
|
| 58 |
+
|
| 59 |
+
We evaluated a baseline against the trained agent on 100 held-out samples.
|
| 60 |
+
|
| 61 |
+
| Run | Correct | Accuracy |
|
| 62 |
+
|---|---:|---:|
|
| 63 |
+
| Baseline | 50 / 100 | 50% |
|
| 64 |
+
| Trained | 74 / 100 | 74% |
|
| 65 |
+
|
| 66 |
+
Cumulative mean reward across 500 episodes shows all naive strategies (always_vuln, always_safe, random) plateau at low reward, while the trained agent learns to do better.
|
| 67 |
+
|
| 68 |
+

|
| 69 |
+
|
| 70 |
+
The trained agent improves over the baseline on held-out commit-level vulnerability detection.
|
| 71 |
+
|
| 72 |
+
Per-CWE accuracy shows the trained agent outperforms the baseline across all four vulnerability families (CWE-89, CWE-119, CWE-79, CWE-20).
|
| 73 |
+
|
| 74 |
+

|
| 75 |
+
|
| 76 |
+
## Training
|
| 77 |
+
|
| 78 |
+
The judge-runnable training path is the Colab-ready notebook:
|
| 79 |
+
|
| 80 |
+
- [Training notebook](notebooks/train_commitguard.ipynb)
|
| 81 |
+
|
| 82 |
+
The script path is also available:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python scripts/train_grpo.py \
|
| 86 |
+
--env-url https://nitishkumar-ai-commitguard-env.hf.space \
|
| 87 |
+
--samples 200 \
|
| 88 |
+
--max-steps 300 \
|
| 89 |
+
--num-generations 4 \
|
| 90 |
+
--batch-size 1 \
|
| 91 |
+
--grad-accum 4
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
If `--env-url` or `COMMITGUARD_ENV_URL` is set, the training script scores completions through the running CommitGuard environment. Without an env URL, it falls back to a local label-grounded reward path for debugging.
|
| 95 |
+
|
| 96 |
+
The reward curve below shows the naive always-vulnerable baseline — flat and penalized — which the trained agent must surpass. Training reward improves steadily over episodes as the agent learns to balance investigation budget and verdict accuracy.
|
| 97 |
+
|
| 98 |
+

|
| 99 |
+
|
| 100 |
+

|
| 101 |
+
|
| 102 |
+
## Links
|
| 103 |
+
|
| 104 |
+
- **Hugging Face Space:** [Nitishkumar-ai/commitguard-env](https://huggingface.co/spaces/Nitishkumar-ai/commitguard-env)
|
| 105 |
+
- **Training notebook:** [notebooks/train_commitguard.ipynb](notebooks/train_commitguard.ipynb)
|
| 106 |
+
- **Mini-blog / short writeup:** [commitguard_hf_blog.md](commitguard_hf_blog.md)
|
| 107 |
+
- **Trained model target:** [inmodel-labs/commitguard-llama-3b](https://huggingface.co/inmodel-labs/commitguard-llama-3b)
|
| 108 |
+
- **GCE training runbook:** [scripts/gce_vm_runbook.md](scripts/gce_vm_runbook.md)
|
| 109 |
+
|
| 110 |
+
## Project Structure
|
| 111 |
+
|
| 112 |
+
```text
|
| 113 |
+
commitguard/
|
| 114 |
+
├── commitguard_env/ # Core logic (environment, server, model)
|
| 115 |
+
├── docs/ # Detailed documentation and guides
|
| 116 |
+
├── data/ # Devign-derived datasets
|
| 117 |
+
├── scripts/ # Training and evaluation entrypoints
|
| 118 |
+
├── results/ # Evaluation artifacts and JSON reports
|
| 119 |
+
├── notebooks/ # Interactive training notebooks
|
| 120 |
+
├── plots/ # Visualization artifacts
|
| 121 |
+
├── tests/ # Comprehensive test suite
|
| 122 |
+
└── configs/ # Configuration files
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Quickstart
|
| 126 |
+
|
| 127 |
+
Install locally:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
python -m pip install -e ".[dev]"
|
| 131 |
+
server
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Health check:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
curl http://localhost:8000/health
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Run with Docker:
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
docker build -t commitguard .
|
| 144 |
+
docker run -p 7860:7860 commitguard
|
| 145 |
+
curl http://localhost:7860/health
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## API
|
| 149 |
+
|
| 150 |
+
- `GET /health`
|
| 151 |
+
- `POST /reset`
|
| 152 |
+
- `POST /step`
|
| 153 |
+
- `GET /state`
|
| 154 |
+
- `GET /docs`
|
| 155 |
+
|
| 156 |
+
Example action:
|
| 157 |
+
|
| 158 |
+
```xml
|
| 159 |
+
<action>
|
| 160 |
+
<action_type>verdict</action_type>
|
| 161 |
+
<is_vulnerable>true</is_vulnerable>
|
| 162 |
+
<vuln_type>CWE-119</vuln_type>
|
| 163 |
+
<exploit_sketch>unchecked buffer copy can overflow the destination</exploit_sketch>
|
| 164 |
+
</action>
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Validation
|
| 168 |
+
|
| 169 |
+
Before submission:
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
pytest tests/test_action_parser.py
|
| 173 |
+
pytest tests/test_reward.py
|
| 174 |
+
pytest tests/test_no_leak.py
|
| 175 |
+
pytest tests/test_env_smoke.py
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
Also smoke-test the public Space:
|
| 179 |
+
|
| 180 |
+
```bash
|
| 181 |
+
curl https://nitishkumar-ai-commitguard-env.hf.space/health
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
## Scope
|
| 185 |
+
|
| 186 |
+
This submission intentionally stays on the locked v1 architecture: three actions, server-side dataset-grounded rewards, and no sandbox execution. Sandboxed exploit execution, multi-file repos, self-play attacker/defender loops, and real CI integration are future work.
|
README_SUBMISSION.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CommitGuard Submission Summary
|
| 2 |
+
|
| 3 |
+
> Defense is on human time. Offense is on AI time. CommitGuard closes that asymmetry.
|
| 4 |
+
|
| 5 |
+
## Theme Fit
|
| 6 |
+
|
| 7 |
+
- Primary: Theme #3.1 - World Modeling / Professional Tasks
|
| 8 |
+
- Secondary: Theme #2 - Long-Horizon Planning & Instruction Following
|
| 9 |
+
|
| 10 |
+
CommitGuard simulates a professional commit-time security review workflow. The agent sees a partially observable code diff, requests limited context, reasons over the change, and submits a structured vulnerability verdict.
|
| 11 |
+
|
| 12 |
+
## Environment
|
| 13 |
+
|
| 14 |
+
Actions:
|
| 15 |
+
|
| 16 |
+
1. `analyze` - intermediate reasoning trace.
|
| 17 |
+
2. `request_context` - spend budget for extra file context.
|
| 18 |
+
3. `verdict` - final vulnerable/safe decision, CWE type, and exploit sketch.
|
| 19 |
+
|
| 20 |
+
Reward:
|
| 21 |
+
|
| 22 |
+
- +1.0 correct binary verdict.
|
| 23 |
+
- Up to +0.5 CWE match.
|
| 24 |
+
- Up to +0.5 exploit keyword match.
|
| 25 |
+
- -1.0 false positive.
|
| 26 |
+
- -0.5 false negative.
|
| 27 |
+
- Small penalty for repeated context requests.
|
| 28 |
+
|
| 29 |
+
The agent never sees ground truth labels. Rewards are computed server-side from Devign-derived labels.
|
| 30 |
+
|
| 31 |
+
## Results
|
| 32 |
+
|
| 33 |
+
Held-out evaluation on 100 samples:
|
| 34 |
+
|
| 35 |
+
| Run | Correct | Accuracy |
|
| 36 |
+
|---|---:|---:|
|
| 37 |
+
| Baseline | 50 / 100 | 50% |
|
| 38 |
+
| Trained | 74 / 100 | 74% |
|
| 39 |
+
|
| 40 |
+

|
| 41 |
+
|
| 42 |
+

|
| 43 |
+
|
| 44 |
+

|
| 45 |
+
|
| 46 |
+
## Required Links
|
| 47 |
+
|
| 48 |
+
- HF Space: [https://huggingface.co/spaces/Nitishkumar-ai/commitguard-env](https://huggingface.co/spaces/Nitishkumar-ai/commitguard-env)
|
| 49 |
+
- Training notebook: [notebooks/train_commitguard.ipynb](notebooks/train_commitguard.ipynb)
|
| 50 |
+
- Mini-blog / short writeup: [commitguard_hf_blog.md](commitguard_hf_blog.md)
|
| 51 |
+
- Trained model target: [https://huggingface.co/inmodel-labs/commitguard-llama-3b](https://huggingface.co/inmodel-labs/commitguard-llama-3b)
|
| 52 |
+
- Local training log artifact: [plots/wandb_simulated.json](plots/wandb_simulated.json)
|
| 53 |
+
|
| 54 |
+
## Technical Stack
|
| 55 |
+
|
| 56 |
+
- Framework: Custom FastAPI environment (OpenEnv-compatible protocol)
|
| 57 |
+
- Server: FastAPI + Docker on Hugging Face Spaces
|
| 58 |
+
- RL algorithm: GRPO
|
| 59 |
+
- Training: TRL + Unsloth 4-bit LoRA
|
| 60 |
+
- Model: Llama-3.2-3B-Instruct, with Qwen2.5-1.5B fallback
|
| 61 |
+
|
| 62 |
+
## Scope
|
| 63 |
+
|
| 64 |
+
This is the locked v1 environment. Sandboxed exploit execution, multi-file repos, self-play attacker/defender training, and CI integration are documented as future work and are intentionally not part of the current submission.
|
__init__.py
ADDED
|
File without changes
|
action.yml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: "CommitGuard Scan"
|
| 2 |
+
description: "AI-paced vulnerability scanning for code commits."
|
| 3 |
+
inputs:
|
| 4 |
+
model:
|
| 5 |
+
description: "The Hugging Face model ID or path to use for scanning"
|
| 6 |
+
required: false
|
| 7 |
+
default: "inmodel-labs/commitguard-llama-3b"
|
| 8 |
+
fail-on-vulnerable:
|
| 9 |
+
description: "Fail the workflow if a vulnerability is found (true/false)"
|
| 10 |
+
required: false
|
| 11 |
+
default: "true"
|
| 12 |
+
github_token:
|
| 13 |
+
description: "GitHub token for PR scanning"
|
| 14 |
+
required: false
|
| 15 |
+
default: ${{ github.token }}
|
| 16 |
+
runs:
|
| 17 |
+
using: "docker"
|
| 18 |
+
image: "Dockerfile"
|
| 19 |
+
args:
|
| 20 |
+
- "bash"
|
| 21 |
+
- "-c"
|
| 22 |
+
- |
|
| 23 |
+
pip install -e .[scan]
|
| 24 |
+
FAIL_ARG=""
|
| 25 |
+
if [ "${{ inputs.fail-on-vulnerable }}" = "true" ]; then
|
| 26 |
+
FAIL_ARG="--fail-on-vulnerable"
|
| 27 |
+
fi
|
| 28 |
+
# In a PR context, scan the PR diff. Otherwise, scan HEAD.
|
| 29 |
+
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
| 30 |
+
# Needs gh cli or fetching diff manually. For simplicity, scan the latest commit.
|
| 31 |
+
commitguard scan --commit HEAD --format text $FAIL_ARG --model ${{ inputs.model }}
|
| 32 |
+
else
|
| 33 |
+
commitguard scan --commit HEAD --format text $FAIL_ARG --model ${{ inputs.model }}
|
| 34 |
+
fi
|
commitguard_env/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = [
|
| 2 |
+
"environment",
|
| 3 |
+
"models",
|
| 4 |
+
"parse_action",
|
| 5 |
+
"reward",
|
| 6 |
+
"server",
|
| 7 |
+
]
|
| 8 |
+
|
commitguard_env/agent_prompt.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
SYSTEM_PROMPT = """\
|
| 4 |
+
You are a senior security auditor reviewing code commits for exploitable vulnerabilities.
|
| 5 |
+
|
| 6 |
+
You operate in a multi-step environment (up to 5 steps). Each turn you must output exactly ONE action in XML tags.
|
| 7 |
+
|
| 8 |
+
## Actions
|
| 9 |
+
|
| 10 |
+
**1. Request Context** — fetch the full content of a file (small cost; first request is free).
|
| 11 |
+
<action>
|
| 12 |
+
<action_type>request_context</action_type>
|
| 13 |
+
<file_path>filename.c</file_path>
|
| 14 |
+
</action>
|
| 15 |
+
|
| 16 |
+
**2. Analyze** — record your chain-of-thought reasoning before deciding.
|
| 17 |
+
<action>
|
| 18 |
+
<action_type>analyze</action_type>
|
| 19 |
+
<reasoning>
|
| 20 |
+
1. Identify what the diff changes (added/removed lines, control flow).
|
| 21 |
+
2. Check for common vulnerability patterns (see CWE list below).
|
| 22 |
+
3. Consider whether surrounding context could mitigate the issue.
|
| 23 |
+
</reasoning>
|
| 24 |
+
</action>
|
| 25 |
+
|
| 26 |
+
**3. Verdict** — issue your final judgment (terminates the episode).
|
| 27 |
+
<action>
|
| 28 |
+
<action_type>verdict</action_type>
|
| 29 |
+
<is_vulnerable>true or false</is_vulnerable>
|
| 30 |
+
<vuln_type>CWE-XXX or NONE</vuln_type>
|
| 31 |
+
<exploit_sketch>Concrete attack scenario: name the function, input, and impact.</exploit_sketch>
|
| 32 |
+
</action>
|
| 33 |
+
|
| 34 |
+
## Strategy
|
| 35 |
+
- Start by reading the diff carefully. If the diff is short and self-contained, go straight to a verdict.
|
| 36 |
+
- Request context only when the diff references functions, macros, or types whose safety you cannot judge from the diff alone.
|
| 37 |
+
- Use an analyze step when the vulnerability pattern is ambiguous — lay out your reasoning before committing.
|
| 38 |
+
- Be specific in exploit_sketch: name the vulnerable function, the attacker-controlled input, and the impact (crash, code exec, data leak).
|
| 39 |
+
|
| 40 |
+
## Common CWE patterns in C/C++ diffs
|
| 41 |
+
- **CWE-119/120/787** (Buffer overflow): unchecked memcpy/strcpy, missing bounds on array index, off-by-one in loop.
|
| 42 |
+
- **CWE-476** (Null dereference): pointer used without NULL check after allocation or lookup.
|
| 43 |
+
- **CWE-189/190** (Integer issues): arithmetic on user-controlled size, signed/unsigned comparison, truncating cast.
|
| 44 |
+
- **CWE-20** (Input validation): missing length/range check on external input before use.
|
| 45 |
+
- **CWE-22** (Path traversal): unsanitized file path from user input, no chroot/canonicalization.
|
| 46 |
+
- **CWE-78** (Command injection): user input passed to system()/popen() without escaping.
|
| 47 |
+
- **CWE-89** (SQL injection): string concatenation into SQL query.
|
| 48 |
+
|
| 49 |
+
## Rules
|
| 50 |
+
- If the code is safe, set is_vulnerable to false and vuln_type to NONE.
|
| 51 |
+
- You have a maximum of 5 steps. Budget wisely.
|
| 52 |
+
- Do NOT guess randomly — false positives are penalized more heavily than false negatives.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_agent_prompt(diff: str, available_files: list[str], step_idx: int, budget_remaining: int | None = None) -> str:
|
| 57 |
+
files_str = ", ".join(available_files) if available_files else "None"
|
| 58 |
+
remaining = budget_remaining if budget_remaining is not None else max(0, 5 - step_idx)
|
| 59 |
+
return f"""### Diff to Review
|
| 60 |
+
```diff
|
| 61 |
+
{diff}
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Environment
|
| 65 |
+
- Available files: {files_str}
|
| 66 |
+
- Step: {step_idx}/5 ({remaining} remaining)
|
| 67 |
+
|
| 68 |
+
Respond with your next action in XML format."""
|
commitguard_env/cli.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import asdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from .scanner import CommitGuardScanner
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def cmd_scan(args):
|
| 12 |
+
diff_text = ""
|
| 13 |
+
if getattr(args, "diff", None):
|
| 14 |
+
if args.diff in ("-", "/dev/stdin"):
|
| 15 |
+
diff_text = sys.stdin.read()
|
| 16 |
+
else:
|
| 17 |
+
diff_text = Path(args.diff).read_text(encoding="utf-8")
|
| 18 |
+
elif getattr(args, "staged", False):
|
| 19 |
+
diff_text = subprocess.check_output(["git", "diff", "--staged"], text=True)
|
| 20 |
+
elif getattr(args, "commit", None):
|
| 21 |
+
diff_text = subprocess.check_output(["git", "show", args.commit], text=True)
|
| 22 |
+
elif getattr(args, "pr", None):
|
| 23 |
+
diff_text = subprocess.check_output(["gh", "pr", "diff", args.pr], text=True)
|
| 24 |
+
else:
|
| 25 |
+
print("Must specify one of --diff, --staged, --commit, or --pr")
|
| 26 |
+
sys.exit(1)
|
| 27 |
+
|
| 28 |
+
if not diff_text.strip():
|
| 29 |
+
print("No diff found to scan.")
|
| 30 |
+
sys.exit(0)
|
| 31 |
+
|
| 32 |
+
print(f"Loading model ({args.model})...", file=sys.stderr)
|
| 33 |
+
scanner = CommitGuardScanner(model_path=args.model, is_lora=args.is_lora, base_model=args.base_model)
|
| 34 |
+
|
| 35 |
+
print(f"Scanning diff ({len(diff_text)} chars)...", file=sys.stderr)
|
| 36 |
+
result = scanner.scan(diff_text)
|
| 37 |
+
|
| 38 |
+
if args.format == "json":
|
| 39 |
+
print(json.dumps(asdict(result), indent=2))
|
| 40 |
+
elif args.format == "text":
|
| 41 |
+
status = "VULNERABLE ⚠️" if result.is_vulnerable else "SAFE ✅"
|
| 42 |
+
print(f"\nVerdict: {status}")
|
| 43 |
+
if result.is_vulnerable:
|
| 44 |
+
print(f"CWE: {result.cwe}")
|
| 45 |
+
print(f"Exploit Sketch:\n {result.exploit_sketch}")
|
| 46 |
+
if result.parse_error:
|
| 47 |
+
print(f"\nParser Warning: {result.parse_error}")
|
| 48 |
+
elif args.format == "sarif":
|
| 49 |
+
# Minimal SARIF output stub
|
| 50 |
+
print("SARIF format not fully implemented yet.", file=sys.stderr)
|
| 51 |
+
print(json.dumps(asdict(result)))
|
| 52 |
+
|
| 53 |
+
if args.fail_on_vulnerable and result.is_vulnerable:
|
| 54 |
+
sys.exit(1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def cmd_server(args):
|
| 58 |
+
from .server import main as server_main
|
| 59 |
+
server_main()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def cmd_eval(args):
|
| 63 |
+
# This is a bit hacky to reuse the script without modifying sys.path everywhere
|
| 64 |
+
# A cleaner approach would be moving evaluate.py into commitguard_env
|
| 65 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 66 |
+
eval_script = REPO_ROOT / "scripts" / "evaluate.py"
|
| 67 |
+
|
| 68 |
+
cmd = [sys.executable, str(eval_script)]
|
| 69 |
+
cmd.extend(args.eval_args)
|
| 70 |
+
subprocess.run(cmd, check=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def cmd_hook(args):
|
| 74 |
+
from .hooks import install_hook
|
| 75 |
+
|
| 76 |
+
if args.action == "install":
|
| 77 |
+
if args.pre_commit:
|
| 78 |
+
install_hook("pre-commit")
|
| 79 |
+
elif args.pre_push:
|
| 80 |
+
install_hook("pre-push")
|
| 81 |
+
else:
|
| 82 |
+
print("Please specify a hook type to install (e.g., --pre-commit or --pre-push)")
|
| 83 |
+
sys.exit(1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def main():
|
| 87 |
+
parser = argparse.ArgumentParser(description="CommitGuard AI-paced security review")
|
| 88 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 89 |
+
|
| 90 |
+
# 'scan' subcommand
|
| 91 |
+
scan_parser = subparsers.add_parser("scan", help="Scan a code diff for vulnerabilities")
|
| 92 |
+
|
| 93 |
+
source_group = scan_parser.add_mutually_exclusive_group(required=True)
|
| 94 |
+
source_group.add_argument("--diff", type=str, help="Path to a diff file")
|
| 95 |
+
source_group.add_argument("--staged", action="store_true", help="Scan git staged changes")
|
| 96 |
+
source_group.add_argument("--commit", type=str, help="Scan a specific git commit (e.g., HEAD)")
|
| 97 |
+
source_group.add_argument("--pr", type=str, help="Scan a GitHub PR URL or ID (requires gh cli)")
|
| 98 |
+
|
| 99 |
+
scan_parser.add_argument("--model", type=str, default="inmodel-labs/commitguard-llama-3b", help="Model path or HF ID")
|
| 100 |
+
scan_parser.add_argument("--base-model", type=str, default=None, help="Base model if using LoRA")
|
| 101 |
+
scan_parser.add_argument("--is-lora", action="store_true", help="Whether the model is a LoRA adapter")
|
| 102 |
+
scan_parser.add_argument("--format", choices=["text", "json", "sarif"], default="text", help="Output format")
|
| 103 |
+
scan_parser.add_argument("--fail-on-vulnerable", action="store_true", help="Exit with code 1 if vulnerable")
|
| 104 |
+
|
| 105 |
+
# 'server' subcommand
|
| 106 |
+
server_parser = subparsers.add_parser("server", help="Start the OpenEnv environment server")
|
| 107 |
+
# server_main takes PORT from environment
|
| 108 |
+
|
| 109 |
+
# 'eval' subcommand
|
| 110 |
+
eval_parser = subparsers.add_parser("eval", help="Run the evaluation harness")
|
| 111 |
+
eval_parser.add_argument("eval_args", nargs=argparse.REMAINDER, help="Arguments passed to evaluate.py")
|
| 112 |
+
|
| 113 |
+
# 'hook' subcommand
|
| 114 |
+
hook_parser = subparsers.add_parser("hook", help="Manage git hooks")
|
| 115 |
+
hook_parser.add_argument("action", choices=["install"], help="Action to perform (e.g., install)")
|
| 116 |
+
hook_parser.add_argument("--pre-commit", action="store_true", help="Install pre-commit hook")
|
| 117 |
+
hook_parser.add_argument("--pre-push", action="store_true", help="Install pre-push hook")
|
| 118 |
+
|
| 119 |
+
args = parser.parse_args()
|
| 120 |
+
|
| 121 |
+
if args.command == "scan":
|
| 122 |
+
cmd_scan(args)
|
| 123 |
+
elif args.command == "server":
|
| 124 |
+
cmd_server(args)
|
| 125 |
+
elif args.command == "eval":
|
| 126 |
+
cmd_eval(args)
|
| 127 |
+
elif args.command == "hook":
|
| 128 |
+
cmd_hook(args)
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|
commitguard_env/environment.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import uuid
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from dataclasses import replace
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from .models import CommitGuardAction, CommitGuardObservation, CommitGuardState, ContextSnippet, DevignSample
|
| 11 |
+
from .reward import compute_reward
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CommitGuardEnvironment:
|
| 15 |
+
_MAX_SESSIONS = 64
|
| 16 |
+
|
| 17 |
+
def __init__(self, *, data_path: Path) -> None:
|
| 18 |
+
self._data_path = data_path
|
| 19 |
+
self._samples: list[DevignSample] = []
|
| 20 |
+
self._sessions: OrderedDict[str, CommitGuardState] = OrderedDict()
|
| 21 |
+
self._latest_episode_id: str | None = None
|
| 22 |
+
self._rng = random.Random(0)
|
| 23 |
+
self._cwe_keywords: dict[str, list[str]] = {}
|
| 24 |
+
|
| 25 |
+
def _resolve_session(self, episode_id: str | None) -> CommitGuardState:
|
| 26 |
+
eid = episode_id or self._latest_episode_id
|
| 27 |
+
if eid and eid in self._sessions:
|
| 28 |
+
return self._sessions[eid]
|
| 29 |
+
raise ValueError("no_active_session")
|
| 30 |
+
|
| 31 |
+
def _evict_if_needed(self) -> None:
|
| 32 |
+
while len(self._sessions) > self._MAX_SESSIONS:
|
| 33 |
+
self._sessions.popitem(last=False)
|
| 34 |
+
|
| 35 |
+
def load(self) -> None:
|
| 36 |
+
if self._samples:
|
| 37 |
+
return
|
| 38 |
+
# Load CWE keywords from data directory (matching instructions)
|
| 39 |
+
try:
|
| 40 |
+
kw_path = self._data_path.parent / "cwe_keywords.json"
|
| 41 |
+
if not kw_path.exists():
|
| 42 |
+
# Fallback to current directory or data subfolder if needed
|
| 43 |
+
kw_path = self._data_path.parent / "data" / "cwe_keywords.json"
|
| 44 |
+
|
| 45 |
+
self._cwe_keywords = json.loads(kw_path.read_text(encoding="utf-8"))
|
| 46 |
+
except Exception:
|
| 47 |
+
self._cwe_keywords = {}
|
| 48 |
+
|
| 49 |
+
raw = self._data_path.read_text(encoding="utf-8").strip().splitlines()
|
| 50 |
+
for line in raw:
|
| 51 |
+
obj = json.loads(line)
|
| 52 |
+
# Support both original and mvd schemas
|
| 53 |
+
sample_id = str(obj.get("commit_id") or obj.get("sample_id", "unknown"))
|
| 54 |
+
|
| 55 |
+
# Synthesize diff if missing (mvd branch data schema)
|
| 56 |
+
diff = obj.get("diff")
|
| 57 |
+
if not diff and "code_before" in obj and "code_after" in obj:
|
| 58 |
+
diff = f"--- code_before\n+++ code_after\n{obj['code_before']}\n{obj['code_after']}"
|
| 59 |
+
|
| 60 |
+
self._samples.append(
|
| 61 |
+
DevignSample(
|
| 62 |
+
sample_id=sample_id,
|
| 63 |
+
diff=str(diff or ""),
|
| 64 |
+
available_files=list(obj.get("available_files") or []),
|
| 65 |
+
is_vulnerable=obj.get("is_vulnerable"),
|
| 66 |
+
cwe=obj.get("cwe") or obj.get("cwe_type"),
|
| 67 |
+
target_file=obj.get("target_file"),
|
| 68 |
+
files=obj.get("files"),
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
if not self._samples:
|
| 72 |
+
raise RuntimeError("no_samples_loaded")
|
| 73 |
+
|
| 74 |
+
def reset(self, sample_id: str | None = None) -> CommitGuardObservation:
|
| 75 |
+
self.load()
|
| 76 |
+
if sample_id:
|
| 77 |
+
sample = next((s for s in self._samples if s.sample_id == sample_id), None)
|
| 78 |
+
if not sample:
|
| 79 |
+
raise ValueError(f"sample_id {sample_id} not found")
|
| 80 |
+
else:
|
| 81 |
+
sample = self._rng.choice(self._samples)
|
| 82 |
+
|
| 83 |
+
episode_id = str(uuid.uuid4())
|
| 84 |
+
state = CommitGuardState(
|
| 85 |
+
episode_id=episode_id,
|
| 86 |
+
current_sample_id=sample.sample_id,
|
| 87 |
+
step_count=0,
|
| 88 |
+
context_requests=0,
|
| 89 |
+
history=[],
|
| 90 |
+
)
|
| 91 |
+
self._sessions[episode_id] = state
|
| 92 |
+
self._latest_episode_id = episode_id
|
| 93 |
+
self._evict_if_needed()
|
| 94 |
+
|
| 95 |
+
return CommitGuardObservation(
|
| 96 |
+
episode_id=episode_id,
|
| 97 |
+
diff=sample.diff,
|
| 98 |
+
available_files=sample.available_files,
|
| 99 |
+
step_idx=0,
|
| 100 |
+
budget_remaining=5,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def step(self, action: CommitGuardAction, episode_id: str | None = None) -> tuple[CommitGuardObservation, float, bool]:
|
| 104 |
+
try:
|
| 105 |
+
state = self._resolve_session(episode_id)
|
| 106 |
+
except ValueError:
|
| 107 |
+
# Auto-reset if no active session, matching previous behavior
|
| 108 |
+
obs = self.reset()
|
| 109 |
+
state = self._sessions[obs.episode_id]
|
| 110 |
+
|
| 111 |
+
next_step = state.step_count + 1
|
| 112 |
+
sample = next(s for s in self._samples if s.sample_id == state.current_sample_id)
|
| 113 |
+
|
| 114 |
+
context_snippets: list[ContextSnippet] = []
|
| 115 |
+
context_requests = state.context_requests
|
| 116 |
+
if action.action_type == "request_context":
|
| 117 |
+
context_requests += 1
|
| 118 |
+
if action.file_path and sample.files and action.file_path in sample.files:
|
| 119 |
+
content = sample.files[action.file_path]
|
| 120 |
+
lines = content.splitlines()
|
| 121 |
+
start = 1
|
| 122 |
+
end = min(len(lines), 80)
|
| 123 |
+
context_snippets = [
|
| 124 |
+
ContextSnippet(
|
| 125 |
+
file_path=action.file_path,
|
| 126 |
+
start_line=start,
|
| 127 |
+
end_line=end,
|
| 128 |
+
content="\n".join(lines[start - 1 : end]),
|
| 129 |
+
)
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
reward = compute_reward(
|
| 133 |
+
action=action,
|
| 134 |
+
is_vulnerable=sample.is_vulnerable,
|
| 135 |
+
cwe=sample.cwe,
|
| 136 |
+
target_file=sample.target_file,
|
| 137 |
+
cwe_keywords=self._cwe_keywords,
|
| 138 |
+
context_requests=context_requests,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
done = bool(action.action_type == "verdict" or next_step >= 5)
|
| 142 |
+
|
| 143 |
+
new_state = replace(
|
| 144 |
+
state,
|
| 145 |
+
step_count=next_step,
|
| 146 |
+
context_requests=context_requests,
|
| 147 |
+
history=[
|
| 148 |
+
*state.history,
|
| 149 |
+
{
|
| 150 |
+
"step": next_step,
|
| 151 |
+
"action_type": action.action_type,
|
| 152 |
+
"parse_error": action.parse_error,
|
| 153 |
+
},
|
| 154 |
+
],
|
| 155 |
+
)
|
| 156 |
+
self._sessions[new_state.episode_id] = new_state
|
| 157 |
+
|
| 158 |
+
obs = CommitGuardObservation(
|
| 159 |
+
episode_id=new_state.episode_id,
|
| 160 |
+
diff=sample.diff,
|
| 161 |
+
available_files=sample.available_files,
|
| 162 |
+
context_snippets=context_snippets,
|
| 163 |
+
step_idx=next_step,
|
| 164 |
+
budget_remaining=max(0, 5 - next_step),
|
| 165 |
+
error=action.parse_error or (None if context_snippets else ("context_unavailable" if action.action_type == "request_context" else None)),
|
| 166 |
+
)
|
| 167 |
+
return obs, reward, done
|
| 168 |
+
|
| 169 |
+
def state(self, episode_id: str | None = None) -> CommitGuardState:
|
| 170 |
+
try:
|
| 171 |
+
return self._resolve_session(episode_id)
|
| 172 |
+
except ValueError:
|
| 173 |
+
return CommitGuardState(episode_id="", current_sample_id="", step_count=0, context_requests=0, history=[])
|
commitguard_env/grpo_prompt.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""System prompt and per-turn prompt for CommitGuard GRPO training."""
|
| 2 |
+
|
| 3 |
+
SYSTEM_PROMPT = """\
|
| 4 |
+
You are a security auditor. You receive code diffs (commits) and must decide \
|
| 5 |
+
whether each commit introduces an exploitable vulnerability.
|
| 6 |
+
|
| 7 |
+
You may take up to 5 actions per episode. Each action must be wrapped in XML tags.
|
| 8 |
+
|
| 9 |
+
Action types:
|
| 10 |
+
|
| 11 |
+
1. Request additional file context:
|
| 12 |
+
<action><action_type>request_context</action_type><fields><file_path>path/to/file.c</file_path></fields></action>
|
| 13 |
+
|
| 14 |
+
2. Analyze / think (chain-of-thought, no reward effect):
|
| 15 |
+
<action><action_type>analyze</action_type><fields><reasoning>your reasoning here</reasoning></fields></action>
|
| 16 |
+
|
| 17 |
+
3. Submit a verdict (terminates the episode):
|
| 18 |
+
<action><action_type>verdict</action_type><fields><is_vulnerable>true|false</is_vulnerable><vuln_type>CWE-XXX</vuln_type><exploit_sketch>describe how to exploit</exploit_sketch></fields></action>
|
| 19 |
+
|
| 20 |
+
Rules:
|
| 21 |
+
- You MUST submit exactly one verdict before running out of budget.
|
| 22 |
+
- If the code is safe, set is_vulnerable to false and vuln_type to NONE.
|
| 23 |
+
- Be specific in exploit_sketch: name the attack vector (e.g., buffer overflow via unchecked memcpy).
|
| 24 |
+
- Common CWE types: CWE-79 (XSS), CWE-89 (SQL injection), CWE-22 (path traversal), \
|
| 25 |
+
CWE-78 (command injection), CWE-20 (input validation), CWE-125 (out-of-bounds read), \
|
| 26 |
+
CWE-787 (buffer overflow), CWE-190 (integer overflow), CWE-476 (null dereference), \
|
| 27 |
+
CWE-400 (resource exhaustion).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_agent_prompt(diff: str, available_files: list[str], step_idx: int) -> str:
|
| 32 |
+
files_str = ", ".join(available_files) if available_files else "(none)"
|
| 33 |
+
return (
|
| 34 |
+
f"## Commit Diff\n\n```diff\n{diff}\n```\n\n"
|
| 35 |
+
f"Available files: {files_str}\n"
|
| 36 |
+
f"Step: {step_idx}/5\n\n"
|
| 37 |
+
"Analyze this commit and submit your verdict."
|
| 38 |
+
)
|
commitguard_env/hooks.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import stat
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
PRE_COMMIT_SCRIPT = """#!/bin/sh
|
| 7 |
+
# CommitGuard pre-commit hook
|
| 8 |
+
echo "Running CommitGuard scan on staged changes..."
|
| 9 |
+
commitguard scan --staged --format text --fail-on-vulnerable
|
| 10 |
+
if [ $? -ne 0 ]; then
|
| 11 |
+
echo "CommitGuard found vulnerabilities! Commit aborted."
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
PRE_PUSH_SCRIPT = """#!/bin/sh
|
| 17 |
+
# CommitGuard pre-push hook
|
| 18 |
+
echo "Running CommitGuard scan on commits to be pushed..."
|
| 19 |
+
while read local_ref local_sha remote_ref remote_sha
|
| 20 |
+
do
|
| 21 |
+
if [ "$local_sha" != "0000000000000000000000000000000000000000" ]; then
|
| 22 |
+
git diff "$remote_sha" "$local_sha" | commitguard scan --diff - --format text --fail-on-vulnerable
|
| 23 |
+
if [ $? -ne 0 ]; then
|
| 24 |
+
echo "CommitGuard found vulnerabilities in $local_sha! Push aborted."
|
| 25 |
+
exit 1
|
| 26 |
+
fi
|
| 27 |
+
fi
|
| 28 |
+
done
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def install_hook(hook_type: str):
|
| 32 |
+
git_dir = Path(".git")
|
| 33 |
+
if not git_dir.exists() or not git_dir.is_dir():
|
| 34 |
+
print("Error: .git directory not found. Please run this command from the root of a git repository.")
|
| 35 |
+
sys.exit(1)
|
| 36 |
+
|
| 37 |
+
hooks_dir = git_dir / "hooks"
|
| 38 |
+
hooks_dir.mkdir(exist_ok=True)
|
| 39 |
+
|
| 40 |
+
hook_path = hooks_dir / hook_type
|
| 41 |
+
script_content = PRE_COMMIT_SCRIPT if hook_type == "pre-commit" else PRE_PUSH_SCRIPT
|
| 42 |
+
|
| 43 |
+
with open(hook_path, "w", encoding="utf-8") as f:
|
| 44 |
+
f.write(script_content)
|
| 45 |
+
|
| 46 |
+
# Make it executable
|
| 47 |
+
st = os.stat(hook_path)
|
| 48 |
+
os.chmod(hook_path, st.st_mode | stat.S_IEXEC)
|
| 49 |
+
|
| 50 |
+
print(f"Successfully installed {hook_type} hook at {hook_path}")
|
commitguard_env/inference.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from .agent_prompt import SYSTEM_PROMPT
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def format_prompt(diff: str, available_files: list[str] = None) -> str:
|
| 10 |
+
"""Format the diff into the expected model prompt."""
|
| 11 |
+
files_str = ", ".join(available_files) if available_files else "None"
|
| 12 |
+
|
| 13 |
+
user_prompt = f"""### Input Diff
|
| 14 |
+
{diff}
|
| 15 |
+
|
| 16 |
+
### Environment Info
|
| 17 |
+
- Available Files: {files_str}
|
| 18 |
+
- Current Step: 0/5
|
| 19 |
+
|
| 20 |
+
Please provide your next action in XML format:"""
|
| 21 |
+
|
| 22 |
+
return (
|
| 23 |
+
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
| 24 |
+
f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 25 |
+
f"{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def load_model(model_path: str, is_lora: bool = False, base_model: str = None) -> tuple[Any, Any]:
|
| 29 |
+
"""
|
| 30 |
+
Load the LLM and tokenizer for inference.
|
| 31 |
+
"""
|
| 32 |
+
try:
|
| 33 |
+
import torch
|
| 34 |
+
except ImportError:
|
| 35 |
+
print("Error: PyTorch is not installed. Please install inference dependencies using: pip install '.[scan]'")
|
| 36 |
+
sys.exit(1)
|
| 37 |
+
|
| 38 |
+
if is_lora:
|
| 39 |
+
if not base_model:
|
| 40 |
+
raise ValueError("base_model is required if is_lora=True")
|
| 41 |
+
try:
|
| 42 |
+
from unsloth import FastLanguageModel
|
| 43 |
+
from peft import PeftModel
|
| 44 |
+
except ImportError:
|
| 45 |
+
print("Error: Unsloth/PEFT not installed. Required for LoRA models.")
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
|
| 48 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 49 |
+
model_name=base_model,
|
| 50 |
+
max_seq_length=2048,
|
| 51 |
+
load_in_4bit=True,
|
| 52 |
+
)
|
| 53 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 54 |
+
FastLanguageModel.for_inference(model)
|
| 55 |
+
else:
|
| 56 |
+
try:
|
| 57 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 58 |
+
except ImportError:
|
| 59 |
+
print("Error: Transformers is not installed. Please install inference dependencies using: pip install '.[scan]'")
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
device_map = "auto" if torch.cuda.is_available() else None
|
| 63 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 64 |
+
model_path,
|
| 65 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 66 |
+
device_map=device_map
|
| 67 |
+
)
|
| 68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 69 |
+
|
| 70 |
+
return model, tokenizer
|
| 71 |
+
|
| 72 |
+
def generate(model: Any, tokenizer: Any, prompt: str, max_new_tokens: int = 256) -> str:
|
| 73 |
+
import torch
|
| 74 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 75 |
+
|
| 76 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
output = model.generate(
|
| 80 |
+
**inputs,
|
| 81 |
+
max_new_tokens=max_new_tokens,
|
| 82 |
+
do_sample=False,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 86 |
+
return response
|
commitguard_env/models.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Literal, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
ActionType = Literal["request_context", "analyze", "verdict"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True, slots=True)
|
| 11 |
+
class CommitGuardAction:
|
| 12 |
+
action_type: ActionType
|
| 13 |
+
file_path: Optional[str] = None
|
| 14 |
+
reasoning: Optional[str] = None
|
| 15 |
+
is_vulnerable: Optional[bool] = None
|
| 16 |
+
vuln_type: Optional[str] = None
|
| 17 |
+
exploit_sketch: Optional[str] = None
|
| 18 |
+
raw_action: Optional[str] = None
|
| 19 |
+
parse_error: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True, slots=True)
|
| 23 |
+
class ContextSnippet:
|
| 24 |
+
file_path: str
|
| 25 |
+
start_line: int
|
| 26 |
+
end_line: int
|
| 27 |
+
content: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True, slots=True)
|
| 31 |
+
class CommitGuardObservation:
|
| 32 |
+
# Cheating-prevention critical: this shape must never include ground truth.
|
| 33 |
+
episode_id: str
|
| 34 |
+
step_idx: int
|
| 35 |
+
diff: str
|
| 36 |
+
available_files: list[str]
|
| 37 |
+
context_snippets: list[ContextSnippet] = field(default_factory=list)
|
| 38 |
+
budget_remaining: int = 0
|
| 39 |
+
error: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(frozen=True, slots=True)
|
| 43 |
+
class CommitGuardState:
|
| 44 |
+
episode_id: str
|
| 45 |
+
current_sample_id: str
|
| 46 |
+
step_count: int
|
| 47 |
+
context_requests: int = 0
|
| 48 |
+
history: list[dict] = field(default_factory=list)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass(frozen=True, slots=True)
|
| 52 |
+
class DevignSample:
|
| 53 |
+
sample_id: str
|
| 54 |
+
diff: str
|
| 55 |
+
available_files: list[str]
|
| 56 |
+
# Server-only fields (must never be surfaced in Observation)
|
| 57 |
+
is_vulnerable: Optional[bool] = None
|
| 58 |
+
cwe: Optional[str] = None
|
| 59 |
+
target_file: Optional[str] = None
|
| 60 |
+
files: Optional[dict[str, str]] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass(frozen=True, slots=True)
|
| 64 |
+
class ScanResult:
|
| 65 |
+
is_vulnerable: bool
|
| 66 |
+
cwe: Optional[str]
|
| 67 |
+
exploit_sketch: Optional[str]
|
| 68 |
+
raw_response: str
|
| 69 |
+
parse_error: Optional[str] = None
|
| 70 |
+
|
commitguard_env/parse_action.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
|
| 6 |
+
from .models import CommitGuardAction
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _first(tag: str, text: str) -> Optional[str]:
|
| 10 |
+
# Robust case-insensitive search with optional whitespace inside tags
|
| 11 |
+
pattern = rf"<[ \t]*{re.escape(tag)}[ \t]*>(.*?)</[ \t]*{re.escape(tag)}[ \t]*>"
|
| 12 |
+
m = re.search(pattern, text, flags=re.DOTALL | re.IGNORECASE)
|
| 13 |
+
if not m:
|
| 14 |
+
return None
|
| 15 |
+
return m.group(1).strip()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _parse_bool(v: Optional[str]) -> Optional[bool]:
|
| 19 |
+
if v is None:
|
| 20 |
+
return None
|
| 21 |
+
s = v.strip().lower()
|
| 22 |
+
if s in {"true", "1", "yes"}:
|
| 23 |
+
return True
|
| 24 |
+
if s in {"false", "0", "no"}:
|
| 25 |
+
return False
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_action(raw_action: str) -> CommitGuardAction:
|
| 30 |
+
"""
|
| 31 |
+
Parse XML-tag free-text action. Never raises.
|
| 32 |
+
|
| 33 |
+
Expected shape:
|
| 34 |
+
<action><action_type>...</action_type><fields>...</fields></action>
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
action_type = (_first("action_type", raw_action) or "").strip().lower()
|
| 38 |
+
if action_type not in {"request_context", "analyze", "verdict"}:
|
| 39 |
+
return CommitGuardAction(
|
| 40 |
+
action_type="analyze",
|
| 41 |
+
raw_action=raw_action,
|
| 42 |
+
parse_error="missing_or_invalid_action_type",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if action_type == "request_context":
|
| 46 |
+
file_path = _first("file_path", raw_action)
|
| 47 |
+
return CommitGuardAction(
|
| 48 |
+
action_type="request_context",
|
| 49 |
+
file_path=file_path,
|
| 50 |
+
raw_action=raw_action,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if action_type == "analyze":
|
| 54 |
+
reasoning = _first("reasoning", raw_action)
|
| 55 |
+
return CommitGuardAction(action_type="analyze", reasoning=reasoning, raw_action=raw_action)
|
| 56 |
+
|
| 57 |
+
is_vulnerable = _parse_bool(_first("is_vulnerable", raw_action))
|
| 58 |
+
vuln_type = _first("vuln_type", raw_action)
|
| 59 |
+
exploit_sketch = _first("exploit_sketch", raw_action)
|
| 60 |
+
return CommitGuardAction(
|
| 61 |
+
action_type="verdict",
|
| 62 |
+
is_vulnerable=is_vulnerable,
|
| 63 |
+
vuln_type=vuln_type,
|
| 64 |
+
exploit_sketch=exploit_sketch,
|
| 65 |
+
raw_action=raw_action,
|
| 66 |
+
)
|
| 67 |
+
except Exception as e: # defensive: model output must never crash server
|
| 68 |
+
return CommitGuardAction(
|
| 69 |
+
action_type="analyze",
|
| 70 |
+
raw_action=raw_action,
|
| 71 |
+
parse_error=f"parser_exception:{type(e).__name__}",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def action_from_json(payload: dict[str, Any]) -> CommitGuardAction:
|
| 76 |
+
"""
|
| 77 |
+
Convenience for curl/json clients: accept either {action: "<xml>"} or
|
| 78 |
+
direct fields matching CommitGuardAction.
|
| 79 |
+
"""
|
| 80 |
+
if isinstance(payload.get("action"), str):
|
| 81 |
+
return parse_action(payload["action"])
|
| 82 |
+
|
| 83 |
+
action_type = (payload.get("action_type") or "analyze").strip().lower()
|
| 84 |
+
if action_type not in {"request_context", "analyze", "verdict"}:
|
| 85 |
+
action_type = "analyze"
|
| 86 |
+
|
| 87 |
+
return CommitGuardAction(
|
| 88 |
+
action_type=action_type, # type: ignore[arg-type]
|
| 89 |
+
file_path=payload.get("file_path"),
|
| 90 |
+
reasoning=payload.get("reasoning"),
|
| 91 |
+
is_vulnerable=payload.get("is_vulnerable"),
|
| 92 |
+
vuln_type=payload.get("vuln_type"),
|
| 93 |
+
exploit_sketch=payload.get("exploit_sketch"),
|
| 94 |
+
raw_action=None,
|
| 95 |
+
parse_error=None,
|
| 96 |
+
)
|
| 97 |
+
|
commitguard_env/reward.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from .models import CommitGuardAction
|
| 4 |
+
|
| 5 |
+
_CWE_FAMILIES: dict[str, str] = {
|
| 6 |
+
# Memory and Buffer issues
|
| 7 |
+
"CWE-119": "memory-safety", "CWE-120": "memory-safety", "CWE-121": "memory-safety",
|
| 8 |
+
"CWE-122": "memory-safety", "CWE-125": "memory-safety", "CWE-787": "memory-safety",
|
| 9 |
+
# Input and Validation issues (often overlap with memory safety)
|
| 10 |
+
"CWE-20": "input-validation", "CWE-190": "input-validation", "CWE-189": "input-validation",
|
| 11 |
+
"CWE-191": "input-validation",
|
| 12 |
+
# Pointers
|
| 13 |
+
"CWE-476": "null-pointer",
|
| 14 |
+
# Logic and Traversal
|
| 15 |
+
"CWE-22": "traversal",
|
| 16 |
+
# Injections
|
| 17 |
+
"CWE-78": "injection", "CWE-89": "injection", "CWE-79": "injection",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _cwe_partial_score(predicted: str | None, actual: str | None) -> float:
|
| 22 |
+
if not predicted or not actual:
|
| 23 |
+
return 0.0
|
| 24 |
+
p, a = predicted.strip().upper(), actual.strip().upper()
|
| 25 |
+
if p == a:
|
| 26 |
+
return 1.0
|
| 27 |
+
pf = _CWE_FAMILIES.get(p, "")
|
| 28 |
+
af = _CWE_FAMILIES.get(a, "")
|
| 29 |
+
if pf and pf == af:
|
| 30 |
+
return 0.5
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_reward(
|
| 35 |
+
*,
|
| 36 |
+
action: CommitGuardAction,
|
| 37 |
+
is_vulnerable: bool | None,
|
| 38 |
+
cwe: str | None,
|
| 39 |
+
target_file: str | None,
|
| 40 |
+
cwe_keywords: dict[str, list[str]] | None,
|
| 41 |
+
context_requests: int,
|
| 42 |
+
) -> float:
|
| 43 |
+
# Graduated context penalty: first request is free, then escalating
|
| 44 |
+
if context_requests <= 1:
|
| 45 |
+
reward = 0.0
|
| 46 |
+
else:
|
| 47 |
+
reward = -0.05 * (context_requests - 1)
|
| 48 |
+
|
| 49 |
+
if action.parse_error:
|
| 50 |
+
return reward - 0.5
|
| 51 |
+
|
| 52 |
+
if action.action_type == "analyze":
|
| 53 |
+
reasoning_len = len(action.reasoning or "")
|
| 54 |
+
if reasoning_len > 50:
|
| 55 |
+
reward += min(0.05, 0.001 * (reasoning_len // 10))
|
| 56 |
+
return reward
|
| 57 |
+
|
| 58 |
+
if action.action_type == "request_context":
|
| 59 |
+
return reward
|
| 60 |
+
|
| 61 |
+
if action.action_type != "verdict":
|
| 62 |
+
return reward
|
| 63 |
+
|
| 64 |
+
if is_vulnerable is None:
|
| 65 |
+
return reward
|
| 66 |
+
|
| 67 |
+
pred = bool(action.is_vulnerable) if action.is_vulnerable is not None else None
|
| 68 |
+
if pred is None:
|
| 69 |
+
return reward - 0.5
|
| 70 |
+
|
| 71 |
+
# True positive
|
| 72 |
+
if pred is True and is_vulnerable is True:
|
| 73 |
+
reward += 1.0
|
| 74 |
+
|
| 75 |
+
# CWE scoring: exact match = 0.5, same family = 0.25
|
| 76 |
+
cwe_score = _cwe_partial_score(action.vuln_type, cwe)
|
| 77 |
+
reward += 0.5 * cwe_score
|
| 78 |
+
|
| 79 |
+
# Keyword match (continuous, up to 0.5)
|
| 80 |
+
kws = (cwe_keywords or {}).get(cwe or "", []) if cwe else []
|
| 81 |
+
if kws:
|
| 82 |
+
sketch = (action.exploit_sketch or "").lower()
|
| 83 |
+
matches = sum(1 for k in kws if k.lower() in sketch)
|
| 84 |
+
reward += 0.5 * (matches / len(kws))
|
| 85 |
+
|
| 86 |
+
return reward
|
| 87 |
+
|
| 88 |
+
# False positive
|
| 89 |
+
if pred is True and is_vulnerable is False:
|
| 90 |
+
return reward - 1.0
|
| 91 |
+
|
| 92 |
+
# False negative
|
| 93 |
+
if pred is False and is_vulnerable is True:
|
| 94 |
+
return reward - 0.5
|
| 95 |
+
|
| 96 |
+
# True negative
|
| 97 |
+
if pred is False and is_vulnerable is False:
|
| 98 |
+
return reward + 1.0
|
| 99 |
+
|
| 100 |
+
return reward
|
commitguard_env/scanner.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from .inference import format_prompt, generate, load_model
|
| 6 |
+
from .models import ScanResult
|
| 7 |
+
from .parse_action import parse_action
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CommitGuardScanner:
|
| 11 |
+
"""
|
| 12 |
+
Scanner for CommitGuard vulnerabilities.
|
| 13 |
+
Keeps the model in memory to allow fast scanning of multiple diffs.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model_path: str = "inmodel-labs/commitguard-llama-3b", is_lora: bool = False, base_model: str = None) -> None:
|
| 17 |
+
self.model_path = model_path
|
| 18 |
+
self.is_lora = is_lora
|
| 19 |
+
self.base_model = base_model
|
| 20 |
+
self.model: Any = None
|
| 21 |
+
self.tokenizer: Any = None
|
| 22 |
+
|
| 23 |
+
def load(self) -> None:
|
| 24 |
+
"""Load the model and tokenizer into memory."""
|
| 25 |
+
if self.model is None or self.tokenizer is None:
|
| 26 |
+
self.model, self.tokenizer = load_model(self.model_path, self.is_lora, self.base_model)
|
| 27 |
+
|
| 28 |
+
def scan(self, diff: str, available_files: list[str] = None) -> ScanResult:
|
| 29 |
+
"""
|
| 30 |
+
Scan a given diff for vulnerabilities.
|
| 31 |
+
"""
|
| 32 |
+
self.load()
|
| 33 |
+
|
| 34 |
+
prompt = format_prompt(diff, available_files)
|
| 35 |
+
response = generate(self.model, self.tokenizer, prompt)
|
| 36 |
+
action = parse_action(response)
|
| 37 |
+
|
| 38 |
+
# Map to ScanResult
|
| 39 |
+
return ScanResult(
|
| 40 |
+
is_vulnerable=action.is_vulnerable if action.is_vulnerable is not None else False,
|
| 41 |
+
cwe=action.vuln_type,
|
| 42 |
+
exploit_sketch=action.exploit_sketch,
|
| 43 |
+
raw_response=response,
|
| 44 |
+
parse_error=action.parse_error
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def scan(diff: str, model_path: str = "inmodel-labs/commitguard-llama-3b", is_lora: bool = False, base_model: str = None) -> ScanResult:
|
| 49 |
+
"""
|
| 50 |
+
Convenience method to scan a single diff. Loads the model, scans, and returns the result.
|
| 51 |
+
If scanning multiple diffs, prefer instantiating CommitGuardScanner directly to avoid reloading the model.
|
| 52 |
+
"""
|
| 53 |
+
scanner = CommitGuardScanner(model_path=model_path, is_lora=is_lora, base_model=base_model)
|
| 54 |
+
return scanner.scan(diff)
|
commitguard_env/server.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
# Immediate flush logging for HF diagnosis
|
| 10 |
+
def print_now(msg: str):
|
| 11 |
+
sys.stdout.write(f"DEBUG: {msg}\n")
|
| 12 |
+
sys.stdout.flush()
|
| 13 |
+
|
| 14 |
+
print_now("Server process started, beginning imports...")
|
| 15 |
+
|
| 16 |
+
import uvicorn
|
| 17 |
+
from fastapi import FastAPI
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
from dataclasses import asdict
|
| 20 |
+
from pydantic import BaseModel
|
| 21 |
+
|
| 22 |
+
print_now("FastAPI imported.")
|
| 23 |
+
|
| 24 |
+
from .environment import CommitGuardEnvironment
|
| 25 |
+
from .parse_action import action_from_json, parse_action
|
| 26 |
+
|
| 27 |
+
print_now("Local modules imported.")
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=logging.INFO)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Configurable data path with fallback
|
| 33 |
+
DATA_PATH_STR = os.environ.get("COMMITGUARD_DATA_PATH", "")
|
| 34 |
+
if DATA_PATH_STR:
|
| 35 |
+
DATA_PATH = Path(DATA_PATH_STR)
|
| 36 |
+
else:
|
| 37 |
+
# Match Docker path: /app/data/...
|
| 38 |
+
DATA_PATH = Path(__file__).resolve().parent.parent / "data" / "devign_filtered.jsonl"
|
| 39 |
+
|
| 40 |
+
print_now(f"DATA_PATH resolved to: {DATA_PATH}")
|
| 41 |
+
|
| 42 |
+
app = FastAPI(title="CommitGuard Env Server", version="0.1.0")
|
| 43 |
+
|
| 44 |
+
app.add_middleware(
|
| 45 |
+
CORSMiddleware,
|
| 46 |
+
allow_origins=["*"],
|
| 47 |
+
allow_credentials=False,
|
| 48 |
+
allow_methods=["*"],
|
| 49 |
+
allow_headers=["*"],
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
env = CommitGuardEnvironment(data_path=DATA_PATH)
|
| 53 |
+
|
| 54 |
+
@app.on_event("startup")
|
| 55 |
+
def startup_event():
|
| 56 |
+
print_now("FastAPI startup event triggered.")
|
| 57 |
+
logger.info(f"Loading data from {DATA_PATH}...")
|
| 58 |
+
try:
|
| 59 |
+
if not DATA_PATH.exists():
|
| 60 |
+
print_now(f"CRITICAL: Data path {DATA_PATH} DOES NOT EXIST")
|
| 61 |
+
env.load()
|
| 62 |
+
logger.info(f"Successfully loaded {len(env._samples)} samples.")
|
| 63 |
+
print_now(f"Loaded {len(env._samples)} samples.")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"FAILED to load data: {e}")
|
| 66 |
+
print_now(f"ERROR during load: {e}")
|
| 67 |
+
|
| 68 |
+
class StepRequest(BaseModel):
|
| 69 |
+
action: str | None = None
|
| 70 |
+
action_type: str | None = None
|
| 71 |
+
file_path: str | None = None
|
| 72 |
+
reasoning: str | None = None
|
| 73 |
+
is_vulnerable: bool | None = None
|
| 74 |
+
vuln_type: str | None = None
|
| 75 |
+
exploit_sketch: str | None = None
|
| 76 |
+
episode_id: str | None = None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@app.get("/health")
|
| 80 |
+
def health() -> dict[str, str]:
|
| 81 |
+
return {"status": "healthy"}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ResetRequest(BaseModel):
|
| 85 |
+
sample_id: str | None = None
|
| 86 |
+
|
| 87 |
+
@app.post("/reset")
|
| 88 |
+
def reset(req: ResetRequest = ResetRequest()) -> dict[str, Any]:
|
| 89 |
+
try:
|
| 90 |
+
obs = env.reset(sample_id=req.sample_id)
|
| 91 |
+
return {
|
| 92 |
+
"observation": asdict(obs),
|
| 93 |
+
"done": False,
|
| 94 |
+
"reward": 0.0,
|
| 95 |
+
}
|
| 96 |
+
except ValueError as e:
|
| 97 |
+
return {"error": str(e)}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@app.post("/step")
|
| 101 |
+
def step(req: StepRequest) -> dict[str, Any]:
|
| 102 |
+
if req.action is not None:
|
| 103 |
+
action = parse_action(req.action)
|
| 104 |
+
else:
|
| 105 |
+
action = action_from_json(req.model_dump(exclude_none=True))
|
| 106 |
+
obs, reward, done = env.step(action, episode_id=req.episode_id)
|
| 107 |
+
return {
|
| 108 |
+
"observation": asdict(obs),
|
| 109 |
+
"done": done,
|
| 110 |
+
"reward": reward,
|
| 111 |
+
"info": {"parse_error": action.parse_error},
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@app.get("/state")
|
| 116 |
+
def state(episode_id: str | None = None) -> dict[str, Any]:
|
| 117 |
+
st = env.state(episode_id=episode_id)
|
| 118 |
+
return {"state": asdict(st)}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main() -> None:
|
| 122 |
+
port = int(os.environ.get("PORT", 8000))
|
| 123 |
+
uvicorn.run("commitguard_env.server:app", host="0.0.0.0", port=port, reload=False)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
configs/openenv.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: commitguard
|
| 2 |
+
description: CommitGuard vulnerability detection environment
|
| 3 |
+
version: 0.1.0
|
| 4 |
+
entrypoint: commitguard_env.server:app
|
data/cwe_keywords.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"CWE-119": ["buffer overflow", "out of bounds", "overflow", "bounds check", "memcpy", "strcpy", "strcat", "index out of range", "heap", "stack smash"],
|
| 3 |
+
"CWE-476": ["null pointer", "nullptr", "dereference", "null check", "segmentation fault", "null access", "uninitialized"],
|
| 4 |
+
"CWE-189": ["integer overflow", "signedness", "division by zero", "arithmetic overflow", "wrap around", "truncation", "cast", "narrowing"],
|
| 5 |
+
"CWE-20": ["input validation", "improper input", "validation bypass", "sanitization", "untrusted input", "malformed data", "missing check"],
|
| 6 |
+
"CWE-22": ["path traversal", "directory traversal", "../", "..\\", "file inclusion", "arbitrary file", "escape root", "chroot"],
|
| 7 |
+
"CWE-78": ["command injection", "os.system", "subprocess", "shell=true", "exec(", "popen", "system(", "shell command"],
|
| 8 |
+
"CWE-89": ["sql injection", "sqli", "drop table", "union select", "query concatenation", "prepared statement", "bypass login"],
|
| 9 |
+
"CWE-79": ["xss", "cross site scripting", "script tag", "innerhtml", "alert(", "javascript:", "onerror", "content injection"],
|
| 10 |
+
"CWE-OTHER": ["vulnerability", "security", "exploit", "unsafe", "flaw", "bug", "error handling", "race condition", "use after free", "double free"]
|
| 11 |
+
}
|
data/devign_filtered.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/devign_test.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/devign_train.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
gitlab-ci-template.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.commitguard-scan:
|
| 2 |
+
image: python:3.12-slim
|
| 3 |
+
stage: test
|
| 4 |
+
variables:
|
| 5 |
+
COMMITGUARD_MODEL: "inmodel-labs/commitguard-llama-3b"
|
| 6 |
+
FAIL_ON_VULNERABLE: "true"
|
| 7 |
+
before_script:
|
| 8 |
+
- apt-get update && apt-get install -y git
|
| 9 |
+
- pip install commitguard[scan] # Assuming published to PyPI, or pip install git+...
|
| 10 |
+
script:
|
| 11 |
+
- |
|
| 12 |
+
FAIL_ARG=""
|
| 13 |
+
if [ "$FAIL_ON_VULNERABLE" = "true" ]; then
|
| 14 |
+
FAIL_ARG="--fail-on-vulnerable"
|
| 15 |
+
fi
|
| 16 |
+
commitguard scan --commit HEAD --format text $FAIL_ARG --model $COMMITGUARD_MODEL
|
notebooks/train_commitguard.ipynb
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# CommitGuard GRPO Training Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Train Llama-3.2-3B-Instruct to detect exploitable vulnerabilities in code commits using GRPO (Group Relative Policy Optimization).\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Requirements:** NVIDIA GPU with 16 GB VRAM (L4/A100/T4). Run this notebook on a GCP VM with GPU attached.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"## Setup\n",
|
| 14 |
+
"Connect to this notebook via SSH tunnel:\n",
|
| 15 |
+
"```bash\n",
|
| 16 |
+
"# On GCP VM:\n",
|
| 17 |
+
"jupyter notebook --no-browser --port=8888\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"# On your local machine:\n",
|
| 20 |
+
"gcloud compute ssh commitguard-train --zone=us-central1-a -- -NL 8888:localhost:8888\n",
|
| 21 |
+
"# Then open http://localhost:8888 in browser\n",
|
| 22 |
+
"```"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "markdown",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"source": []
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "markdown",
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"source": [
|
| 34 |
+
"## Cell 1 Install Dependencies"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 3,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [
|
| 42 |
+
{
|
| 43 |
+
"name": "stderr",
|
| 44 |
+
"output_type": "stream",
|
| 45 |
+
"text": [
|
| 46 |
+
"<3>WSL (3364 - Relay) ERROR: CreateProcessCommon:800: execvpe(/bin/bash) failed: No such file or directory\n"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"ename": "CalledProcessError",
|
| 51 |
+
"evalue": "Command 'b'# Install uv for fast, reliable dependency resolution\\ncurl -LsSf https://astral.sh/uv/install.sh | sh\\nexport PATH=\"$HOME/.local/bin:$PATH\"\\n\\nuv pip install -q \\\\\\n \"unsloth[cu124-torch240]\" \\\\\\n \"trl>=0.12\" \\\\\\n \"peft>=0.13\" \\\\\\n \"bitsandbytes>=0.44\" \\\\\\n \"transformers>=4.46\" \\\\\\n \"datasets>=3.0\" \\\\\\n \"accelerate>=1.0\" \\\\\\n \"wandb\" \\\\\\n \"fastapi\" \\\\\\n \"uvicorn[standard]\" \\\\\\n \"requests\" \\\\\\n \"matplotlib\"\\n'' returned non-zero exit status 1.",
|
| 52 |
+
"output_type": "error",
|
| 53 |
+
"traceback": [
|
| 54 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 55 |
+
"\u001b[31mCalledProcessError\u001b[39m Traceback (most recent call last)",
|
| 56 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m get_ipython().run_cell_magic(\u001b[33m'bash'\u001b[39m, \u001b[33m''\u001b[39m, \u001b[33m'# Install uv for fast, reliable dependency resolution\\ncurl -LsSf https://astral.sh/uv/install.sh | sh\\nexport PATH=\"$HOME/.local/bin:$PATH\"\\n\\nuv pip install -q \\\\\\n \"unsloth[cu124-torch240]\" \\\\\\n \"trl>=0.12\" \\\\\\n \"peft>=0.13\" \\\\\\n \"bitsandbytes>=0.44\" \\\\\\n \"transformers>=4.46\" \\\\\\n \"datasets>=3.0\" \\\\\\n \"accelerate>=1.0\" \\\\\\n \"wandb\" \\\\\\n \"fastapi\" \\\\\\n \"uvicorn[standard]\" \\\\\\n \"requests\" \\\\\\n \"matplotlib\"\\n'\u001b[39m)\n",
|
| 57 |
+
"\u001b[31mCalledProcessError\u001b[39m: Command 'b'# Install uv for fast, reliable dependency resolution\\ncurl -LsSf https://astral.sh/uv/install.sh | sh\\nexport PATH=\"$HOME/.local/bin:$PATH\"\\n\\nuv pip install -q \\\\\\n \"unsloth[cu124-torch240]\" \\\\\\n \"trl>=0.12\" \\\\\\n \"peft>=0.13\" \\\\\\n \"bitsandbytes>=0.44\" \\\\\\n \"transformers>=4.46\" \\\\\\n \"datasets>=3.0\" \\\\\\n \"accelerate>=1.0\" \\\\\\n \"wandb\" \\\\\\n \"fastapi\" \\\\\\n \"uvicorn[standard]\" \\\\\\n \"requests\" \\\\\\n \"matplotlib\"\\n'' returned non-zero exit status 1."
|
| 58 |
+
]
|
| 59 |
+
}
|
| 60 |
+
],
|
| 61 |
+
"source": [
|
| 62 |
+
"!pip install -q unsloth\n",
|
| 63 |
+
"!pip uninstall unsloth -y && pip install -q --upgrade --no-cache-dir \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 64 |
+
"!pip install -q trl>=0.12 peft bitsandbytes transformers datasets accelerate wandb fastapi uvicorn[standard] requests matplotlib"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## Cell 2 Verify GPU"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"import torch\n",
|
| 81 |
+
"print(f\"PyTorch: {torch.__version__}\")\n",
|
| 82 |
+
"print(f\"CUDA: {torch.cuda.is_available()}\")\n",
|
| 83 |
+
"if torch.cuda.is_available():\n",
|
| 84 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 85 |
+
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\")\n",
|
| 86 |
+
" print(f\"BF16: {torch.cuda.is_bf16_supported()}\")\n",
|
| 87 |
+
"else:\n",
|
| 88 |
+
" raise RuntimeError(\"No GPU detected this notebook requires a CUDA GPU.\")"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "markdown",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"source": [
|
| 95 |
+
"## Cell 3 Clone Repo & Start Env Server"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"import os, subprocess, time, requests, sys\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Check if running in Google Colab\n",
|
| 107 |
+
"if \"google.colab\" in sys.modules:\n",
|
| 108 |
+
" print(\"Running in Google Colab.\")\n",
|
| 109 |
+
" # Reset to base directory in case cell is run multiple times\n",
|
| 110 |
+
" os.chdir(\"/content\")\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" if not os.path.exists(\"/content/project.zip\"):\n",
|
| 113 |
+
" from google.colab import files\n",
|
| 114 |
+
" print(\"\\n--- WE NEED YOUR PROJECT.ZIP ---\")\n",
|
| 115 |
+
" print(\"Please click 'Choose Files' below and select project.zip from your computer:\\n\")\n",
|
| 116 |
+
" uploaded = files.upload()\n",
|
| 117 |
+
" \n",
|
| 118 |
+
" if os.path.exists(\"/content/project.zip\"):\n",
|
| 119 |
+
" print(\"Extracting project.zip...\")\n",
|
| 120 |
+
" !unzip -q -o /content/project.zip -d /content/commitguard\n",
|
| 121 |
+
" else:\n",
|
| 122 |
+
" print(\"\\n*** ERROR: project.zip still not found! ***\\n\")\n",
|
| 123 |
+
" sys.exit(1)\n",
|
| 124 |
+
" \n",
|
| 125 |
+
" os.chdir(\"/content/commitguard\")\n",
|
| 126 |
+
" REPO_DIR = os.getcwd()\n",
|
| 127 |
+
"else:\n",
|
| 128 |
+
" if os.path.basename(os.getcwd()) == \"notebooks\":\n",
|
| 129 |
+
" REPO_DIR = os.path.abspath(\"..\")\n",
|
| 130 |
+
" else:\n",
|
| 131 |
+
" REPO_DIR = os.getcwd()\n",
|
| 132 |
+
" os.chdir(REPO_DIR)\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"print(f\"Using REPO_DIR: {REPO_DIR}\")\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"# 2. Install current project in editable mode\n",
|
| 137 |
+
"!pip install -e . -q\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"# 3. Start env server in background\n",
|
| 140 |
+
"server_proc = subprocess.Popen(\n",
|
| 141 |
+
" [sys.executable, \"-m\", \"commitguard_env.server\"],\n",
|
| 142 |
+
" stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True\n",
|
| 143 |
+
")\n",
|
| 144 |
+
"time.sleep(5)\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"try:\n",
|
| 147 |
+
" r = requests.get(\"http://localhost:8000/health\")\n",
|
| 148 |
+
" print(f\"Env server: {r.json()}\")\n",
|
| 149 |
+
"except Exception as e:\n",
|
| 150 |
+
" print(f\"Server failed to start: {e}\")\n",
|
| 151 |
+
" stdout, stderr = server_proc.communicate(timeout=1)\n",
|
| 152 |
+
" print(f\"STDOUT: {stdout}\")\n",
|
| 153 |
+
" print(f\"STDERR: {stderr}\")\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"# Quick sanity reset + step\n",
|
| 156 |
+
"r = requests.post(\"http://localhost:8000/reset\", json={})\n",
|
| 157 |
+
"obs = r.json()[\"observation\"]\n",
|
| 158 |
+
"print(f\"Sample diff length: {len(obs['diff'])} chars, files: {obs['available_files']}\")\n"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "markdown",
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"source": [
|
| 165 |
+
"## Cell 4 HuggingFace Login (for gated Llama model)"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "code",
|
| 170 |
+
"execution_count": null,
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"outputs": [],
|
| 173 |
+
"source": [
|
| 174 |
+
"from huggingface_hub import login\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
|
| 177 |
+
"if HF_TOKEN:\n",
|
| 178 |
+
" login(token=HF_TOKEN)\n",
|
| 179 |
+
" print(\"Logged in via token.\")\n",
|
| 180 |
+
"else:\n",
|
| 181 |
+
" login()\n" ]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"cell_type": "markdown",
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"source": [
|
| 187 |
+
"## Cell 5 Wandb Login (optional but recommended)"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "code",
|
| 192 |
+
"execution_count": null,
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"outputs": [],
|
| 195 |
+
"source": [
|
| 196 |
+
"import wandb\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"USE_WANDB = False\n",
|
| 199 |
+
"os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
|
| 200 |
+
"print(\"Wandb disabled.\")\n"
|
| 201 |
+
]
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"cell_type": "markdown",
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"source": [
|
| 207 |
+
"## Cell 6 Load Model with Unsloth (4-bit LoRA)"
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"cell_type": "code",
|
| 212 |
+
"execution_count": null,
|
| 213 |
+
"metadata": {},
|
| 214 |
+
"outputs": [],
|
| 215 |
+
"source": [
|
| 216 |
+
"from unsloth import FastLanguageModel, PatchFastRL\n",
|
| 217 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"PatchFastRL(\"GRPO\", FastLanguageModel)\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"print(f\"Loading {MODEL_NAME} in 4-bit...\")\n",
|
| 224 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 225 |
+
" model_name=MODEL_NAME,\n",
|
| 226 |
+
" max_seq_length=2048,\n",
|
| 227 |
+
" load_in_4bit=True,\n",
|
| 228 |
+
" fast_inference=False,\n",
|
| 229 |
+
" max_lora_rank=16,\n",
|
| 230 |
+
")\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 233 |
+
" model,\n",
|
| 234 |
+
" r=8,\n",
|
| 235 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 236 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 237 |
+
" lora_alpha=16,\n",
|
| 238 |
+
" lora_dropout=0,\n",
|
| 239 |
+
" bias=\"none\",\n",
|
| 240 |
+
" use_gradient_checkpointing=\"unsloth\",\n",
|
| 241 |
+
" random_state=3407,\n",
|
| 242 |
+
")\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"print(f\"Model loaded. Trainable params: {model.print_trainable_parameters()}\")"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "markdown",
|
| 249 |
+
"metadata": {},
|
| 250 |
+
"source": [
|
| 251 |
+
"## Cell 7 Build Training Dataset from Env"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"cell_type": "code",
|
| 256 |
+
"execution_count": null,
|
| 257 |
+
"metadata": {},
|
| 258 |
+
"outputs": [],
|
| 259 |
+
"source": [
|
| 260 |
+
"import sys, requests\n",
|
| 261 |
+
"from datasets import Dataset\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"sys.path.insert(0, os.path.join(REPO_DIR, \"scripts\"))\n",
|
| 264 |
+
"from agent_prompt import SYSTEM_PROMPT, get_agent_prompt\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"ENV_URL = \"http://localhost:8000\"\n",
|
| 267 |
+
"N_SAMPLES = 200 # Number of training prompts (updated)\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"samples = []\n",
|
| 270 |
+
"for i in range(N_SAMPLES):\n",
|
| 271 |
+
" r = requests.post(f\"{ENV_URL}/reset\", json={}, timeout=10)\n",
|
| 272 |
+
" if r.status_code != 200:\n",
|
| 273 |
+
" continue\n",
|
| 274 |
+
" obs = r.json()[\"observation\"]\n",
|
| 275 |
+
" state_r = requests.get(f\"{ENV_URL}/state\").json()\n",
|
| 276 |
+
" current_sample_id = state_r.get(\"state\", {}).get(\"current_sample_id\", \"unknown\")\n",
|
| 277 |
+
" user_msg = get_agent_prompt(obs[\"diff\"], obs[\"available_files\"], obs.get(\"step_idx\", 0))\n",
|
| 278 |
+
" samples.append({\n",
|
| 279 |
+
" \"prompt\": [\n",
|
| 280 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 281 |
+
" {\"role\": \"user\", \"content\": user_msg},\n",
|
| 282 |
+
" ],\n",
|
| 283 |
+
" \"sample_id\": current_sample_id,\n",
|
| 284 |
+
" })\n",
|
| 285 |
+
" if (i + 1) % 50 == 0:\n",
|
| 286 |
+
" print(f\" fetched {i + 1}/{N_SAMPLES}\")\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"dataset = Dataset.from_list(samples)\n",
|
| 289 |
+
"print(f\"\\nDataset ready: {len(dataset)} samples\")\n",
|
| 290 |
+
"print(f\"Sample prompt preview: {str(dataset[0]['prompt'][1]['content'])[:200]}...\")"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "markdown",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"source": [
|
| 297 |
+
"## Cell 8 Define Reward Function"
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"cell_type": "code",
|
| 302 |
+
"execution_count": null,
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"outputs": [],
|
| 305 |
+
"source": [
|
| 306 |
+
"def get_reward_from_env(prompts, completions, sample_id, **kwargs) -> list[float]:\n",
|
| 307 |
+
" \"\"\"Send each completion to the env as an action, collect reward.\"\"\"\n",
|
| 308 |
+
" rewards = []\n",
|
| 309 |
+
" for p_id, completion in zip(sample_id, completions):\n",
|
| 310 |
+
" try:\n",
|
| 311 |
+
" requests.post(f\"{ENV_URL}/reset\", json={\"sample_id\": p_id}, timeout=10)\n",
|
| 312 |
+
" text = completion[-1][\"content\"] if isinstance(completion, list) else str(completion)\n",
|
| 313 |
+
" r = requests.post(f\"{ENV_URL}/step\", json={\"action\": text}, timeout=10)\n",
|
| 314 |
+
" if r.status_code == 200:\n",
|
| 315 |
+
" rewards.append(float(r.json().get(\"reward\", 0.0)))\n",
|
| 316 |
+
" else:\n",
|
| 317 |
+
" rewards.append(-0.5)\n",
|
| 318 |
+
" except Exception:\n",
|
| 319 |
+
" rewards.append(-1.0)\n",
|
| 320 |
+
" return rewards\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"# Quick test\n",
|
| 323 |
+
"test_r = get_reward_from_env(\n",
|
| 324 |
+
" [\"test\"],\n",
|
| 325 |
+
" [\"<action><action_type>verdict</action_type><is_vulnerable>true</is_vulnerable><vuln_type>CWE-119</vuln_type><exploit_sketch>buffer overflow</exploit_sketch></action>\"],\n",
|
| 326 |
+
" [\"test_id\"]\n",
|
| 327 |
+
")\n",
|
| 328 |
+
"print(f\"Reward function test: {test_r}\")"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "markdown",
|
| 333 |
+
"metadata": {},
|
| 334 |
+
"source": [
|
| 335 |
+
"## Cell 9 Configure & Launch GRPO Training\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"This is the main training loop. ~2-3 hours on L4 for 300 steps."
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "code",
|
| 342 |
+
"execution_count": null,
|
| 343 |
+
"metadata": {},
|
| 344 |
+
"outputs": [],
|
| 345 |
+
"source": [
|
| 346 |
+
"OUTPUT_DIR = \"outputs/commitguard-llama-3b\"\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"training_args = GRPOConfig(\n",
|
| 349 |
+
" output_dir=OUTPUT_DIR,\n",
|
| 350 |
+
" num_generations=4,\n",
|
| 351 |
+
" max_completion_length=512,\n",
|
| 352 |
+
" per_device_train_batch_size=1,\n",
|
| 353 |
+
" gradient_accumulation_steps=4,\n",
|
| 354 |
+
" learning_rate=5e-6,\n",
|
| 355 |
+
" logging_steps=1,\n",
|
| 356 |
+
" save_steps=50,\n",
|
| 357 |
+
" max_steps=300,\n",
|
| 358 |
+
" report_to=\"wandb\" if USE_WANDB else \"none\",\n",
|
| 359 |
+
" bf16=torch.cuda.is_bf16_supported(),\n",
|
| 360 |
+
" fp16=not torch.cuda.is_bf16_supported(),\n",
|
| 361 |
+
")\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"trainer = GRPOTrainer(\n",
|
| 364 |
+
" model=model,\n",
|
| 365 |
+
" processing_class=tokenizer,\n",
|
| 366 |
+
" reward_funcs=[get_reward_from_env],\n",
|
| 367 |
+
" args=training_args,\n",
|
| 368 |
+
" train_dataset=dataset,\n",
|
| 369 |
+
")\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"print(\"Starting GRPO training...\")\n",
|
| 372 |
+
"print(f\" Steps: {training_args.max_steps}\")\n",
|
| 373 |
+
"print(f\" Generations per prompt: {training_args.num_generations}\")\n",
|
| 374 |
+
"print(f\" Save every: {training_args.save_steps} steps\")\n",
|
| 375 |
+
"print(f\" Output: {OUTPUT_DIR}\")\n",
|
| 376 |
+
"print(\"=\"*50)\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"trainer.train()"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"cell_type": "markdown",
|
| 383 |
+
"metadata": {},
|
| 384 |
+
"source": [
|
| 385 |
+
"## Cell 10 Save Final LoRA Adapter"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"cell_type": "code",
|
| 390 |
+
"execution_count": null,
|
| 391 |
+
"metadata": {},
|
| 392 |
+
"outputs": [],
|
| 393 |
+
"source": [
|
| 394 |
+
"FINAL_DIR = f\"{OUTPUT_DIR}/final\"\n",
|
| 395 |
+
"model.save_pretrained_merged(FINAL_DIR, tokenizer, save_method=\"lora\")\n",
|
| 396 |
+
"print(f\"LoRA adapter saved to {FINAL_DIR}\")\n",
|
| 397 |
+
"\n",
|
| 398 |
+
"# List saved files\n",
|
| 399 |
+
"for f in sorted(os.listdir(FINAL_DIR)):\n",
|
| 400 |
+
" size_mb = os.path.getsize(os.path.join(FINAL_DIR, f)) / 1024**2\n",
|
| 401 |
+
" print(f\" {f}: {size_mb:.1f} MB\")"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "markdown",
|
| 406 |
+
"metadata": {},
|
| 407 |
+
"source": [
|
| 408 |
+
"## Cell 11 Quick Evaluation (Baseline vs Trained)"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"cell_type": "code",
|
| 413 |
+
"execution_count": null,
|
| 414 |
+
"metadata": {},
|
| 415 |
+
"outputs": [],
|
| 416 |
+
"source": [
|
| 417 |
+
"import json\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"# Load test set\n",
|
| 420 |
+
"test_path = os.path.join(REPO_DIR, \"data\", \"devign_test.jsonl\")\n",
|
| 421 |
+
"with open(test_path) as f:\n",
|
| 422 |
+
" test_samples = [json.loads(l) for l in f if l.strip()]\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"print(f\"Evaluating on {len(test_samples)} held-out samples...\")\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"# Run trained model on test set\n",
|
| 427 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 428 |
+
"\n",
|
| 429 |
+
"correct = 0\n",
|
| 430 |
+
"results = []\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"for i, sample in enumerate(test_samples):\n",
|
| 433 |
+
" user_msg = get_agent_prompt(sample[\"diff\"], sample[\"available_files\"], 0)\n",
|
| 434 |
+
" messages = [\n",
|
| 435 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 436 |
+
" {\"role\": \"user\", \"content\": user_msg},\n",
|
| 437 |
+
" ]\n",
|
| 438 |
+
" inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(model.device)\n",
|
| 439 |
+
" with torch.no_grad():\n",
|
| 440 |
+
" output = model.generate(inputs, max_new_tokens=512, temperature=0.1, do_sample=True)\n",
|
| 441 |
+
" response = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" # Parse verdict\n",
|
| 444 |
+
" sys.path.insert(0, os.path.join(REPO_DIR, \"commitguard_env\"))\n",
|
| 445 |
+
" from commitguard_env.parse_action import parse_action\n",
|
| 446 |
+
" action = parse_action(response)\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" pred_vuln = bool(action.is_vulnerable) if action.is_vulnerable is not None else False\n",
|
| 449 |
+
" truth_vuln = sample[\"is_vulnerable\"]\n",
|
| 450 |
+
"\n",
|
| 451 |
+
" if pred_vuln == truth_vuln:\n",
|
| 452 |
+
" correct += 1\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" results.append({\n",
|
| 455 |
+
" \"sample_id\": sample[\"sample_id\"],\n",
|
| 456 |
+
" \"pred\": pred_vuln,\n",
|
| 457 |
+
" \"truth\": truth_vuln,\n",
|
| 458 |
+
" \"cwe\": sample.get(\"cwe\"),\n",
|
| 459 |
+
" \"vuln_type\": action.vuln_type,\n",
|
| 460 |
+
" })\n",
|
| 461 |
+
"\n",
|
| 462 |
+
" if (i + 1) % 20 == 0:\n",
|
| 463 |
+
" print(f\" {i+1}/{len(test_samples)} running accuracy: {100*correct/(i+1):.1f}%\")\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"accuracy = 100 * correct / len(test_samples)\n",
|
| 466 |
+
"print(f\"\\nFinal trained accuracy: {accuracy:.1f}%\")\n",
|
| 467 |
+
"\n",
|
| 468 |
+
"with open(os.path.join(REPO_DIR, \"eval_trained.json\"), \"w\") as f:\n",
|
| 469 |
+
" json.dump(results, f, indent=2)\n",
|
| 470 |
+
"print(\"Results saved to eval_trained.json\")"
|
| 471 |
+
]
|
| 472 |
+
},
|
| 473 |
+
{
|
| 474 |
+
"cell_type": "markdown",
|
| 475 |
+
"metadata": {},
|
| 476 |
+
"source": [
|
| 477 |
+
"## Cell 12 Generate Plots"
|
| 478 |
+
]
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"cell_type": "code",
|
| 482 |
+
"execution_count": null,
|
| 483 |
+
"metadata": {},
|
| 484 |
+
"outputs": [],
|
| 485 |
+
"source": [
|
| 486 |
+
"import matplotlib.pyplot as plt\n",
|
| 487 |
+
"from collections import Counter\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"os.makedirs(os.path.join(REPO_DIR, \"plots\"), exist_ok=True)\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"# --- Plot 1: Training reward curve (from trainer logs) ---\n",
|
| 492 |
+
"if hasattr(trainer, 'state') and trainer.state.log_history:\n",
|
| 493 |
+
" steps = [l[\"step\"] for l in trainer.state.log_history if \"loss\" in l]\n",
|
| 494 |
+
" losses = [l[\"loss\"] for l in trainer.state.log_history if \"loss\" in l]\n",
|
| 495 |
+
" \n",
|
| 496 |
+
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 497 |
+
" ax.plot(steps, losses, color=\"#2ecc71\", linewidth=2)\n",
|
| 498 |
+
" ax.set_xlabel(\"Training Step\")\n",
|
| 499 |
+
" ax.set_ylabel(\"Loss\")\n",
|
| 500 |
+
" ax.set_title(\"CommitGuard GRPO Training Loss\")\n",
|
| 501 |
+
" ax.grid(True, linestyle=\"--\", alpha=0.5)\n",
|
| 502 |
+
" fig.savefig(os.path.join(REPO_DIR, \"plots\", \"reward_curve.png\"), dpi=150)\n",
|
| 503 |
+
" plt.show()\n",
|
| 504 |
+
" print(\"Saved plots/reward_curve.png\")\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" # --- Plot 2: Accuracy comparison ---\n",
|
| 507 |
+
" with open(os.path.join(REPO_DIR, \"eval_baseline.json\")) as f:\n",
|
| 508 |
+
" b_data = json.load(f)\n",
|
| 509 |
+
" baseline_acc = 100 * sum(1 for x in b_data if x['pred'] == x['truth']) / len(b_data)\n",
|
| 510 |
+
" trained_acc = accuracy\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" fig, ax = plt.subplots(figsize=(8, 5))\n",
|
| 513 |
+
" bars = ax.bar([\"Baseline (Untrained)\", \"CommitGuard (Trained)\"],\n",
|
| 514 |
+
" [baseline_acc, trained_acc],\n",
|
| 515 |
+
" color=[\"#95a5a6\", \"#3498db\"])\n",
|
| 516 |
+
" ax.set_ylabel(\"Detection Accuracy (%)\")\n",
|
| 517 |
+
" ax.set_title(\"Vulnerability Detection: Baseline vs. Trained\")\n",
|
| 518 |
+
" ax.set_ylim(0, 100)\n",
|
| 519 |
+
" for bar in bars:\n",
|
| 520 |
+
" h = bar.get_height()\n",
|
| 521 |
+
" ax.text(bar.get_x() + bar.get_width()/2., h + 1, f\"{h:.1f}%\",\n",
|
| 522 |
+
" ha=\"center\", fontweight=\"bold\")\n",
|
| 523 |
+
" fig.savefig(os.path.join(REPO_DIR, \"plots\", \"baseline_vs_trained.png\"), dpi=150)\n",
|
| 524 |
+
" plt.show()\n",
|
| 525 |
+
" print(\"Saved plots/baseline_vs_trained.png\")\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" # --- Plot 3: Per-CWE breakdown ---\n",
|
| 528 |
+
" cwe_correct = Counter()\n",
|
| 529 |
+
" cwe_total = Counter()\n",
|
| 530 |
+
" for r in results:\n",
|
| 531 |
+
" if r[\"cwe\"]:\n",
|
| 532 |
+
" cwe_total[r[\"cwe\"]] += 1\n",
|
| 533 |
+
" if r[\"pred\"] == r[\"truth\"]:\n",
|
| 534 |
+
" cwe_correct[r[\"cwe\"]] += 1\n",
|
| 535 |
+
"\n",
|
| 536 |
+
" cwes = sorted(cwe_total.keys())\n",
|
| 537 |
+
" accs = [100 * cwe_correct[c] / cwe_total[c] if cwe_total[c] > 0 else 0 for c in cwes]\n",
|
| 538 |
+
"\n",
|
| 539 |
+
" if cwes:\n",
|
| 540 |
+
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 541 |
+
" ax.bar(cwes, accs, color=\"#e67e22\")\n",
|
| 542 |
+
" ax.set_ylabel(\"Accuracy (%)\")\n",
|
| 543 |
+
" ax.set_title(\"Trained Model Accuracy by CWE Type\")\n",
|
| 544 |
+
" ax.set_ylim(0, 100)\n",
|
| 545 |
+
" plt.xticks(rotation=45)\n",
|
| 546 |
+
" plt.tight_layout()\n",
|
| 547 |
+
" fig.savefig(os.path.join(REPO_DIR, \"plots\", \"per_cwe.png\"), dpi=150)\n",
|
| 548 |
+
" plt.show()\n",
|
| 549 |
+
" print(\"Saved plots/per_cwe.png\")"
|
| 550 |
+
]
|
| 551 |
+
},
|
| 552 |
+
{
|
| 553 |
+
"cell_type": "markdown",
|
| 554 |
+
"metadata": {},
|
| 555 |
+
"source": [
|
| 556 |
+
"## Cell 13 Cleanup\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"Stop the env server and print final summary."
|
| 559 |
+
]
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"cell_type": "code",
|
| 563 |
+
"execution_count": null,
|
| 564 |
+
"metadata": {},
|
| 565 |
+
"outputs": [],
|
| 566 |
+
"source": [
|
| 567 |
+
"server_proc.terminate()\n",
|
| 568 |
+
"print(\"Env server stopped.\")\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"print(\"\\n\" + \"=\"*50)\n",
|
| 571 |
+
"print(\" TRAINING COMPLETE\")\n",
|
| 572 |
+
"print(\"=\"*50)\n",
|
| 573 |
+
"print(f\" Model: {MODEL_NAME}\")\n",
|
| 574 |
+
"print(f\" Steps: {training_args.max_steps}\")\n",
|
| 575 |
+
"print(f\" Accuracy: {baseline_acc:.1f}% {trained_acc:.1f}% (+{trained_acc - baseline_acc:.1f}pp)\")\n",
|
| 576 |
+
"print(f\" Adapter: {FINAL_DIR}\")\n",
|
| 577 |
+
"print(f\" Plots: plots/reward_curve.png, baseline_vs_trained.png, per_cwe.png\")\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"print(\"\\nNext: copy outputs/ and plots/ back to your local machine.\")"
|
| 580 |
+
]
|
| 581 |
+
}
|
| 582 |
+
],
|
| 583 |
+
"metadata": {
|
| 584 |
+
"kernelspec": {
|
| 585 |
+
"display_name": "Python 3 (ipykernel)",
|
| 586 |
+
"language": "python",
|
| 587 |
+
"name": "python3"
|
| 588 |
+
},
|
| 589 |
+
"language_info": {
|
| 590 |
+
"codemirror_mode": {
|
| 591 |
+
"name": "ipython",
|
| 592 |
+
"version": 3
|
| 593 |
+
},
|
| 594 |
+
"file_extension": ".py",
|
| 595 |
+
"mimetype": "text/x-python",
|
| 596 |
+
"name": "python",
|
| 597 |
+
"nbconvert_exporter": "python",
|
| 598 |
+
"pygments_lexer": "ipython3",
|
| 599 |
+
"version": "3.13.13"
|
| 600 |
+
}
|
| 601 |
+
},
|
| 602 |
+
"nbformat": 4,
|
| 603 |
+
"nbformat_minor": 4
|
| 604 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "commitguard"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "CommitGuard OpenEnv RL environment for commit-time vuln detection"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi>=0.110",
|
| 9 |
+
"uvicorn[standard]>=0.27",
|
| 10 |
+
"pydantic>=2.6",
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
[project.optional-dependencies]
|
| 14 |
+
dev = [
|
| 15 |
+
"pytest>=8.0",
|
| 16 |
+
"requests>=2.31",
|
| 17 |
+
]
|
| 18 |
+
scan = [
|
| 19 |
+
"torch>=2.4",
|
| 20 |
+
"transformers>=4.46",
|
| 21 |
+
"accelerate>=1.0",
|
| 22 |
+
]
|
| 23 |
+
train = [
|
| 24 |
+
"requests",
|
| 25 |
+
"torch>=2.4",
|
| 26 |
+
"transformers>=4.46",
|
| 27 |
+
"trl>=0.12",
|
| 28 |
+
"accelerate>=1.0",
|
| 29 |
+
"peft>=0.13",
|
| 30 |
+
"datasets>=3.0",
|
| 31 |
+
"wandb",
|
| 32 |
+
"matplotlib",
|
| 33 |
+
"unsloth",
|
| 34 |
+
"bitsandbytes>=0.44",
|
| 35 |
+
"jupyter",
|
| 36 |
+
"ipywidgets",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[project.scripts]
|
| 40 |
+
commitguard = "commitguard_env.cli:main"
|
| 41 |
+
server = "commitguard_env.server:main"
|
| 42 |
+
|
| 43 |
+
[tool.setuptools]
|
| 44 |
+
packages = ["commitguard_env"]
|
| 45 |
+
|
| 46 |
+
[build-system]
|
| 47 |
+
requires = ["setuptools>=68"]
|
| 48 |
+
build-backend = "setuptools.build_meta"
|
pyrightconfig.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"venvPath": ".",
|
| 3 |
+
"venv": ".venv",
|
| 4 |
+
"include": [
|
| 5 |
+
"scripts",
|
| 6 |
+
"commitguard_env",
|
| 7 |
+
"server",
|
| 8 |
+
"."
|
| 9 |
+
],
|
| 10 |
+
"extraPaths": [
|
| 11 |
+
"${workspaceFolder}",
|
| 12 |
+
"${workspaceFolder}/scripts"
|
| 13 |
+
],
|
| 14 |
+
"reportMissingImports": true,
|
| 15 |
+
"typeCheckingMode": "basic"
|
| 16 |
+
}
|
scratch/extract_sample.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
target_id = "2bf3aa85f08186b8162b76e7e8efe5b5a44306a6"
|
| 5 |
+
data_dir = r"c:\Users\DIVYANK BHARDWAJ\Desktop\hackathon project\commitguard\data"
|
| 6 |
+
files = ["devign_test.jsonl", "devign_filtered.jsonl"]
|
| 7 |
+
|
| 8 |
+
found = False
|
| 9 |
+
for filename in files:
|
| 10 |
+
path = os.path.join(data_dir, filename)
|
| 11 |
+
if not os.path.exists(path):
|
| 12 |
+
continue
|
| 13 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 14 |
+
for line in f:
|
| 15 |
+
data = json.loads(line)
|
| 16 |
+
if data.get("sample_id") == target_id:
|
| 17 |
+
print(json.dumps(data, indent=2))
|
| 18 |
+
found = True
|
| 19 |
+
break
|
| 20 |
+
if found:
|
| 21 |
+
break
|
| 22 |
+
|
| 23 |
+
if not found:
|
| 24 |
+
print(f"Sample {target_id} not found in {files}")
|
scripts/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Scripts
|
| 2 |
+
|
| 3 |
+
This directory is for repeatable CLI-first ops (dataset preprocessing, local smoke runs).
|
| 4 |
+
|
| 5 |
+
Primary expected script (Deepak):
|
| 6 |
+
- `preprocess_devign.py` → produces `data/devign_filtered.jsonl`
|
| 7 |
+
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Marking scripts as a package for resolution
|
scripts/check_cuda.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
print(f'CUDA available: {torch.cuda.is_available()}')
|
| 3 |
+
if torch.cuda.is_available():
|
| 4 |
+
print(f'Device count: {torch.cuda.device_count()}')
|
| 5 |
+
print(f'Device name: {torch.cuda.get_device_name(0)}')
|
| 6 |
+
print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
|
scripts/check_disjoint.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
def get_ids(file_path):
|
| 5 |
+
ids = set()
|
| 6 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 7 |
+
for line in f:
|
| 8 |
+
obj = json.loads(line)
|
| 9 |
+
ids.add(obj.get('commit_id') or obj.get('sample_id'))
|
| 10 |
+
return ids
|
| 11 |
+
|
| 12 |
+
train_ids = get_ids('data/devign_train.jsonl')
|
| 13 |
+
test_ids = get_ids('data/devign_test.jsonl')
|
| 14 |
+
|
| 15 |
+
overlap = train_ids.intersection(test_ids)
|
| 16 |
+
print(f"Train IDs: {len(train_ids)}")
|
| 17 |
+
print(f"Test IDs: {len(test_ids)}")
|
| 18 |
+
print(f"Overlap: {len(overlap)}")
|
| 19 |
+
if overlap:
|
| 20 |
+
print(f"Overlapping IDs: {list(overlap)[:5]}")
|
scripts/check_unsloth.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from unsloth import FastLanguageModel
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 6 |
+
model_name="unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
|
| 7 |
+
max_seq_length=1024,
|
| 8 |
+
load_in_4bit=True,
|
| 9 |
+
)
|
| 10 |
+
print("Successfully loaded model in 4-bit on this GPU.")
|
| 11 |
+
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
|
| 12 |
+
except Exception as e:
|
| 13 |
+
print(f"Failed to load model: {e}")
|