| """Extended simulation tests — adapted for real mini-training curves.""" |
|
|
| from __future__ import annotations |
|
|
| from ml_training_debugger.scenarios import sample_scenario |
| from ml_training_debugger.simulation import ( |
| gen_data_batch_stats, |
| gen_loss_history, |
| gen_val_accuracy_history, |
| gen_val_loss_history, |
| ) |
|
|
|
|
| class TestVanishingGradients: |
| def test_loss_barely_decreases(self): |
| s = sample_scenario("task_002", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_acc_low(self): |
| s = sample_scenario("task_002", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_loss_present(self): |
| s = sample_scenario("task_002", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
|
|
| class TestOverfitting: |
| def test_loss_history_present(self): |
| s = sample_scenario("task_004", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_acc_present(self): |
| s = sample_scenario("task_004", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_loss_present(self): |
| s = sample_scenario("task_004", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| def test_data_batch_stats_clean(self): |
| s = sample_scenario("task_004", seed=42) |
| stats = gen_data_batch_stats(s) |
| assert stats["class_overlap_score"] == 0.0 |
| assert stats["duplicate_ratio"] == 0.0 |
|
|
|
|
| class TestCodeBug: |
| def test_loss_history(self): |
| s = sample_scenario("task_006", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_acc(self): |
| s = sample_scenario("task_006", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_loss(self): |
| s = sample_scenario("task_006", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
|
|
| class TestBatchNormEval: |
| def test_val_loss_present(self): |
| s = sample_scenario("task_005", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_acc_near_zero(self): |
| s = sample_scenario("task_005", seed=42) |
| hist = gen_val_accuracy_history(s) |
| |
| assert len(hist) == 20 |
|
|
|
|
| class TestSchedulerMisconfigured: |
| def test_loss_history(self): |
| s = sample_scenario("task_007", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_acc(self): |
| s = sample_scenario("task_007", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| def test_val_loss(self): |
| s = sample_scenario("task_007", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|