File size: 4,053 Bytes
42d0b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da61f5
 
 
 
 
 
 
 
 
 
 
 
 
42d0b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0462352
 
 
 
 
 
 
42d0b80
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import unittest
from importlib import import_module

try:
    from fastapi.testclient import TestClient
    server_app = import_module("server.app")
except ModuleNotFoundError:  # pragma: no cover - dependency-gated
    TestClient = None
    server_app = None


@unittest.skipIf(TestClient is None, "fastapi test dependencies are not installed in the local test env")
class ServerApiTests(unittest.TestCase):
    def setUp(self) -> None:
        server_app.env = server_app.IDPIExfilEnv()
        self.client = TestClient(server_app.app)

    def test_health_endpoint(self) -> None:
        response = self.client.get("/health")
        self.assertEqual(response.status_code, 200)
        payload = response.json()
        self.assertEqual(payload["status"], "healthy")
        self.assertEqual(payload["env"], "idpi-exfil-env")

    def test_root_landing_page_is_html(self) -> None:
        response = self.client.get("/")
        self.assertEqual(response.status_code, 200)
        self.assertIn("text/html", response.headers.get("content-type", ""))
        self.assertIn("IDPI Exfil Environment", response.text)
        self.assertIn("/docs", response.text)

    def test_web_landing_page_is_html(self) -> None:
        response = self.client.get("/web")
        self.assertEqual(response.status_code, 200)
        self.assertIn("text/html", response.headers.get("content-type", ""))
        self.assertIn("OpenEnv-compatible", response.text)

    def test_metadata_endpoint(self) -> None:
        response = self.client.get("/metadata")
        self.assertEqual(response.status_code, 200)
        payload = response.json()
        self.assertEqual(payload["name"], "idpi-exfil-env")
        self.assertIn("prompt injection", payload["description"].lower())
        self.assertEqual(payload["tasks"], ["easy", "medium", "hard"])

    def test_schema_endpoint(self) -> None:
        response = self.client.get("/schema")
        self.assertEqual(response.status_code, 200)
        payload = response.json()
        self.assertIn("action", payload)
        self.assertIn("observation", payload)
        self.assertIn("state", payload)
        self.assertIn("reward", payload)
        self.assertIn("result", payload)

    def test_mcp_endpoint_returns_jsonrpc_contract(self) -> None:
        response = self.client.post("/mcp", json={"id": "check-1"})
        self.assertEqual(response.status_code, 200)
        payload = response.json()
        self.assertEqual(payload["jsonrpc"], "2.0")
        self.assertEqual(payload["id"], "check-1")
        self.assertIn("error", payload)

    def test_reset_and_step_contract(self) -> None:
        reset = self.client.post("/reset", json={"difficulty": "easy", "seed": 11})
        self.assertEqual(reset.status_code, 200)
        observation = reset.json()
        self.assertEqual(observation["task_difficulty"], "easy")

        step = self.client.post("/step", json={"action": {"type": "ReadDoc"}})
        self.assertEqual(step.status_code, 200)
        payload = step.json()
        self.assertIn("observation", payload)
        self.assertIn("reward", payload)
        self.assertIn("done", payload)
        self.assertIn("info", payload)
        self.assertIn("policy", payload["info"])
        self.assertIn("grader_inputs", payload["info"])

    def test_reset_without_body_uses_defaults(self) -> None:
        response = self.client.post("/reset")
        self.assertEqual(response.status_code, 200)
        observation = response.json()
        self.assertEqual(observation["task_difficulty"], "easy")
        self.assertEqual(observation["step_count"], 0)

    def test_invalid_reset_payload_returns_422(self) -> None:
        response = self.client.post("/reset", json={"difficulty": "legendary"})
        self.assertEqual(response.status_code, 422)

    def test_step_before_reset_returns_400(self) -> None:
        response = self.client.post("/step", json={"action": {"type": "ReadDoc"}})
        self.assertEqual(response.status_code, 400)


if __name__ == "__main__":
    unittest.main()