File size: 4,382 Bytes
05ccdc6
e2f8b29
05ccdc6
 
 
 
 
 
 
 
e2f8b29
 
 
 
 
 
 
 
 
 
 
05ccdc6
 
aa0bed2
05ccdc6
 
 
 
 
 
aa0bed2
05ccdc6
aa0bed2
05ccdc6
e2f8b29
05ccdc6
e2f8b29
 
05ccdc6
 
 
e2f8b29
 
 
05ccdc6
 
e2f8b29
 
 
05ccdc6
 
e2f8b29
 
0b9b77b
05ccdc6
 
 
 
0b9b77b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2f8b29
 
 
 
 
 
0b9b77b
 
e2f8b29
0b9b77b
e2f8b29
 
 
 
 
 
 
 
 
0b9b77b
e2f8b29
 
 
 
 
 
 
 
 
 
 
0b9b77b
e2f8b29
 
 
 
 
 
 
 
 
 
 
0b9b77b
e2f8b29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Training curve generation — real PyTorch mini-training.

All curves come from run_real_training() in pytorch_engine.py:
  - Real torch.nn.Module (SimpleCNN or SimpleMLP)
  - Real torch.autograd forward + backward passes
  - Real torch.optim optimizer steps
  - Real validation on held-out data
  - 20 epochs, cached per (task_id, seed, model_type)

Zero numpy. Zero parametric formulas. Zero synthetic curves.
"""

from __future__ import annotations

import torch

from ml_training_debugger.scenarios import ScenarioParams

EPOCHS = 20


def _get_real_curves(scenario: ScenarioParams) -> dict[str, list[float]]:
    """Run real PyTorch training and return loss/accuracy curves.

    Calls pytorch_engine.run_real_training() which:
    - Creates a real SimpleCNN or SimpleMLP model
    - Generates random CIFAR-10 style data (3x32x32)
    - Runs 20 epochs of real forward/backward passes
    - Injects the actual fault (wrong LR, eval mode, data leakage, etc.)
    - Returns real loss_history, val_loss_history, val_acc_history

    Results are cached per (task_id, seed, model_type) for instant resets.
    """
    from ml_training_debugger.pytorch_engine import run_real_training

    return run_real_training(scenario)


def gen_loss_history(scenario: ScenarioParams) -> list[float]:
    """Generate training loss history (20 epochs) from real PyTorch training."""
    return _get_real_curves(scenario)["loss_history"]


def gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]:
    """Generate validation accuracy history (20 epochs) from real PyTorch training."""
    return _get_real_curves(scenario)["val_acc_history"]


def gen_val_loss_history(scenario: ScenarioParams) -> list[float]:
    """Generate validation loss history (20 epochs) from real PyTorch training."""
    return _get_real_curves(scenario)["val_loss_history"]


def _gen_confusion_matrix(scenario: ScenarioParams) -> list[list[float]]:
    """Generate a 10x10 confusion matrix based on the fault type.

    Uses torch.Tensor operations on random data shaped by the fault scenario.
    """
    torch.manual_seed(scenario.seed + 10)
    root = scenario.root_cause.value
    n = 10

    if root == "data_leakage":
        # High diagonal but with leakage-induced off-diagonal noise
        base = torch.eye(n) * 0.8
        noise = torch.rand(n, n) * scenario.leakage_pct * 0.3
        cm = base + noise
    elif root == "overfitting":
        # Near-perfect diagonal (memorized)
        cm = torch.eye(n) * 0.95 + torch.rand(n, n) * 0.02
    else:
        # Normal confusion with moderate accuracy
        cm = torch.eye(n) * 0.6 + torch.rand(n, n) * 0.08

    # Normalize rows to sum to ~1.0
    row_sums = cm.sum(dim=1, keepdim=True)
    cm = cm / row_sums
    return cm.tolist()


def gen_data_batch_stats(scenario: ScenarioParams) -> dict:
    """Generate data batch statistics for the scenario."""
    torch.manual_seed(scenario.seed + 3)

    root = scenario.root_cause.value

    cm = _gen_confusion_matrix(scenario)

    if root == "data_leakage":
        overlap = 0.5 + scenario.leakage_pct * 1.5
        overlap = min(overlap, 0.92)
        return {
            "label_distribution": {i: 0.1 for i in range(10)},
            "feature_mean": 0.45 + torch.randn(1).item() * 0.05,
            "feature_std": 0.22 + torch.randn(1).item() * 0.02,
            "null_count": 0,
            "class_overlap_score": overlap,
            "batch_size": 64,
            "duplicate_ratio": scenario.leakage_pct,
            "confusion_matrix": cm,
        }

    if root == "overfitting":
        return {
            "label_distribution": {i: 0.1 for i in range(10)},
            "feature_mean": 0.48 + torch.randn(1).item() * 0.03,
            "feature_std": 0.25 + torch.randn(1).item() * 0.02,
            "null_count": 0,
            "class_overlap_score": 0.0,
            "batch_size": 64,
            "duplicate_ratio": 0.0,
            "confusion_matrix": cm,
        }

    # Default: normal data
    return {
        "label_distribution": {i: 0.1 for i in range(10)},
        "feature_mean": 0.47 + torch.randn(1).item() * 0.03,
        "feature_std": 0.24 + torch.randn(1).item() * 0.02,
        "null_count": 0,
        "class_overlap_score": 0.0 + torch.randn(1).abs().item() * 0.05,
        "batch_size": 64,
        "duplicate_ratio": 0.0,
        "confusion_matrix": cm,
    }