Spaces:
Runtime error
Runtime error
Commit ·
8f4e44a
1
Parent(s): 6398066
Feat (Phase 1 & 2): Extract scanner module and add CLI interface
Browse files- commitguard_env/cli.py +107 -0
- commitguard_env/inference.py +118 -0
- commitguard_env/models.py +9 -0
- commitguard_env/scanner.py +54 -0
- pyproject.toml +1 -0
commitguard_env/cli.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import asdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from .scanner import CommitGuardScanner
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def cmd_scan(args):
|
| 12 |
+
diff_text = ""
|
| 13 |
+
if getattr(args, "diff", None):
|
| 14 |
+
diff_text = Path(args.diff).read_text(encoding="utf-8")
|
| 15 |
+
elif getattr(args, "staged", False):
|
| 16 |
+
diff_text = subprocess.check_output(["git", "diff", "--staged"], text=True)
|
| 17 |
+
elif getattr(args, "commit", None):
|
| 18 |
+
diff_text = subprocess.check_output(["git", "show", args.commit], text=True)
|
| 19 |
+
elif getattr(args, "pr", None):
|
| 20 |
+
diff_text = subprocess.check_output(["gh", "pr", "diff", args.pr], text=True)
|
| 21 |
+
else:
|
| 22 |
+
print("Must specify one of --diff, --staged, --commit, or --pr")
|
| 23 |
+
sys.exit(1)
|
| 24 |
+
|
| 25 |
+
if not diff_text.strip():
|
| 26 |
+
print("No diff found to scan.")
|
| 27 |
+
sys.exit(0)
|
| 28 |
+
|
| 29 |
+
print(f"Loading model ({args.model})...", file=sys.stderr)
|
| 30 |
+
scanner = CommitGuardScanner(model_path=args.model, is_lora=args.is_lora, base_model=args.base_model)
|
| 31 |
+
|
| 32 |
+
print(f"Scanning diff ({len(diff_text)} chars)...", file=sys.stderr)
|
| 33 |
+
result = scanner.scan(diff_text)
|
| 34 |
+
|
| 35 |
+
if args.format == "json":
|
| 36 |
+
print(json.dumps(asdict(result), indent=2))
|
| 37 |
+
elif args.format == "text":
|
| 38 |
+
status = "VULNERABLE ⚠️" if result.is_vulnerable else "SAFE ✅"
|
| 39 |
+
print(f"\nVerdict: {status}")
|
| 40 |
+
if result.is_vulnerable:
|
| 41 |
+
print(f"CWE: {result.cwe}")
|
| 42 |
+
print(f"Exploit Sketch:\n {result.exploit_sketch}")
|
| 43 |
+
if result.parse_error:
|
| 44 |
+
print(f"\nParser Warning: {result.parse_error}")
|
| 45 |
+
elif args.format == "sarif":
|
| 46 |
+
# Minimal SARIF output stub
|
| 47 |
+
print("SARIF format not fully implemented yet.", file=sys.stderr)
|
| 48 |
+
print(json.dumps(asdict(result)))
|
| 49 |
+
|
| 50 |
+
if args.fail_on_vulnerable and result.is_vulnerable:
|
| 51 |
+
sys.exit(1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def cmd_server(args):
|
| 55 |
+
from .server import main as server_main
|
| 56 |
+
server_main()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def cmd_eval(args):
|
| 60 |
+
# This is a bit hacky to reuse the script without modifying sys.path everywhere
|
| 61 |
+
# A cleaner approach would be moving evaluate.py into commitguard_env
|
| 62 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 63 |
+
eval_script = REPO_ROOT / "scripts" / "evaluate.py"
|
| 64 |
+
|
| 65 |
+
cmd = [sys.executable, str(eval_script)]
|
| 66 |
+
cmd.extend(args.eval_args)
|
| 67 |
+
subprocess.run(cmd, check=True)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main():
|
| 71 |
+
parser = argparse.ArgumentParser(description="CommitGuard AI-paced security review")
|
| 72 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 73 |
+
|
| 74 |
+
# 'scan' subcommand
|
| 75 |
+
scan_parser = subparsers.add_parser("scan", help="Scan a code diff for vulnerabilities")
|
| 76 |
+
|
| 77 |
+
source_group = scan_parser.add_mutually_exclusive_group(required=True)
|
| 78 |
+
source_group.add_argument("--diff", type=str, help="Path to a diff file")
|
| 79 |
+
source_group.add_argument("--staged", action="store_true", help="Scan git staged changes")
|
| 80 |
+
source_group.add_argument("--commit", type=str, help="Scan a specific git commit (e.g., HEAD)")
|
| 81 |
+
source_group.add_argument("--pr", type=str, help="Scan a GitHub PR URL or ID (requires gh cli)")
|
| 82 |
+
|
| 83 |
+
scan_parser.add_argument("--model", type=str, default="inmodel-labs/commitguard-llama-3b", help="Model path or HF ID")
|
| 84 |
+
scan_parser.add_argument("--base-model", type=str, default=None, help="Base model if using LoRA")
|
| 85 |
+
scan_parser.add_argument("--is-lora", action="store_true", help="Whether the model is a LoRA adapter")
|
| 86 |
+
scan_parser.add_argument("--format", choices=["text", "json", "sarif"], default="text", help="Output format")
|
| 87 |
+
scan_parser.add_argument("--fail-on-vulnerable", action="store_true", help="Exit with code 1 if vulnerable")
|
| 88 |
+
|
| 89 |
+
# 'server' subcommand
|
| 90 |
+
server_parser = subparsers.add_parser("server", help="Start the OpenEnv environment server")
|
| 91 |
+
# server_main takes PORT from environment
|
| 92 |
+
|
| 93 |
+
# 'eval' subcommand
|
| 94 |
+
eval_parser = subparsers.add_parser("eval", help="Run the evaluation harness")
|
| 95 |
+
eval_parser.add_argument("eval_args", nargs=argparse.REMAINDER, help="Arguments passed to evaluate.py")
|
| 96 |
+
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
if args.command == "scan":
|
| 100 |
+
cmd_scan(args)
|
| 101 |
+
elif args.command == "server":
|
| 102 |
+
cmd_server(args)
|
| 103 |
+
elif args.command == "eval":
|
| 104 |
+
cmd_eval(args)
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
main()
|
commitguard_env/inference.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
# Add project root to path for imports to find agent_prompt if run directly
|
| 8 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 9 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from agent_prompt import SYSTEM_PROMPT
|
| 13 |
+
except ImportError:
|
| 14 |
+
# Fallback if not found
|
| 15 |
+
SYSTEM_PROMPT = """You are a senior security researcher and pentester. Your task is to analyze code commits (diffs) to determine if they introduce exploitable vulnerabilities.
|
| 16 |
+
|
| 17 |
+
You operate in a multi-step environment (up to 5 steps). You can request more context, analyze your thoughts, or issue a final verdict.
|
| 18 |
+
|
| 19 |
+
### Action Format
|
| 20 |
+
You MUST respond with exactly ONE action per turn, wrapped in XML tags:
|
| 21 |
+
|
| 22 |
+
1. **Request Context:** Use this if you need to see the full content of a file listed in 'available_files'.
|
| 23 |
+
<action>
|
| 24 |
+
<action_type>request_context</action_type>
|
| 25 |
+
<file_path>filename.c</file_path>
|
| 26 |
+
</action>
|
| 27 |
+
|
| 28 |
+
2. **Analyze:** Use this for your internal Chain-of-Thought reasoning. Be detailed.
|
| 29 |
+
<action>
|
| 30 |
+
<action_type>analyze</action_type>
|
| 31 |
+
<reasoning>Your detailed step-by-step security analysis here...</reasoning>
|
| 32 |
+
</action>
|
| 33 |
+
|
| 34 |
+
3. **Verdict:** Use this to terminate the episode with your final judgment.
|
| 35 |
+
<action>
|
| 36 |
+
<action_type>verdict</action_type>
|
| 37 |
+
<is_vulnerable>true/false</is_vulnerable>
|
| 38 |
+
<vuln_type>CWE-XX (e.g., CWE-89)</vuln_type>
|
| 39 |
+
<exploit_sketch>Brief description of how this could be exploited...</exploit_sketch>
|
| 40 |
+
</action>
|
| 41 |
+
|
| 42 |
+
### Rules & Constraints
|
| 43 |
+
- If the code is safe, set is_vulnerable to false and vuln_type to NONE.
|
| 44 |
+
- Be specific in exploit_sketch: name the attack vector (e.g., buffer overflow via unchecked memcpy).
|
| 45 |
+
- Common CWE types: CWE-89 (SQLi), CWE-79 (XSS), CWE-78 (Command Inj), CWE-22 (Path Traversal), CWE-119 (Buffer Overflow), CWE-476 (Null Dereference), CWE-190 (Integer Overflow).
|
| 46 |
+
- You have a maximum of 5 steps per episode.
|
| 47 |
+
- Context requests have a small cost; be efficient.
|
| 48 |
+
- Verifiable rewards (RLVR) are based on the accuracy of your final verdict and the presence of correct exploit keywords.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def format_prompt(diff: str, available_files: list[str] = None) -> str:
|
| 53 |
+
"""Format the diff into the expected model prompt."""
|
| 54 |
+
files_str = ", ".join(available_files) if available_files else "None"
|
| 55 |
+
|
| 56 |
+
user_prompt = f"""### Input Diff
|
| 57 |
+
{diff}
|
| 58 |
+
|
| 59 |
+
### Environment Info
|
| 60 |
+
- Available Files: {files_str}
|
| 61 |
+
- Current Step: 0/5
|
| 62 |
+
|
| 63 |
+
Please provide your next action in XML format:"""
|
| 64 |
+
|
| 65 |
+
return (
|
| 66 |
+
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
| 67 |
+
f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 68 |
+
f"{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def load_model(model_path: str, is_lora: bool = False, base_model: str = None) -> tuple[Any, Any]:
|
| 72 |
+
"""
|
| 73 |
+
Load the LLM and tokenizer for inference.
|
| 74 |
+
"""
|
| 75 |
+
import torch
|
| 76 |
+
|
| 77 |
+
if is_lora:
|
| 78 |
+
if not base_model:
|
| 79 |
+
raise ValueError("base_model is required if is_lora=True")
|
| 80 |
+
from unsloth import FastLanguageModel
|
| 81 |
+
from peft import PeftModel
|
| 82 |
+
|
| 83 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 84 |
+
model_name=base_model,
|
| 85 |
+
max_seq_length=2048,
|
| 86 |
+
load_in_4bit=True,
|
| 87 |
+
)
|
| 88 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 89 |
+
FastLanguageModel.for_inference(model)
|
| 90 |
+
else:
|
| 91 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 92 |
+
|
| 93 |
+
device_map = "auto" if torch.cuda.is_available() else None
|
| 94 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 95 |
+
model_path,
|
| 96 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 97 |
+
device_map=device_map
|
| 98 |
+
)
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 100 |
+
|
| 101 |
+
return model, tokenizer
|
| 102 |
+
|
| 103 |
+
def generate(model: Any, tokenizer: Any, prompt: str, max_new_tokens: int = 256) -> str:
|
| 104 |
+
import torch
|
| 105 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 106 |
+
|
| 107 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
output = model.generate(
|
| 111 |
+
**inputs,
|
| 112 |
+
max_new_tokens=max_new_tokens,
|
| 113 |
+
temperature=0.1,
|
| 114 |
+
do_sample=False,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 118 |
+
return response
|
commitguard_env/models.py
CHANGED
|
@@ -59,3 +59,12 @@ class DevignSample:
|
|
| 59 |
target_file: Optional[str] = None
|
| 60 |
files: Optional[dict[str, str]] = None
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
target_file: Optional[str] = None
|
| 60 |
files: Optional[dict[str, str]] = None
|
| 61 |
|
| 62 |
+
|
| 63 |
+
@dataclass(frozen=True, slots=True)
|
| 64 |
+
class ScanResult:
|
| 65 |
+
is_vulnerable: bool
|
| 66 |
+
cwe: Optional[str]
|
| 67 |
+
exploit_sketch: Optional[str]
|
| 68 |
+
raw_response: str
|
| 69 |
+
parse_error: Optional[str] = None
|
| 70 |
+
|
commitguard_env/scanner.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from .inference import format_prompt, generate, load_model
|
| 6 |
+
from .models import ScanResult
|
| 7 |
+
from .parse_action import parse_action
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CommitGuardScanner:
|
| 11 |
+
"""
|
| 12 |
+
Scanner for CommitGuard vulnerabilities.
|
| 13 |
+
Keeps the model in memory to allow fast scanning of multiple diffs.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model_path: str = "inmodel-labs/commitguard-llama-3b", is_lora: bool = False, base_model: str = None) -> None:
|
| 17 |
+
self.model_path = model_path
|
| 18 |
+
self.is_lora = is_lora
|
| 19 |
+
self.base_model = base_model
|
| 20 |
+
self.model: Any = None
|
| 21 |
+
self.tokenizer: Any = None
|
| 22 |
+
|
| 23 |
+
def load(self) -> None:
|
| 24 |
+
"""Load the model and tokenizer into memory."""
|
| 25 |
+
if self.model is None or self.tokenizer is None:
|
| 26 |
+
self.model, self.tokenizer = load_model(self.model_path, self.is_lora, self.base_model)
|
| 27 |
+
|
| 28 |
+
def scan(self, diff: str, available_files: list[str] = None) -> ScanResult:
|
| 29 |
+
"""
|
| 30 |
+
Scan a given diff for vulnerabilities.
|
| 31 |
+
"""
|
| 32 |
+
self.load()
|
| 33 |
+
|
| 34 |
+
prompt = format_prompt(diff, available_files)
|
| 35 |
+
response = generate(self.model, self.tokenizer, prompt)
|
| 36 |
+
action = parse_action(response)
|
| 37 |
+
|
| 38 |
+
# Map to ScanResult
|
| 39 |
+
return ScanResult(
|
| 40 |
+
is_vulnerable=action.is_vulnerable if action.is_vulnerable is not None else False,
|
| 41 |
+
cwe=action.vuln_type,
|
| 42 |
+
exploit_sketch=action.exploit_sketch,
|
| 43 |
+
raw_response=response,
|
| 44 |
+
parse_error=action.parse_error
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def scan(diff: str, model_path: str = "inmodel-labs/commitguard-llama-3b", is_lora: bool = False, base_model: str = None) -> ScanResult:
|
| 49 |
+
"""
|
| 50 |
+
Convenience method to scan a single diff. Loads the model, scans, and returns the result.
|
| 51 |
+
If scanning multiple diffs, prefer instantiating CommitGuardScanner directly to avoid reloading the model.
|
| 52 |
+
"""
|
| 53 |
+
scanner = CommitGuardScanner(model_path=model_path, is_lora=is_lora, base_model=base_model)
|
| 54 |
+
return scanner.scan(diff)
|
pyproject.toml
CHANGED
|
@@ -33,6 +33,7 @@ train = [
|
|
| 33 |
]
|
| 34 |
|
| 35 |
[project.scripts]
|
|
|
|
| 36 |
server = "commitguard_env.server:main"
|
| 37 |
|
| 38 |
[tool.setuptools]
|
|
|
|
| 33 |
]
|
| 34 |
|
| 35 |
[project.scripts]
|
| 36 |
+
commitguard = "commitguard_env.cli:main"
|
| 37 |
server = "commitguard_env.server:main"
|
| 38 |
|
| 39 |
[tool.setuptools]
|