codeSentry / codesentry-backend /tests /test_security_agent.py
YashashviAlva's picture
Initial commit for HF Spaces deploy
7b4f5dd
"""
Tests for SecurityAgent β€” static scan only (no LLM / GPU required).
"""
from __future__ import annotations
import json
import pathlib
import pytest
from agents.security_agent import SecurityAgent
from api.models import Severity
from tools.code_parser import FileEntry
# ──────────────────────────────────────────
# Fixtures
# ──────────────────────────────────────────
FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures"
@pytest.fixture(scope="module")
def vulnerable_code() -> str:
return (FIXTURES_DIR / "vulnerable_ml_code.py").read_text(encoding="utf-8")
@pytest.fixture(scope="module")
def clean_code() -> str:
return (FIXTURES_DIR / "clean_ml_code.py").read_text(encoding="utf-8")
@pytest.fixture(scope="module")
def expected() -> dict:
return json.loads((FIXTURES_DIR / "expected_findings.json").read_text(encoding="utf-8"))
@pytest.fixture(scope="module")
def agent() -> SecurityAgent:
return SecurityAgent()
@pytest.fixture(scope="module")
def vulnerable_files(vulnerable_code: str) -> list[FileEntry]:
return [("vulnerable_ml_code.py", vulnerable_code)]
@pytest.fixture(scope="module")
def vulnerable_findings(agent: SecurityAgent, vulnerable_files: list[FileEntry]):
return agent.static_scan(vulnerable_files)
# ──────────────────────────────────────────
# Tests
# ──────────────────────────────────────────
class TestPromptInjectionDetection:
def test_detects_prompt_injection(self, vulnerable_findings):
"""LLM01: Should detect user input concatenated directly into prompt."""
llm01_findings = [
f for f in vulnerable_findings
if f.owasp_category == "LLM01" or "Prompt Injection" in f.title
]
assert len(llm01_findings) > 0, (
"Expected at least one LLM01 Prompt Injection finding"
)
def test_prompt_injection_severity(self, vulnerable_findings):
"""Prompt injection must be rated critical or high."""
llm01_findings = [
f for f in vulnerable_findings
if f.owasp_category == "LLM01" or "Prompt Injection" in f.title
]
assert any(
f.severity in (Severity.critical, Severity.high) for f in llm01_findings
), "Prompt injection finding must be critical or high severity"
class TestPickleDetection:
def test_detects_pickle_deserialization(self, vulnerable_findings):
"""A04 / CWE-502: Should detect pickle.load() from untrusted source."""
pickle_findings = [
f for f in vulnerable_findings
if (f.cwe and "502" in f.cwe) or "pickle" in f.title.lower() or "Insecure Design" in (f.owasp_category or "")
]
assert len(pickle_findings) > 0, (
"Expected CWE-502 finding for pickle.load()"
)
def test_pickle_is_critical(self, vulnerable_findings):
pickle_findings = [
f for f in vulnerable_findings
if f.cwe and "502" in f.cwe
]
if pickle_findings:
assert any(f.severity == Severity.critical for f in pickle_findings)
class TestHardcodedAPIKeyDetection:
def test_detects_hardcoded_api_key(self, vulnerable_findings):
"""LLM06 / A07: Should detect hardcoded HF_TOKEN and OpenAI key."""
key_findings = [
f for f in vulnerable_findings
if f.owasp_category in ("LLM06", "A07")
or any(kw in f.title.lower() for kw in ("hardcoded", "api key", "token", "secret"))
]
assert len(key_findings) > 0, (
"Expected at least one hardcoded API key finding (LLM06 / A07)"
)
def test_hardcoded_key_severity_high_or_critical(self, vulnerable_findings):
key_findings = [
f for f in vulnerable_findings
if f.owasp_category in ("LLM06", "A07")
]
if key_findings:
assert any(f.severity in (Severity.critical, Severity.high) for f in key_findings)
class TestEvalDetection:
def test_detects_eval_of_llm_output(self, vulnerable_findings):
"""LLM02: Should detect eval() used on model output."""
llm02_findings = [
f for f in vulnerable_findings
if f.owasp_category == "LLM02"
or any(kw in f.title.lower() for kw in ("eval", "insecure output"))
]
assert len(llm02_findings) > 0, (
"Expected LLM02 finding for eval(llm_output)"
)
class TestSeverityRanking:
def test_severity_ranking_order(self, vulnerable_findings):
"""Critical findings must appear before high, which appear before medium."""
if len(vulnerable_findings) < 2:
pytest.skip("Need at least 2 findings to test ordering")
severity_order = {
Severity.critical: 0,
Severity.high: 1,
Severity.medium: 2,
Severity.low: 3,
Severity.info: 4,
}
for i in range(len(vulnerable_findings) - 1):
a = severity_order[vulnerable_findings[i].severity]
b = severity_order[vulnerable_findings[i + 1].severity]
assert a <= b, (
f"Finding {i} ({vulnerable_findings[i].severity}) should not come after "
f"finding {i+1} ({vulnerable_findings[i+1].severity})"
)
def test_has_critical_findings(self, vulnerable_findings):
"""Vulnerable code must produce at least one critical finding."""
critical = [f for f in vulnerable_findings if f.severity == Severity.critical]
assert len(critical) > 0, "Expected at least one critical severity finding"
class TestOWASPLLMCoverage:
def test_owasp_llm_coverage(self, vulnerable_findings):
"""
Assert findings cover the key OWASP LLM Top-10 categories
present in the vulnerable fixture.
"""
found_categories = {f.owasp_category for f in vulnerable_findings if f.owasp_category}
# These categories have triggers in the vulnerable fixture
expected_categories = {"LLM01", "LLM02", "LLM06"}
missing = expected_categories - found_categories
assert not missing, (
f"Missing OWASP LLM categories in findings: {missing}. "
f"Found: {found_categories}"
)
def test_no_false_positives_on_clean_code(self, agent: SecurityAgent, clean_code: str):
"""Clean code should produce significantly fewer critical findings."""
clean_files: list[FileEntry] = [("clean_ml_code.py", clean_code)]
clean_findings = agent.static_scan(clean_files)
critical_clean = [f for f in clean_findings if f.severity == Severity.critical]
# Clean code may still trigger some pattern matches, but should have far fewer
assert len(critical_clean) < 3, (
f"Clean code produced {len(critical_clean)} critical findings β€” too many false positives"
)
class TestFindingSchema:
def test_all_findings_have_required_fields(self, vulnerable_findings):
"""Every finding must have severity, title, and explanation."""
for i, finding in enumerate(vulnerable_findings):
assert finding.severity is not None, f"Finding {i} missing severity"
assert finding.title, f"Finding {i} missing title"
assert finding.explanation, f"Finding {i} missing explanation"
def test_findings_are_not_empty(self, vulnerable_findings):
assert len(vulnerable_findings) > 0, (
"SecurityAgent.static_scan() returned no findings for vulnerable code"
)