| """Test real PyTorch model instantiation and fault injection.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ml_training_debugger.pytorch_engine import ( |
| SimpleCNN, |
| create_model_and_inject_fault, |
| extract_gradient_stats, |
| extract_model_modes, |
| extract_weight_stats, |
| ) |
| from ml_training_debugger.scenarios import sample_scenario |
|
|
|
|
| class TestSimpleCNN: |
| def test_is_nn_module(self): |
| model = SimpleCNN() |
| assert isinstance(model, nn.Module) |
|
|
| def test_param_count(self): |
| model = SimpleCNN() |
| count = sum(p.numel() for p in model.parameters()) |
| assert 30_000 < count < 100_000 |
|
|
| def test_forward_pass(self): |
| model = SimpleCNN() |
| x = torch.randn(2, 3, 32, 32) |
| out = model(x) |
| assert out.shape == (2, 10) |
|
|
|
|
| class TestFaultInjection: |
| def test_task_001_exploding_gradients(self): |
| scenario = sample_scenario("task_001", seed=42) |
| model, info = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
| assert len(stats) > 0 |
| |
| any_high = any(s.mean_norm > 1.0 for s in stats) |
| assert any_high |
|
|
| def test_task_005_eval_mode(self): |
| scenario = sample_scenario("task_005", seed=42) |
| model, info = create_model_and_inject_fault(scenario) |
| assert not model.training |
|
|
| def test_task_005_gradients_not_exploding(self): |
| scenario = sample_scenario("task_005", seed=42) |
| model, info = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
| |
| for s in stats: |
| assert not s.is_exploding, f"Layer {s.layer_name} should not be exploding" |
|
|
|
|
| class TestExtractGradientStats: |
| def test_returns_gradient_stats(self): |
| scenario = sample_scenario("task_001", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
| assert len(stats) == 4 |
| for s in stats: |
| assert isinstance(s.mean_norm, float) |
| assert isinstance(s.norm_history, list) |
| assert len(s.norm_history) == 5 |
|
|
|
|
| class TestExtractWeightStats: |
| def test_returns_weight_stats(self): |
| scenario = sample_scenario("task_001", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| stats = extract_weight_stats(model) |
| assert len(stats) > 0 |
| for s in stats: |
| assert isinstance(s.weight_norm, float) |
| assert isinstance(s.has_nan, bool) |
|
|
|
|
| class TestExtractModelModes: |
| def test_train_mode(self): |
| model = SimpleCNN() |
| model.train() |
| modes = extract_model_modes(model) |
| assert all(v == "train" for v in modes.values()) |
|
|
| def test_eval_mode(self): |
| model = SimpleCNN() |
| model.eval() |
| modes = extract_model_modes(model) |
| assert all(v == "eval" for v in modes.values()) |
|
|
|
|
| class TestTask005RedHerrings: |
| """Test Task 5 red herring injection — conv1 near-vanishing, FC spike.""" |
|
|
| def test_conv1_near_vanishing_red_herring(self): |
| """When spike layer is fc, conv1 should show near-vanishing gradient.""" |
| scenario = sample_scenario("task_005", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
|
|
| conv1 = next(s for s in stats if s.layer_name == "conv1") |
| if scenario.red_herring_spike_layer != "conv1": |
| |
| assert conv1.mean_norm < 0.01 |
| assert not conv1.is_vanishing |
|
|
| def test_fc_spike_not_exploding(self): |
| """FC spike has elevated gradient but is_exploding=False (mean < 10.0).""" |
| scenario = sample_scenario("task_005", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
|
|
| spike_layer = next( |
| s for s in stats if s.layer_name == scenario.red_herring_spike_layer |
| ) |
| assert not spike_layer.is_exploding |
| |
| assert spike_layer.mean_norm > 0 |
|
|
| def test_all_layers_not_exploding(self): |
| """All layers is_exploding=False — this gates gradients_were_normal.""" |
| scenario = sample_scenario("task_005", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
| for s in stats: |
| assert not s.is_exploding, f"{s.layer_name} should not be exploding" |
|
|
|
|
| class TestVanishingGradientInjection: |
| """Test vanishing gradient fault injection produces correct stats.""" |
|
|
| def test_task_002_vanishing(self): |
| scenario = sample_scenario("task_002", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| stats = extract_gradient_stats(model, scenario) |
| |
| assert any(s.is_vanishing for s in stats) |
|
|
| def test_task_002_model_in_train_mode(self): |
| scenario = sample_scenario("task_002", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| assert model.training |
|
|
|
|
| class TestCodeBugFaultInjection: |
| """Test code bug fault injection — model should be normal.""" |
|
|
| def test_task_006_model_trains_normally(self): |
| scenario = sample_scenario("task_006", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| assert model.training |
| stats = extract_gradient_stats(model, scenario) |
| |
| assert not any(s.is_exploding for s in stats) |
|
|
|
|
| class TestDataLeakageFaultInjection: |
| """Test data leakage scenario — model should be normal.""" |
|
|
| def test_task_003_normal_model(self): |
| scenario = sample_scenario("task_003", seed=42) |
| model, _ = create_model_and_inject_fault(scenario) |
| assert model.training |
| stats = extract_gradient_stats(model, scenario) |
| assert not any(s.is_exploding for s in stats) |
|
|