File size: 4,319 Bytes
8097081 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | # src/pytorch_debug_env/bug_library.py
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
import numpy as np
@dataclass
class BugTemplate:
bug_type: str
category: str
difficulty: str
primary_bug_file: str
related_files: List[str]
red_herring_file: Optional[str]
fix_strategy: str
line_range: List[int]
description: str
artifact_generator: Callable
repo_mutator: Callable
metadata: Dict = field(default_factory=dict)
BUG_CATEGORIES = {
"shape_mismatch": "model",
"missing_zero_grad": "optimization",
"wrong_loss_function": "optimization",
"learning_rate_too_high": "optimization",
"gradient_explosion": "optimization",
"memory_leak": "resource",
"data_leakage": "data",
"incorrect_normalization": "data",
"distributed_sync_error": "distributed",
"amp_overflow": "numerics",
}
# Realistic artifact generator
def dummy_artifact_generator(artifact_type: str, rng):
if artifact_type == "loss_curve":
t = np.arange(100)
base = 2.3 * np.exp(-0.01 * t) + 0.15
oscillation = 0.22 * np.sin(0.25 * t) * np.exp(-0.002 * t)
return [
{"step": int(i), "train_loss": float(base[i] + oscillation[i])}
for i in range(100)
]
elif artifact_type == "gpu_profile":
t = np.arange(100)
allocated = 2048 + 2.4 * t
return [
{"step": int(i), "allocated_mb": float(allocated[i])}
for i in range(100)
]
elif artifact_type == "training_log":
return "Epoch 1, Step 0: loss 2.45\nEpoch 1, Step 1: loss 2.43\n"
return []
def mutate_missing_zero_grad(repo_files, rng):
repo_files["train.py"] = """import torch
from model.architecture import Net
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(10):
for x, y in dataloader:
# optimizer.zero_grad() # BUG: commented out
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
"""
return repo_files
def mutate_data_leakage(repo_files, rng):
repo_files["data/dataset.py"] = """from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, data, split="train"):
# BUG: We use the entire data instead of just the split
self.data = data
self.split = split
"""
return repo_files
def mutate_memory_leak(repo_files, rng):
repo_files["data/dataset.py"] = """from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self):
# BUG: Storing huge tensors in a class-level variable leading to memory accumulation
self.cache = []
def load(self, x):
self.cache.append(x)
return x
"""
return repo_files
BUG_TEMPLATES = [
BugTemplate(
bug_type="missing_zero_grad",
category="optimization",
difficulty="easy",
primary_bug_file="train.py",
related_files=[],
red_herring_file="model/architecture.py",
fix_strategy="Call optimizer.zero_grad() before loss.backward()",
line_range=[9, 14],
description="Missing zero grad",
artifact_generator=dummy_artifact_generator,
repo_mutator=mutate_missing_zero_grad,
),
BugTemplate(
bug_type="data_leakage",
category="data",
difficulty="medium",
primary_bug_file="data/dataset.py",
related_files=["data/preprocessing.py"],
red_herring_file="train.py",
fix_strategy="Ensure validation split is strictly separate from training",
line_range=[4, 6],
description="Data leakage",
artifact_generator=dummy_artifact_generator,
repo_mutator=mutate_data_leakage,
),
BugTemplate(
bug_type="memory_leak",
category="resource",
difficulty="hard",
primary_bug_file="data/dataset.py",
related_files=["train.py"],
red_herring_file="model/attention.py",
fix_strategy="Avoid holding reference to tensors in class cache",
line_range=[5, 9],
description="Memory leak",
artifact_generator=dummy_artifact_generator,
repo_mutator=mutate_memory_leak,
)
]
|