pytorch-training-debugger / tests /test_websocket.py
omkarrr88
Major fixes + gap fixes
4f58e42
"""WebSocket integration tests.
Verifies the /ws endpoint works with correct message formats.
Auto-validators test: connect -> reset -> step -> diagnose.
Key discovery: WSResetMessage has a `data: Dict[str, Any]` field.
Task selection via WS: {"type": "reset", "data": {"task_id": "task_003"}}
"""
from __future__ import annotations
import json
import pytest
from fastapi.testclient import TestClient
from server.app import app
class TestWebSocketEndpoint:
"""Test WebSocket /ws endpoint."""
def test_ws_endpoint_exists(self) -> None:
paths = [r.path for r in app.routes if hasattr(r, "path")]
assert "/ws" in paths
def test_ws_reset_returns_observation(self) -> None:
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset"})
resp = ws.receive_json()
assert resp["type"] == "observation"
obs = resp["data"]["observation"]
assert len(obs["training_loss_history"]) == 20
assert len(obs["val_accuracy_history"]) == 20
assert len(obs["val_loss_history"]) == 20
assert obs["framework"] == "pytorch"
assert obs["epoch"] == 20
assert isinstance(obs["available_actions"], list)
assert len(obs["available_actions"]) > 0
assert obs["episode_state"]["step_count"] == 0
def test_ws_reset_with_task_selection(self) -> None:
"""Task selection via WS using data field."""
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
# Task 3 is data leakage — has specific notes
ws.send_json({"type": "reset", "data": {"task_id": "task_003", "seed": 42}})
resp = ws.receive_json()
assert resp["type"] == "observation"
obs = resp["data"]["observation"]
assert "architecture upgraded" in obs.get("notes", "").lower()
assert obs["error_log"] is None # Task 3 has no error log
def test_ws_task_selection_all_tasks(self) -> None:
"""Verify all 6 tasks can be selected via WS."""
client = TestClient(app)
task_ids = ["task_001", "task_002", "task_003", "task_004", "task_005", "task_006"]
for task_id in task_ids:
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset", "data": {"task_id": task_id, "seed": 42}})
resp = ws.receive_json()
assert resp["type"] == "observation", f"{task_id} failed reset"
obs = resp["data"]["observation"]
assert len(obs["training_loss_history"]) == 20, f"{task_id} missing loss history"
def test_ws_step_inspect_gradients(self) -> None:
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset"})
ws.receive_json()
ws.send_json(
{"type": "step", "data": {"action_type": "inspect_gradients"}}
)
resp = ws.receive_json()
assert resp["type"] == "observation"
obs = resp["data"]["observation"]
assert len(obs["gradient_stats"]) == 4
assert obs["episode_state"]["gradients_inspected"] is True
for g in obs["gradient_stats"]:
assert "layer_name" in g
assert "mean_norm" in g
assert "is_exploding" in g
assert "is_vanishing" in g
def test_ws_full_episode_flow(self) -> None:
"""Full episode: reset -> inspect -> fix -> restart -> diagnose."""
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
# Reset to task_001 (exploding gradients)
ws.send_json({"type": "reset", "data": {"task_id": "task_001", "seed": 42}})
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert obs["error_log"] is not None
# Inspect gradients
ws.send_json(
{"type": "step", "data": {"action_type": "inspect_gradients"}}
)
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert any(g["is_exploding"] for g in obs["gradient_stats"])
# Fix: reduce learning rate
ws.send_json(
{
"type": "step",
"data": {
"action_type": "modify_config",
"target": "learning_rate",
"value": 0.001,
},
}
)
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert obs["episode_state"]["fix_action_taken"] is True
# Restart
ws.send_json({"type": "step", "data": {"action_type": "restart_run"}})
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert obs["episode_state"]["restart_after_fix"] is True
# Diagnose
ws.send_json(
{
"type": "step",
"data": {
"action_type": "mark_diagnosed",
"diagnosis": "lr_too_high",
},
}
)
resp = ws.receive_json()
done = resp["data"].get("done", False)
obs = resp["data"]["observation"]
assert done or obs["episode_state"]["diagnosis_submitted"]
def test_ws_task_005_red_herrings(self) -> None:
"""Task 5 via WS — verify red herrings and correct diagnosis path."""
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset", "data": {"task_id": "task_005", "seed": 42}})
resp = ws.receive_json()
obs = resp["data"]["observation"]
# Task 5 has GPU memory warning
assert obs.get("error_log") is not None
assert obs["gpu_memory_used_gb"] > 14.0 # 91% of 16GB
# Inspect gradients — all should be non-exploding
ws.send_json(
{"type": "step", "data": {"action_type": "inspect_gradients"}}
)
resp = ws.receive_json()
obs = resp["data"]["observation"]
for g in obs["gradient_stats"]:
assert not g["is_exploding"]
# Inspect model modes — should reveal eval mode
ws.send_json(
{"type": "step", "data": {"action_type": "inspect_model_modes"}}
)
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert any(v == "eval" for v in obs["model_mode_info"].values())
def test_ws_task_006_code_inspection(self) -> None:
"""Task 6 via WS — verify code inspection and fix."""
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset", "data": {"task_id": "task_006", "seed": 42}})
ws.receive_json()
# Inspect code
ws.send_json(
{"type": "step", "data": {"action_type": "inspect_code"}}
)
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert obs["code_snippet"] is not None
assert obs["code_snippet"]["filename"] == "train.py"
assert obs["code_snippet"]["line_count"] > 0
def test_ws_invalid_message_returns_error(self) -> None:
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset"})
ws.receive_json()
# Wrong format — "action" instead of "data"
ws.send_json(
{"type": "step", "action": {"action_type": "inspect_gradients"}}
)
resp = ws.receive_json()
assert resp["type"] == "error"
def test_ws_step_data_batch(self) -> None:
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset"})
ws.receive_json()
ws.send_json(
{"type": "step", "data": {"action_type": "inspect_data_batch"}}
)
resp = ws.receive_json()
obs = resp["data"]["observation"]
assert obs["data_batch_stats"] is not None
assert "class_overlap_score" in obs["data_batch_stats"]
assert obs["episode_state"]["data_inspected"] is True