commitguard-env / scripts /preprocess_devign.py
Nitishkumar-ai's picture
Deployment Build (Final): Professional Structure + Blog
95cbc5b
import argparse
import json
import random
from collections import Counter
from pathlib import Path
def _read_jsonl(path: Path) -> list[dict]:
rows = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def _write_jsonl(path: Path, rows: list[dict]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8", newline="\n") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
# ---------------------------------------------------------------------------
# Fix 2: CWE classification using vulnerable lines, not the whole function.
# Scored rules — highest-scoring match wins. Falls back to CWE-OTHER.
# ---------------------------------------------------------------------------
_CWE_RULES: list[tuple[str, list[str], int]] = [
("CWE-119", ["memcpy", "strcpy", "strcat", "strncpy", "memmove", "sprintf",
"gets(", "buffer", "overflow", "oob", "av_malloc", "av_realloc",
"realloc", "malloc", "alloc", "g_malloc", "g_realloc",
"qemu_malloc", "len ", "length", "copy_from", "copy_to"], 5),
("CWE-476", ["null", "nullptr", "!= null", "== null", "if (!",
"dereference", "segfault", "!obj", "!ctx", "!s->", "!p"], 5),
("CWE-189", ["integer overflow", "signedness", "truncat", "wrap",
"size_t", "underflow", "narrowing", "(int)", "(uint",
"(unsigned)", ">> ", "<< ", "0xffff", "max_", "min_"], 5),
("CWE-78", ["system(", "popen(", "exec(", "execve", "shell",
"command", "subprocess"], 8),
("CWE-22", ["../", "..\\", "traversal", "chroot", "realpath",
"canonicalize", "symlink", "path"], 7),
("CWE-89", ["sql", "query", "select ", "insert ", "union ",
"prepared", "sqlite", "mysql"], 7),
("CWE-79", ["xss", "innerhtml", "script", "sanitize", "escape",
"htmlentit", "content-type"], 6),
("CWE-20", ["valid", "saniti", "untrusted", "input", "bounds",
"assert", "range", "check", "error", "return -1",
"goto fail", "goto err", "goto out"], 2),
]
def infer_cwe(vul_lines_code: list[str], func: str) -> str:
vul_text = " ".join(vul_lines_code).lower() if vul_lines_code else ""
func_text = func.lower()
best_cwe, best_score = "CWE-OTHER", 0
for cwe, keywords, weight in _CWE_RULES:
vul_hits = sum(1 for k in keywords if k in vul_text) if vul_text else 0
func_hits = sum(1 for k in keywords if k in func_text)
score = vul_hits * weight + func_hits * (weight // 2)
if score > best_score:
best_cwe, best_score = cwe, score
if best_score < 2:
return "CWE-OTHER"
return best_cwe
# ---------------------------------------------------------------------------
# Fix 1: Real unified diffs from per-line vulnerability labels.
# ---------------------------------------------------------------------------
def _build_diff(func: str, label: list[int], rng: random.Random, is_vuln: bool) -> str:
lines = func.splitlines()
if is_vuln and label and len(label) == len(lines):
changed_indices = {i for i, l in enumerate(label) if l == 1}
elif is_vuln and label and any(l == 1 for l in label):
changed_indices = {i for i, l in enumerate(label) if l == 1}
else:
block_size = max(1, min(5, len(lines) // 4))
start = rng.randint(0, max(0, len(lines) - block_size))
changed_indices = set(range(start, min(start + block_size, len(lines))))
if not changed_indices:
changed_indices = {0}
ctx = 3
visible: set[int] = set()
for ci in changed_indices:
for offset in range(-ctx, ctx + 1):
idx = ci + offset
if 0 <= idx < len(lines):
visible.add(idx)
sorted_visible = sorted(visible)
hunks: list[list[int]] = []
current_hunk: list[int] = []
for idx in sorted_visible:
if current_hunk and idx > current_hunk[-1] + 1:
hunks.append(current_hunk)
current_hunk = [idx]
else:
current_hunk.append(idx)
if current_hunk:
hunks.append(current_hunk)
diff_parts = ["--- a/source.c", "+++ b/source.c"]
for hunk in hunks:
start_line = hunk[0] + 1
hunk_size = len(hunk)
diff_parts.append(f"@@ -{start_line},{hunk_size} +{start_line},{hunk_size} @@")
for idx in hunk:
line = lines[idx]
if idx in changed_indices:
diff_parts.append(f"+{line}")
else:
diff_parts.append(f" {line}")
return "\n".join(diff_parts)
# ---------------------------------------------------------------------------
# Fix 3: CWE rebalancing — cap dominant CWEs, merge tiny ones.
# ---------------------------------------------------------------------------
_MAX_PER_CWE_FRAC = 0.25
_MIN_CWE_SAMPLES = 20
def _rebalance(samples: list[dict], rng: random.Random, limit: int) -> list[dict]:
by_cwe: dict[str, list[dict]] = {}
for s in samples:
by_cwe.setdefault(s["cwe"] or "CWE-OTHER", []).append(s)
for cwe, items in list(by_cwe.items()):
if len(items) < _MIN_CWE_SAMPLES and cwe != "CWE-OTHER":
by_cwe.setdefault("CWE-OTHER", []).extend(items)
for item in items:
item["cwe"] = "CWE-OTHER"
del by_cwe[cwe]
cap = int(limit * _MAX_PER_CWE_FRAC)
kept: list[dict] = []
for cwe, items in by_cwe.items():
rng.shuffle(items)
kept.extend(items[:cap])
rng.shuffle(kept)
return kept[:limit]
def main() -> None:
ap = argparse.ArgumentParser(description="Preprocess Devign-derived samples into CommitGuard JSONL.")
ap.add_argument("--in", dest="inp", type=Path, default=None, help="Optional input JSONL.")
ap.add_argument("--out", dest="out", type=Path, default=Path("data/devign_filtered.jsonl"))
ap.add_argument("--test-out", dest="test_out", type=Path, default=Path("data/devign_test.jsonl"))
ap.add_argument("--limit", type=int, default=5000)
ap.add_argument("--test-limit", type=int, default=100)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
rng = random.Random(args.seed)
if args.inp is None:
try:
from datasets import load_dataset
print("Loading DetectVul/devign from Hugging Face...")
ds = load_dataset('DetectVul/devign', split='train')
raw_rows = list(ds)
print(f"Loaded {len(raw_rows)} rows from HF.")
except Exception as e:
print(f"Failed to load from HF: {e}")
return
else:
raw_rows = _read_jsonl(args.inp)
all_samples: list[dict] = []
# Process all rows first
seen_ids = set()
for i, r in enumerate(raw_rows):
func = r.get("func")
if not func:
continue
if len(func.split("\n")) > 80:
continue
target = bool(r.get("target", False))
label = r.get("label", [])
vul_lines_code = []
vl = r.get("vul_lines")
if vl and isinstance(vl, dict):
vul_lines_code = vl.get("code", [])
cwe = infer_cwe(vul_lines_code, func) if target else None
diff = _build_diff(func, label, rng, target)
# Ensure unique sample_id
original_id = str(r.get("commit_id") or r.get("id") or f"row-{i}")
sample_id = original_id
suffix = 0
while sample_id in seen_ids:
suffix += 1
sample_id = f"{original_id}_{suffix}"
seen_ids.add(sample_id)
target_file = "source.c"
sample = {
"sample_id": sample_id,
"diff": diff,
"available_files": [target_file],
"is_vulnerable": target,
"cwe": cwe,
"target_file": target_file,
"files": {target_file: func},
}
all_samples.append(sample)
print(f"Total processed samples: {len(all_samples)}")
# Shuffle and split to ensure NO overlap
rng.shuffle(all_samples)
# We want to ensure test set has all CWEs if possible
# Let's pick test set first by picking a few from each CWE
test_samples: list[dict] = []
vuln_all = [s for s in all_samples if s["is_vulnerable"]]
safe_all = [s for s in all_samples if not s["is_vulnerable"]]
by_cwe: dict[str, list[dict]] = {}
for s in vuln_all:
by_cwe.setdefault(s["cwe"] or "CWE-OTHER", []).append(s)
# Try to pick 5 from each CWE for test set
for cwe in by_cwe:
test_samples.extend(by_cwe[cwe][:5])
by_cwe[cwe] = by_cwe[cwe][5:]
# Fill the rest of test set with random samples (half vuln, half safe)
remaining_vuln = [s for items in by_cwe.values() for s in items]
needed_vuln = (args.test_limit // 2) - sum(1 for s in test_samples if s["is_vulnerable"])
if needed_vuln > 0:
test_samples.extend(remaining_vuln[:needed_vuln])
remaining_vuln = remaining_vuln[needed_vuln:]
needed_safe = args.test_limit - len(test_samples)
test_samples.extend(safe_all[:needed_safe])
safe_all = safe_all[needed_safe:]
# Now remaining samples go to train
train_pool_vuln = remaining_vuln
train_pool_safe = safe_all
print(f"Test set: {len(test_samples)} samples")
_write_jsonl(args.test_out, test_samples)
# Rebalance training set
target_each = args.limit // 2
vuln_keep = _rebalance(train_pool_vuln, rng, target_each)
safe_keep = rng.sample(train_pool_safe, min(target_each, len(train_pool_safe)))
train_rows = vuln_keep + safe_keep
rng.shuffle(train_rows)
_write_jsonl(args.out, train_rows)
print(f"Wrote {len(train_rows)} training samples to {args.out}")
print(f"Wrote {len(test_samples)} test samples to {args.test_out}")
if __name__ == "__main__":
main()