File size: 2,947 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations

from .models import CommitGuardAction

_CWE_FAMILIES: dict[str, str] = {
    # Memory and Buffer issues
    "CWE-119": "memory-safety", "CWE-120": "memory-safety", "CWE-121": "memory-safety",
    "CWE-122": "memory-safety", "CWE-125": "memory-safety", "CWE-787": "memory-safety",
    # Input and Validation issues (often overlap with memory safety)
    "CWE-20": "input-validation", "CWE-190": "input-validation", "CWE-189": "input-validation",
    "CWE-191": "input-validation",
    # Pointers
    "CWE-476": "null-pointer",
    # Logic and Traversal
    "CWE-22": "traversal",
    # Injections
    "CWE-78": "injection", "CWE-89": "injection", "CWE-79": "injection",
}


def _cwe_partial_score(predicted: str | None, actual: str | None) -> float:
    if not predicted or not actual:
        return 0.0
    p, a = predicted.strip().upper(), actual.strip().upper()
    if p == a:
        return 1.0
    pf = _CWE_FAMILIES.get(p, "")
    af = _CWE_FAMILIES.get(a, "")
    if pf and pf == af:
        return 0.5
    return 0.0


def compute_reward(
    *,
    action: CommitGuardAction,
    is_vulnerable: bool | None,
    cwe: str | None,
    target_file: str | None,
    cwe_keywords: dict[str, list[str]] | None,
    context_requests: int,
) -> float:
    # Graduated context penalty: first request is free, then escalating
    if context_requests <= 1:
        reward = 0.0
    else:
        reward = -0.05 * (context_requests - 1)

    if action.parse_error:
        return reward - 0.5

    if action.action_type == "analyze":
        reasoning_len = len(action.reasoning or "")
        if reasoning_len > 50:
            reward += min(0.05, 0.001 * (reasoning_len // 10))
        return reward

    if action.action_type == "request_context":
        return reward

    if action.action_type != "verdict":
        return reward

    if is_vulnerable is None:
        return reward

    pred = bool(action.is_vulnerable) if action.is_vulnerable is not None else None
    if pred is None:
        return reward - 0.5

    # True positive
    if pred is True and is_vulnerable is True:
        reward += 1.0

        # CWE scoring: exact match = 0.5, same family = 0.25
        cwe_score = _cwe_partial_score(action.vuln_type, cwe)
        reward += 0.5 * cwe_score

        # Keyword match (continuous, up to 0.5)
        kws = (cwe_keywords or {}).get(cwe or "", []) if cwe else []
        if kws:
            sketch = (action.exploit_sketch or "").lower()
            matches = sum(1 for k in kws if k.lower() in sketch)
            reward += 0.5 * (matches / len(kws))

        return reward

    # False positive
    if pred is True and is_vulnerable is False:
        return reward - 1.0

    # False negative
    if pred is False and is_vulnerable is True:
        return reward - 0.5

    # True negative
    if pred is False and is_vulnerable is False:
        return reward + 1.0

    return reward