File size: 7,873 Bytes
7b4f5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
Tests for SecurityAgent β€” static scan only (no LLM / GPU required).
"""
from __future__ import annotations

import json
import pathlib
import pytest

from agents.security_agent import SecurityAgent
from api.models import Severity
from tools.code_parser import FileEntry

# ──────────────────────────────────────────
# Fixtures
# ──────────────────────────────────────────

FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures"


@pytest.fixture(scope="module")
def vulnerable_code() -> str:
    return (FIXTURES_DIR / "vulnerable_ml_code.py").read_text(encoding="utf-8")


@pytest.fixture(scope="module")
def clean_code() -> str:
    return (FIXTURES_DIR / "clean_ml_code.py").read_text(encoding="utf-8")


@pytest.fixture(scope="module")
def expected() -> dict:
    return json.loads((FIXTURES_DIR / "expected_findings.json").read_text(encoding="utf-8"))


@pytest.fixture(scope="module")
def agent() -> SecurityAgent:
    return SecurityAgent()


@pytest.fixture(scope="module")
def vulnerable_files(vulnerable_code: str) -> list[FileEntry]:
    return [("vulnerable_ml_code.py", vulnerable_code)]


@pytest.fixture(scope="module")
def vulnerable_findings(agent: SecurityAgent, vulnerable_files: list[FileEntry]):
    return agent.static_scan(vulnerable_files)


# ──────────────────────────────────────────
# Tests
# ──────────────────────────────────────────

class TestPromptInjectionDetection:
    def test_detects_prompt_injection(self, vulnerable_findings):
        """LLM01: Should detect user input concatenated directly into prompt."""
        llm01_findings = [
            f for f in vulnerable_findings
            if f.owasp_category == "LLM01" or "Prompt Injection" in f.title
        ]
        assert len(llm01_findings) > 0, (
            "Expected at least one LLM01 Prompt Injection finding"
        )

    def test_prompt_injection_severity(self, vulnerable_findings):
        """Prompt injection must be rated critical or high."""
        llm01_findings = [
            f for f in vulnerable_findings
            if f.owasp_category == "LLM01" or "Prompt Injection" in f.title
        ]
        assert any(
            f.severity in (Severity.critical, Severity.high) for f in llm01_findings
        ), "Prompt injection finding must be critical or high severity"


class TestPickleDetection:
    def test_detects_pickle_deserialization(self, vulnerable_findings):
        """A04 / CWE-502: Should detect pickle.load() from untrusted source."""
        pickle_findings = [
            f for f in vulnerable_findings
            if (f.cwe and "502" in f.cwe) or "pickle" in f.title.lower() or "Insecure Design" in (f.owasp_category or "")
        ]
        assert len(pickle_findings) > 0, (
            "Expected CWE-502 finding for pickle.load()"
        )

    def test_pickle_is_critical(self, vulnerable_findings):
        pickle_findings = [
            f for f in vulnerable_findings
            if f.cwe and "502" in f.cwe
        ]
        if pickle_findings:
            assert any(f.severity == Severity.critical for f in pickle_findings)


class TestHardcodedAPIKeyDetection:
    def test_detects_hardcoded_api_key(self, vulnerable_findings):
        """LLM06 / A07: Should detect hardcoded HF_TOKEN and OpenAI key."""
        key_findings = [
            f for f in vulnerable_findings
            if f.owasp_category in ("LLM06", "A07")
            or any(kw in f.title.lower() for kw in ("hardcoded", "api key", "token", "secret"))
        ]
        assert len(key_findings) > 0, (
            "Expected at least one hardcoded API key finding (LLM06 / A07)"
        )

    def test_hardcoded_key_severity_high_or_critical(self, vulnerable_findings):
        key_findings = [
            f for f in vulnerable_findings
            if f.owasp_category in ("LLM06", "A07")
        ]
        if key_findings:
            assert any(f.severity in (Severity.critical, Severity.high) for f in key_findings)


class TestEvalDetection:
    def test_detects_eval_of_llm_output(self, vulnerable_findings):
        """LLM02: Should detect eval() used on model output."""
        llm02_findings = [
            f for f in vulnerable_findings
            if f.owasp_category == "LLM02"
            or any(kw in f.title.lower() for kw in ("eval", "insecure output"))
        ]
        assert len(llm02_findings) > 0, (
            "Expected LLM02 finding for eval(llm_output)"
        )


class TestSeverityRanking:
    def test_severity_ranking_order(self, vulnerable_findings):
        """Critical findings must appear before high, which appear before medium."""
        if len(vulnerable_findings) < 2:
            pytest.skip("Need at least 2 findings to test ordering")

        severity_order = {
            Severity.critical: 0,
            Severity.high: 1,
            Severity.medium: 2,
            Severity.low: 3,
            Severity.info: 4,
        }
        for i in range(len(vulnerable_findings) - 1):
            a = severity_order[vulnerable_findings[i].severity]
            b = severity_order[vulnerable_findings[i + 1].severity]
            assert a <= b, (
                f"Finding {i} ({vulnerable_findings[i].severity}) should not come after "
                f"finding {i+1} ({vulnerable_findings[i+1].severity})"
            )

    def test_has_critical_findings(self, vulnerable_findings):
        """Vulnerable code must produce at least one critical finding."""
        critical = [f for f in vulnerable_findings if f.severity == Severity.critical]
        assert len(critical) > 0, "Expected at least one critical severity finding"


class TestOWASPLLMCoverage:
    def test_owasp_llm_coverage(self, vulnerable_findings):
        """
        Assert findings cover the key OWASP LLM Top-10 categories
        present in the vulnerable fixture.
        """
        found_categories = {f.owasp_category for f in vulnerable_findings if f.owasp_category}
        # These categories have triggers in the vulnerable fixture
        expected_categories = {"LLM01", "LLM02", "LLM06"}
        missing = expected_categories - found_categories
        assert not missing, (
            f"Missing OWASP LLM categories in findings: {missing}. "
            f"Found: {found_categories}"
        )

    def test_no_false_positives_on_clean_code(self, agent: SecurityAgent, clean_code: str):
        """Clean code should produce significantly fewer critical findings."""
        clean_files: list[FileEntry] = [("clean_ml_code.py", clean_code)]
        clean_findings = agent.static_scan(clean_files)
        critical_clean = [f for f in clean_findings if f.severity == Severity.critical]
        # Clean code may still trigger some pattern matches, but should have far fewer
        assert len(critical_clean) < 3, (
            f"Clean code produced {len(critical_clean)} critical findings β€” too many false positives"
        )


class TestFindingSchema:
    def test_all_findings_have_required_fields(self, vulnerable_findings):
        """Every finding must have severity, title, and explanation."""
        for i, finding in enumerate(vulnerable_findings):
            assert finding.severity is not None, f"Finding {i} missing severity"
            assert finding.title, f"Finding {i} missing title"
            assert finding.explanation, f"Finding {i} missing explanation"

    def test_findings_are_not_empty(self, vulnerable_findings):
        assert len(vulnerable_findings) > 0, (
            "SecurityAgent.static_scan() returned no findings for vulnerable code"
        )