commitguard-env / tests /test_reward.py
Nitishkumar-ai's picture
Deployment Build (Final): Professional Structure + Blog
95cbc5b
from __future__ import annotations
from commitguard_env.models import CommitGuardAction
from commitguard_env.reward import compute_reward
def test_reward_true_positive_correct_cwe_and_exploit_match() -> None:
a = CommitGuardAction(
action_type="verdict",
is_vulnerable=True,
vuln_type="CWE-89",
exploit_sketch="This is classic SQL injection: SELECT ... WHERE ... concat",
)
r = compute_reward(
action=a,
is_vulnerable=True,
cwe="CWE-89",
target_file="db.py",
cwe_keywords={"CWE-89": ["sql", "select", "where", "concat", "injection"]},
context_requests=0,
)
assert r == 2.0
def test_reward_true_positive_wrong_cwe_same_family() -> None:
# CWE-79 and CWE-89 are both in the "injection" family -> family bonus = 0.5 * 0.5
a = CommitGuardAction(action_type="verdict", is_vulnerable=True, vuln_type="CWE-79", exploit_sketch="sql injection")
r = compute_reward(
action=a,
is_vulnerable=True,
cwe="CWE-89",
target_file="db.py",
cwe_keywords={"CWE-89": ["sql"]},
context_requests=0,
)
# 1.0 (TP) + 0.25 (family match: 0.5 * 0.5) + 0.5 (keyword: 1/1) = 1.75
assert r == 1.75
def test_reward_false_positive() -> None:
a = CommitGuardAction(action_type="verdict", is_vulnerable=True, vuln_type="CWE-89", exploit_sketch="sql")
r = compute_reward(
action=a,
is_vulnerable=False,
cwe=None,
target_file=None,
cwe_keywords={},
context_requests=0,
)
assert r == -1.0
def test_reward_false_negative() -> None:
a = CommitGuardAction(action_type="verdict", is_vulnerable=False, vuln_type="NONE", exploit_sketch="")
r = compute_reward(
action=a,
is_vulnerable=True,
cwe="CWE-89",
target_file="db.py",
cwe_keywords={"CWE-89": ["sql"]},
context_requests=0,
)
assert r == -0.5
def test_reward_malformed_action_penalty_no_crash() -> None:
a = CommitGuardAction(action_type="analyze", raw_action="<<<", parse_error="bad_xml")
r = compute_reward(
action=a,
is_vulnerable=True,
cwe="CWE-89",
target_file="db.py",
cwe_keywords={"CWE-89": ["sql"]},
context_requests=0,
)
assert r == -0.5