| """WebSocket integration tests. |
| |
| Verifies the /ws endpoint works with correct message formats. |
| Auto-validators test: connect -> reset -> step -> diagnose. |
| |
| Key discovery: WSResetMessage has a `data: Dict[str, Any]` field. |
| Task selection via WS: {"type": "reset", "data": {"task_id": "task_003"}} |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
|
|
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from server.app import app |
|
|
|
|
| class TestWebSocketEndpoint: |
| """Test WebSocket /ws endpoint.""" |
|
|
| def test_ws_endpoint_exists(self) -> None: |
| paths = [r.path for r in app.routes if hasattr(r, "path")] |
| assert "/ws" in paths |
|
|
| def test_ws_reset_returns_observation(self) -> None: |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset"}) |
| resp = ws.receive_json() |
|
|
| assert resp["type"] == "observation" |
| obs = resp["data"]["observation"] |
| assert len(obs["training_loss_history"]) == 20 |
| assert len(obs["val_accuracy_history"]) == 20 |
| assert len(obs["val_loss_history"]) == 20 |
| assert obs["framework"] == "pytorch" |
| assert obs["epoch"] == 20 |
| assert isinstance(obs["available_actions"], list) |
| assert len(obs["available_actions"]) > 0 |
| assert obs["episode_state"]["step_count"] == 0 |
|
|
| def test_ws_reset_with_task_selection(self) -> None: |
| """Task selection via WS using data field.""" |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| |
| ws.send_json({"type": "reset", "data": {"task_id": "task_003", "seed": 42}}) |
| resp = ws.receive_json() |
|
|
| assert resp["type"] == "observation" |
| obs = resp["data"]["observation"] |
| assert "architecture upgraded" in obs.get("notes", "").lower() |
| assert obs["error_log"] is None |
|
|
| def test_ws_task_selection_all_tasks(self) -> None: |
| """Verify all 6 tasks can be selected via WS.""" |
| client = TestClient(app) |
| task_ids = ["task_001", "task_002", "task_003", "task_004", "task_005", "task_006"] |
|
|
| for task_id in task_ids: |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset", "data": {"task_id": task_id, "seed": 42}}) |
| resp = ws.receive_json() |
| assert resp["type"] == "observation", f"{task_id} failed reset" |
| obs = resp["data"]["observation"] |
| assert len(obs["training_loss_history"]) == 20, f"{task_id} missing loss history" |
|
|
| def test_ws_step_inspect_gradients(self) -> None: |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset"}) |
| ws.receive_json() |
|
|
| ws.send_json( |
| {"type": "step", "data": {"action_type": "inspect_gradients"}} |
| ) |
| resp = ws.receive_json() |
|
|
| assert resp["type"] == "observation" |
| obs = resp["data"]["observation"] |
| assert len(obs["gradient_stats"]) == 4 |
| assert obs["episode_state"]["gradients_inspected"] is True |
| for g in obs["gradient_stats"]: |
| assert "layer_name" in g |
| assert "mean_norm" in g |
| assert "is_exploding" in g |
| assert "is_vanishing" in g |
|
|
| def test_ws_full_episode_flow(self) -> None: |
| """Full episode: reset -> inspect -> fix -> restart -> diagnose.""" |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| |
| ws.send_json({"type": "reset", "data": {"task_id": "task_001", "seed": 42}}) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert obs["error_log"] is not None |
|
|
| |
| ws.send_json( |
| {"type": "step", "data": {"action_type": "inspect_gradients"}} |
| ) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert any(g["is_exploding"] for g in obs["gradient_stats"]) |
|
|
| |
| ws.send_json( |
| { |
| "type": "step", |
| "data": { |
| "action_type": "modify_config", |
| "target": "learning_rate", |
| "value": 0.001, |
| }, |
| } |
| ) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert obs["episode_state"]["fix_action_taken"] is True |
|
|
| |
| ws.send_json({"type": "step", "data": {"action_type": "restart_run"}}) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert obs["episode_state"]["restart_after_fix"] is True |
|
|
| |
| ws.send_json( |
| { |
| "type": "step", |
| "data": { |
| "action_type": "mark_diagnosed", |
| "diagnosis": "lr_too_high", |
| }, |
| } |
| ) |
| resp = ws.receive_json() |
| done = resp["data"].get("done", False) |
| obs = resp["data"]["observation"] |
| assert done or obs["episode_state"]["diagnosis_submitted"] |
|
|
| def test_ws_task_005_red_herrings(self) -> None: |
| """Task 5 via WS — verify red herrings and correct diagnosis path.""" |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset", "data": {"task_id": "task_005", "seed": 42}}) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| |
| assert obs.get("error_log") is not None |
| assert obs["gpu_memory_used_gb"] > 14.0 |
|
|
| |
| ws.send_json( |
| {"type": "step", "data": {"action_type": "inspect_gradients"}} |
| ) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| for g in obs["gradient_stats"]: |
| assert not g["is_exploding"] |
|
|
| |
| ws.send_json( |
| {"type": "step", "data": {"action_type": "inspect_model_modes"}} |
| ) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert any(v == "eval" for v in obs["model_mode_info"].values()) |
|
|
| def test_ws_task_006_code_inspection(self) -> None: |
| """Task 6 via WS — verify code inspection and fix.""" |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset", "data": {"task_id": "task_006", "seed": 42}}) |
| ws.receive_json() |
|
|
| |
| ws.send_json( |
| {"type": "step", "data": {"action_type": "inspect_code"}} |
| ) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert obs["code_snippet"] is not None |
| assert obs["code_snippet"]["filename"] == "train.py" |
| assert obs["code_snippet"]["line_count"] > 0 |
|
|
| def test_ws_invalid_message_returns_error(self) -> None: |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset"}) |
| ws.receive_json() |
|
|
| |
| ws.send_json( |
| {"type": "step", "action": {"action_type": "inspect_gradients"}} |
| ) |
| resp = ws.receive_json() |
| assert resp["type"] == "error" |
|
|
| def test_ws_step_data_batch(self) -> None: |
| client = TestClient(app) |
| with client.websocket_connect("/ws") as ws: |
| ws.send_json({"type": "reset"}) |
| ws.receive_json() |
|
|
| ws.send_json( |
| {"type": "step", "data": {"action_type": "inspect_data_batch"}} |
| ) |
| resp = ws.receive_json() |
| obs = resp["data"]["observation"] |
| assert obs["data_batch_stats"] is not None |
| assert "class_overlap_score" in obs["data_batch_stats"] |
| assert obs["episode_state"]["data_inspected"] is True |
|
|