File size: 5,333 Bytes
e2f8b29 206438f e2f8b29 206438f e2f8b29 0b9b77b e2f8b29 206438f e2f8b29 206438f e2f8b29 206438f e2f8b29 206438f e2f8b29 0b9b77b e2f8b29 206438f e2f8b29 206438f e2f8b29 206438f e2f8b29 206438f e2f8b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """All Pydantic models, enums, and typed data structures.
No business logic. Pure data definitions.
"""
from __future__ import annotations
import enum
from typing import Optional, Union
import torch # noqa: F401
from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field
class RootCauseDiagnosis(str, enum.Enum):
"""Closed enumeration of ML failure root causes."""
LR_TOO_HIGH = "lr_too_high"
VANISHING_GRADIENTS = "vanishing_gradients"
DATA_LEAKAGE = "data_leakage"
OVERFITTING = "overfitting"
BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
CODE_BUG = "code_bug"
SCHEDULER_MISCONFIGURED = "scheduler_misconfigured"
VALID_DIAGNOSES: set[str] = {d.value for d in RootCauseDiagnosis}
class TrainingConfig(BaseModel):
"""Typed hyperparameter configuration."""
learning_rate: float = 0.001
weight_decay: float = 0.0001
batch_size: int = 64
hidden_dim: int = 64
num_layers: int = 3
optimizer: str = "adam"
dropout_rate: float = 0.0
gradient_clip_norm: Optional[float] = None
VALID_CONFIG_KEYS: set[str] = set(TrainingConfig.model_fields.keys())
class GradientStats(BaseModel):
"""Per-layer gradient information from real torch.autograd."""
layer_name: str
norm_history: list[float]
mean_norm: float
max_norm: float
is_exploding: bool # True when mean_norm > 10.0
is_vanishing: bool # True when mean_norm < 1e-6
class ModelWeightStats(BaseModel):
"""Per-layer weight statistics from real state_dict()."""
layer_name: str
weight_norm: float
weight_mean: float
weight_std: float
weight_min: float
weight_max: float
dead_neuron_pct: float = 0.0
has_nan: bool = False
has_inf: bool = False
class DataBatchStats(BaseModel):
"""Data batch inspection results."""
label_distribution: dict[int, float]
feature_mean: float
feature_std: float
null_count: int = 0
class_overlap_score: float
batch_size: int
duplicate_ratio: float = 0.0
confusion_matrix: Optional[list[list[float]]] = None
class CodeSnippet(BaseModel):
"""PyTorch code for Task 6 inspection."""
code: str
filename: str = "train.py"
line_count: int
imports: list[str]
hint: Optional[str] = None
class EpisodeState(BaseModel):
"""Tracks agent history within an episode."""
step_count: int = 0
gradients_inspected: bool = False
gradients_were_normal: bool = False
data_inspected: bool = False
model_modes_inspected: bool = False
model_weights_inspected: bool = False
code_inspected: bool = False
fix_action_taken: bool = False
restart_after_fix: bool = False
diagnosis_submitted: bool = False
actions_taken: list[str] = Field(default_factory=list)
def compute_available_actions(self) -> list[str]:
"""Dynamically compute available actions based on current state."""
actions: list[str] = [
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
"modify_config",
"add_callback",
"replace_optimizer",
"patch_data_loader",
"fix_model_mode",
]
if self.code_inspected:
actions.append("fix_code")
if self.fix_action_taken:
actions.append("restart_run")
if not self.diagnosis_submitted:
actions.append("mark_diagnosed")
return actions
ALL_ACTION_TYPES: set[str] = {
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
"modify_config",
"add_callback",
"replace_optimizer",
"patch_data_loader",
"fix_model_mode",
"fix_code",
"restart_run",
"mark_diagnosed",
}
class MLTrainingAction(Action):
"""What the agent can do — extends openenv Action."""
action_type: str
target: Optional[str] = None
value: Optional[Union[float, int, str]] = None
diagnosis: Optional[str] = None
line: Optional[int] = None
replacement: Optional[str] = None
class MLTrainingObservation(Observation):
"""Full observation — extends openenv Observation.
Observation base has built-in: done (bool), reward (float|None), metadata (dict).
"""
run_id: str = ""
framework: str = "pytorch"
epoch: int = 20
training_loss_history: list[float] = Field(default_factory=list)
val_loss_history: list[float] = Field(default_factory=list)
val_accuracy_history: list[float] = Field(default_factory=list)
gradient_stats: list[GradientStats] = Field(default_factory=list)
model_weight_stats: Optional[list[ModelWeightStats]] = None
gpu_memory_used_gb: float = 6.2
gpu_memory_total_gb: float = 16.0
learning_rate: float = 0.001
current_config: TrainingConfig = Field(default_factory=TrainingConfig)
error_log: Optional[str] = None
data_batch_stats: Optional[DataBatchStats] = None
model_mode_info: Optional[dict[str, str]] = None
code_snippet: Optional[CodeSnippet] = None
available_actions: list[str] = Field(default_factory=list)
episode_state: EpisodeState = Field(default_factory=EpisodeState)
notes: Optional[str] = None
|