| |
| """Validate parametric exploding gradient curves against real PyTorch training. |
| |
| Trains a CNN with lr=0.1 for 20 epochs, compares loss curve to simulation. |
| Asserts R² > 0.85 between real and simulated curves. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ml_training_debugger.pytorch_engine import SimpleCNN |
| from ml_training_debugger.scenarios import sample_scenario |
| from ml_training_debugger.simulation import gen_loss_history |
|
|
|
|
| def run_real_training(lr: float = 0.1, epochs: int = 20) -> list[float]: |
| """Run real training with high LR and capture loss history.""" |
| torch.manual_seed(42) |
| model = SimpleCNN() |
| model.train() |
| optimizer = torch.optim.SGD(model.parameters(), lr=lr) |
| criterion = nn.CrossEntropyLoss() |
|
|
| losses: list[float] = [] |
| for _ in range(epochs): |
| batch_x = torch.randn(16, 3, 32, 32) |
| batch_y = torch.randint(0, 10, (16,)) |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
| loss_val = loss.item() |
| losses.append(loss_val if not (loss_val != loss_val) else float("inf")) |
| return losses |
|
|
|
|
| def compute_r_squared(real: list[float], simulated: list[float]) -> float: |
| """Compute R² between two curves, ignoring inf/nan values.""" |
| pairs = [ |
| (r, s) |
| for r, s in zip(real, simulated) |
| if r != float("inf") and s != float("inf") and r == r and s == s |
| ] |
| if len(pairs) < 3: |
| return 0.0 |
| real_t = torch.tensor([p[0] for p in pairs]) |
| sim_t = torch.tensor([p[1] for p in pairs]) |
| ss_res = ((real_t - sim_t) ** 2).sum() |
| ss_tot = ((real_t - real_t.mean()) ** 2).sum() |
| if ss_tot == 0: |
| return 1.0 |
| return (1 - ss_res / ss_tot).item() |
|
|
|
|
| def main() -> None: |
| scenario = sample_scenario("task_001", seed=42) |
| simulated = gen_loss_history(scenario) |
| real = run_real_training(lr=scenario.learning_rate, epochs=20) |
|
|
| r2 = compute_r_squared(real, simulated) |
| print(f"Exploding Gradients — R²: {r2:.4f}") |
| print(f" Real loss trend: {real[0]:.2f} → {'INF' if real[-1] == float('inf') else f'{real[-1]:.2f}'}") |
| print(f" Sim loss trend: {simulated[0]:.2f} → {'INF' if simulated[-1] == float('inf') else f'{simulated[-1]:.2f}'}") |
|
|
| |
| real_diverges = any(v == float("inf") or v > 100 for v in real) |
| sim_diverges = any(v == float("inf") or v > 100 for v in simulated) |
| print(f" Real diverges: {real_diverges}, Sim diverges: {sim_diverges}") |
| assert real_diverges and sim_diverges, "Both curves should diverge" |
| print(" PASS: Both curves diverge as expected") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|