File size: 4,382 Bytes
05ccdc6 e2f8b29 05ccdc6 e2f8b29 05ccdc6 aa0bed2 05ccdc6 aa0bed2 05ccdc6 aa0bed2 05ccdc6 e2f8b29 05ccdc6 e2f8b29 05ccdc6 e2f8b29 05ccdc6 e2f8b29 05ccdc6 e2f8b29 0b9b77b 05ccdc6 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 0b9b77b e2f8b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | """Training curve generation — real PyTorch mini-training.
All curves come from run_real_training() in pytorch_engine.py:
- Real torch.nn.Module (SimpleCNN or SimpleMLP)
- Real torch.autograd forward + backward passes
- Real torch.optim optimizer steps
- Real validation on held-out data
- 20 epochs, cached per (task_id, seed, model_type)
Zero numpy. Zero parametric formulas. Zero synthetic curves.
"""
from __future__ import annotations
import torch
from ml_training_debugger.scenarios import ScenarioParams
EPOCHS = 20
def _get_real_curves(scenario: ScenarioParams) -> dict[str, list[float]]:
"""Run real PyTorch training and return loss/accuracy curves.
Calls pytorch_engine.run_real_training() which:
- Creates a real SimpleCNN or SimpleMLP model
- Generates random CIFAR-10 style data (3x32x32)
- Runs 20 epochs of real forward/backward passes
- Injects the actual fault (wrong LR, eval mode, data leakage, etc.)
- Returns real loss_history, val_loss_history, val_acc_history
Results are cached per (task_id, seed, model_type) for instant resets.
"""
from ml_training_debugger.pytorch_engine import run_real_training
return run_real_training(scenario)
def gen_loss_history(scenario: ScenarioParams) -> list[float]:
"""Generate training loss history (20 epochs) from real PyTorch training."""
return _get_real_curves(scenario)["loss_history"]
def gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]:
"""Generate validation accuracy history (20 epochs) from real PyTorch training."""
return _get_real_curves(scenario)["val_acc_history"]
def gen_val_loss_history(scenario: ScenarioParams) -> list[float]:
"""Generate validation loss history (20 epochs) from real PyTorch training."""
return _get_real_curves(scenario)["val_loss_history"]
def _gen_confusion_matrix(scenario: ScenarioParams) -> list[list[float]]:
"""Generate a 10x10 confusion matrix based on the fault type.
Uses torch.Tensor operations on random data shaped by the fault scenario.
"""
torch.manual_seed(scenario.seed + 10)
root = scenario.root_cause.value
n = 10
if root == "data_leakage":
# High diagonal but with leakage-induced off-diagonal noise
base = torch.eye(n) * 0.8
noise = torch.rand(n, n) * scenario.leakage_pct * 0.3
cm = base + noise
elif root == "overfitting":
# Near-perfect diagonal (memorized)
cm = torch.eye(n) * 0.95 + torch.rand(n, n) * 0.02
else:
# Normal confusion with moderate accuracy
cm = torch.eye(n) * 0.6 + torch.rand(n, n) * 0.08
# Normalize rows to sum to ~1.0
row_sums = cm.sum(dim=1, keepdim=True)
cm = cm / row_sums
return cm.tolist()
def gen_data_batch_stats(scenario: ScenarioParams) -> dict:
"""Generate data batch statistics for the scenario."""
torch.manual_seed(scenario.seed + 3)
root = scenario.root_cause.value
cm = _gen_confusion_matrix(scenario)
if root == "data_leakage":
overlap = 0.5 + scenario.leakage_pct * 1.5
overlap = min(overlap, 0.92)
return {
"label_distribution": {i: 0.1 for i in range(10)},
"feature_mean": 0.45 + torch.randn(1).item() * 0.05,
"feature_std": 0.22 + torch.randn(1).item() * 0.02,
"null_count": 0,
"class_overlap_score": overlap,
"batch_size": 64,
"duplicate_ratio": scenario.leakage_pct,
"confusion_matrix": cm,
}
if root == "overfitting":
return {
"label_distribution": {i: 0.1 for i in range(10)},
"feature_mean": 0.48 + torch.randn(1).item() * 0.03,
"feature_std": 0.25 + torch.randn(1).item() * 0.02,
"null_count": 0,
"class_overlap_score": 0.0,
"batch_size": 64,
"duplicate_ratio": 0.0,
"confusion_matrix": cm,
}
# Default: normal data
return {
"label_distribution": {i: 0.1 for i in range(10)},
"feature_mean": 0.47 + torch.randn(1).item() * 0.03,
"feature_std": 0.24 + torch.randn(1).item() * 0.02,
"null_count": 0,
"class_overlap_score": 0.0 + torch.randn(1).abs().item() * 0.05,
"batch_size": 64,
"duplicate_ratio": 0.0,
"confusion_matrix": cm,
}
|