Spaces:
Running on A10G
Running on A10G
| import argparse | |
| import json | |
| import random | |
| from collections import Counter | |
| from pathlib import Path | |
| def _read_jsonl(path: Path) -> list[dict]: | |
| rows = [] | |
| for line in path.read_text(encoding="utf-8").splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| rows.append(json.loads(line)) | |
| return rows | |
| def _write_jsonl(path: Path, rows: list[dict]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8", newline="\n") as f: | |
| for r in rows: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| # --------------------------------------------------------------------------- | |
| # Fix 2: CWE classification using vulnerable lines, not the whole function. | |
| # Scored rules — highest-scoring match wins. Falls back to CWE-OTHER. | |
| # --------------------------------------------------------------------------- | |
| _CWE_RULES: list[tuple[str, list[str], int]] = [ | |
| ("CWE-119", ["memcpy", "strcpy", "strcat", "strncpy", "memmove", "sprintf", | |
| "gets(", "buffer", "overflow", "oob", "av_malloc", "av_realloc", | |
| "realloc", "malloc", "alloc", "g_malloc", "g_realloc", | |
| "qemu_malloc", "len ", "length", "copy_from", "copy_to"], 5), | |
| ("CWE-476", ["null", "nullptr", "!= null", "== null", "if (!", | |
| "dereference", "segfault", "!obj", "!ctx", "!s->", "!p"], 5), | |
| ("CWE-189", ["integer overflow", "signedness", "truncat", "wrap", | |
| "size_t", "underflow", "narrowing", "(int)", "(uint", | |
| "(unsigned)", ">> ", "<< ", "0xffff", "max_", "min_"], 5), | |
| ("CWE-78", ["system(", "popen(", "exec(", "execve", "shell", | |
| "command", "subprocess"], 8), | |
| ("CWE-22", ["../", "..\\", "traversal", "chroot", "realpath", | |
| "canonicalize", "symlink", "path"], 7), | |
| ("CWE-89", ["sql", "query", "select ", "insert ", "union ", | |
| "prepared", "sqlite", "mysql"], 7), | |
| ("CWE-79", ["xss", "innerhtml", "script", "sanitize", "escape", | |
| "htmlentit", "content-type"], 6), | |
| ("CWE-20", ["valid", "saniti", "untrusted", "input", "bounds", | |
| "assert", "range", "check", "error", "return -1", | |
| "goto fail", "goto err", "goto out"], 2), | |
| ] | |
| def infer_cwe(vul_lines_code: list[str], func: str) -> str: | |
| vul_text = " ".join(vul_lines_code).lower() if vul_lines_code else "" | |
| func_text = func.lower() | |
| best_cwe, best_score = "CWE-OTHER", 0 | |
| for cwe, keywords, weight in _CWE_RULES: | |
| vul_hits = sum(1 for k in keywords if k in vul_text) if vul_text else 0 | |
| func_hits = sum(1 for k in keywords if k in func_text) | |
| score = vul_hits * weight + func_hits * (weight // 2) | |
| if score > best_score: | |
| best_cwe, best_score = cwe, score | |
| if best_score < 2: | |
| return "CWE-OTHER" | |
| return best_cwe | |
| # --------------------------------------------------------------------------- | |
| # Fix 1: Real unified diffs from per-line vulnerability labels. | |
| # --------------------------------------------------------------------------- | |
| def _build_diff(func: str, label: list[int], rng: random.Random, is_vuln: bool) -> str: | |
| lines = func.splitlines() | |
| if is_vuln and label and len(label) == len(lines): | |
| changed_indices = {i for i, l in enumerate(label) if l == 1} | |
| elif is_vuln and label and any(l == 1 for l in label): | |
| changed_indices = {i for i, l in enumerate(label) if l == 1} | |
| else: | |
| block_size = max(1, min(5, len(lines) // 4)) | |
| start = rng.randint(0, max(0, len(lines) - block_size)) | |
| changed_indices = set(range(start, min(start + block_size, len(lines)))) | |
| if not changed_indices: | |
| changed_indices = {0} | |
| ctx = 3 | |
| visible: set[int] = set() | |
| for ci in changed_indices: | |
| for offset in range(-ctx, ctx + 1): | |
| idx = ci + offset | |
| if 0 <= idx < len(lines): | |
| visible.add(idx) | |
| sorted_visible = sorted(visible) | |
| hunks: list[list[int]] = [] | |
| current_hunk: list[int] = [] | |
| for idx in sorted_visible: | |
| if current_hunk and idx > current_hunk[-1] + 1: | |
| hunks.append(current_hunk) | |
| current_hunk = [idx] | |
| else: | |
| current_hunk.append(idx) | |
| if current_hunk: | |
| hunks.append(current_hunk) | |
| diff_parts = ["--- a/source.c", "+++ b/source.c"] | |
| for hunk in hunks: | |
| start_line = hunk[0] + 1 | |
| hunk_size = len(hunk) | |
| diff_parts.append(f"@@ -{start_line},{hunk_size} +{start_line},{hunk_size} @@") | |
| for idx in hunk: | |
| line = lines[idx] | |
| if idx in changed_indices: | |
| diff_parts.append(f"+{line}") | |
| else: | |
| diff_parts.append(f" {line}") | |
| return "\n".join(diff_parts) | |
| # --------------------------------------------------------------------------- | |
| # Fix 3: CWE rebalancing — cap dominant CWEs, merge tiny ones. | |
| # --------------------------------------------------------------------------- | |
| _MAX_PER_CWE_FRAC = 0.25 | |
| _MIN_CWE_SAMPLES = 20 | |
| def _rebalance(samples: list[dict], rng: random.Random, limit: int) -> list[dict]: | |
| by_cwe: dict[str, list[dict]] = {} | |
| for s in samples: | |
| by_cwe.setdefault(s["cwe"] or "CWE-OTHER", []).append(s) | |
| for cwe, items in list(by_cwe.items()): | |
| if len(items) < _MIN_CWE_SAMPLES and cwe != "CWE-OTHER": | |
| by_cwe.setdefault("CWE-OTHER", []).extend(items) | |
| for item in items: | |
| item["cwe"] = "CWE-OTHER" | |
| del by_cwe[cwe] | |
| cap = int(limit * _MAX_PER_CWE_FRAC) | |
| kept: list[dict] = [] | |
| for cwe, items in by_cwe.items(): | |
| rng.shuffle(items) | |
| kept.extend(items[:cap]) | |
| rng.shuffle(kept) | |
| return kept[:limit] | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Preprocess Devign-derived samples into CommitGuard JSONL.") | |
| ap.add_argument("--in", dest="inp", type=Path, default=None, help="Optional input JSONL.") | |
| ap.add_argument("--out", dest="out", type=Path, default=Path("data/devign_filtered.jsonl")) | |
| ap.add_argument("--test-out", dest="test_out", type=Path, default=Path("data/devign_test.jsonl")) | |
| ap.add_argument("--limit", type=int, default=5000) | |
| ap.add_argument("--test-limit", type=int, default=100) | |
| ap.add_argument("--seed", type=int, default=42) | |
| args = ap.parse_args() | |
| rng = random.Random(args.seed) | |
| if args.inp is None: | |
| try: | |
| from datasets import load_dataset | |
| print("Loading DetectVul/devign from Hugging Face...") | |
| ds = load_dataset('DetectVul/devign', split='train') | |
| raw_rows = list(ds) | |
| print(f"Loaded {len(raw_rows)} rows from HF.") | |
| except Exception as e: | |
| print(f"Failed to load from HF: {e}") | |
| return | |
| else: | |
| raw_rows = _read_jsonl(args.inp) | |
| all_samples: list[dict] = [] | |
| # Process all rows first | |
| seen_ids = set() | |
| for i, r in enumerate(raw_rows): | |
| func = r.get("func") | |
| if not func: | |
| continue | |
| if len(func.split("\n")) > 80: | |
| continue | |
| target = bool(r.get("target", False)) | |
| label = r.get("label", []) | |
| vul_lines_code = [] | |
| vl = r.get("vul_lines") | |
| if vl and isinstance(vl, dict): | |
| vul_lines_code = vl.get("code", []) | |
| cwe = infer_cwe(vul_lines_code, func) if target else None | |
| diff = _build_diff(func, label, rng, target) | |
| # Ensure unique sample_id | |
| original_id = str(r.get("commit_id") or r.get("id") or f"row-{i}") | |
| sample_id = original_id | |
| suffix = 0 | |
| while sample_id in seen_ids: | |
| suffix += 1 | |
| sample_id = f"{original_id}_{suffix}" | |
| seen_ids.add(sample_id) | |
| target_file = "source.c" | |
| sample = { | |
| "sample_id": sample_id, | |
| "diff": diff, | |
| "available_files": [target_file], | |
| "is_vulnerable": target, | |
| "cwe": cwe, | |
| "target_file": target_file, | |
| "files": {target_file: func}, | |
| } | |
| all_samples.append(sample) | |
| print(f"Total processed samples: {len(all_samples)}") | |
| # Shuffle and split to ensure NO overlap | |
| rng.shuffle(all_samples) | |
| # We want to ensure test set has all CWEs if possible | |
| # Let's pick test set first by picking a few from each CWE | |
| test_samples: list[dict] = [] | |
| vuln_all = [s for s in all_samples if s["is_vulnerable"]] | |
| safe_all = [s for s in all_samples if not s["is_vulnerable"]] | |
| by_cwe: dict[str, list[dict]] = {} | |
| for s in vuln_all: | |
| by_cwe.setdefault(s["cwe"] or "CWE-OTHER", []).append(s) | |
| # Try to pick 5 from each CWE for test set | |
| for cwe in by_cwe: | |
| test_samples.extend(by_cwe[cwe][:5]) | |
| by_cwe[cwe] = by_cwe[cwe][5:] | |
| # Fill the rest of test set with random samples (half vuln, half safe) | |
| remaining_vuln = [s for items in by_cwe.values() for s in items] | |
| needed_vuln = (args.test_limit // 2) - sum(1 for s in test_samples if s["is_vulnerable"]) | |
| if needed_vuln > 0: | |
| test_samples.extend(remaining_vuln[:needed_vuln]) | |
| remaining_vuln = remaining_vuln[needed_vuln:] | |
| needed_safe = args.test_limit - len(test_samples) | |
| test_samples.extend(safe_all[:needed_safe]) | |
| safe_all = safe_all[needed_safe:] | |
| # Now remaining samples go to train | |
| train_pool_vuln = remaining_vuln | |
| train_pool_safe = safe_all | |
| print(f"Test set: {len(test_samples)} samples") | |
| _write_jsonl(args.test_out, test_samples) | |
| # Rebalance training set | |
| target_each = args.limit // 2 | |
| vuln_keep = _rebalance(train_pool_vuln, rng, target_each) | |
| safe_keep = rng.sample(train_pool_safe, min(target_each, len(train_pool_safe))) | |
| train_rows = vuln_keep + safe_keep | |
| rng.shuffle(train_rows) | |
| _write_jsonl(args.out, train_rows) | |
| print(f"Wrote {len(train_rows)} training samples to {args.out}") | |
| print(f"Wrote {len(test_samples)} test samples to {args.test_out}") | |
| if __name__ == "__main__": | |
| main() | |