omkarrr88
minor changes
206438f
"""All Pydantic models, enums, and typed data structures.
No business logic. Pure data definitions.
"""
from __future__ import annotations
import enum
from typing import Optional, Union
import torch # noqa: F401
from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field
class RootCauseDiagnosis(str, enum.Enum):
"""Closed enumeration of ML failure root causes."""
LR_TOO_HIGH = "lr_too_high"
VANISHING_GRADIENTS = "vanishing_gradients"
DATA_LEAKAGE = "data_leakage"
OVERFITTING = "overfitting"
BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
CODE_BUG = "code_bug"
SCHEDULER_MISCONFIGURED = "scheduler_misconfigured"
VALID_DIAGNOSES: set[str] = {d.value for d in RootCauseDiagnosis}
class TrainingConfig(BaseModel):
"""Typed hyperparameter configuration."""
learning_rate: float = 0.001
weight_decay: float = 0.0001
batch_size: int = 64
hidden_dim: int = 64
num_layers: int = 3
optimizer: str = "adam"
dropout_rate: float = 0.0
gradient_clip_norm: Optional[float] = None
VALID_CONFIG_KEYS: set[str] = set(TrainingConfig.model_fields.keys())
class GradientStats(BaseModel):
"""Per-layer gradient information from real torch.autograd."""
layer_name: str
norm_history: list[float]
mean_norm: float
max_norm: float
is_exploding: bool # True when mean_norm > 10.0
is_vanishing: bool # True when mean_norm < 1e-6
class ModelWeightStats(BaseModel):
"""Per-layer weight statistics from real state_dict()."""
layer_name: str
weight_norm: float
weight_mean: float
weight_std: float
weight_min: float
weight_max: float
dead_neuron_pct: float = 0.0
has_nan: bool = False
has_inf: bool = False
class DataBatchStats(BaseModel):
"""Data batch inspection results."""
label_distribution: dict[int, float]
feature_mean: float
feature_std: float
null_count: int = 0
class_overlap_score: float
batch_size: int
duplicate_ratio: float = 0.0
confusion_matrix: Optional[list[list[float]]] = None
class CodeSnippet(BaseModel):
"""PyTorch code for Task 6 inspection."""
code: str
filename: str = "train.py"
line_count: int
imports: list[str]
hint: Optional[str] = None
class EpisodeState(BaseModel):
"""Tracks agent history within an episode."""
step_count: int = 0
gradients_inspected: bool = False
gradients_were_normal: bool = False
data_inspected: bool = False
model_modes_inspected: bool = False
model_weights_inspected: bool = False
code_inspected: bool = False
fix_action_taken: bool = False
restart_after_fix: bool = False
diagnosis_submitted: bool = False
actions_taken: list[str] = Field(default_factory=list)
def compute_available_actions(self) -> list[str]:
"""Dynamically compute available actions based on current state."""
actions: list[str] = [
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
"modify_config",
"add_callback",
"replace_optimizer",
"patch_data_loader",
"fix_model_mode",
]
if self.code_inspected:
actions.append("fix_code")
if self.fix_action_taken:
actions.append("restart_run")
if not self.diagnosis_submitted:
actions.append("mark_diagnosed")
return actions
ALL_ACTION_TYPES: set[str] = {
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_model_weights",
"inspect_code",
"modify_config",
"add_callback",
"replace_optimizer",
"patch_data_loader",
"fix_model_mode",
"fix_code",
"restart_run",
"mark_diagnosed",
}
class MLTrainingAction(Action):
"""What the agent can do — extends openenv Action."""
action_type: str
target: Optional[str] = None
value: Optional[Union[float, int, str]] = None
diagnosis: Optional[str] = None
line: Optional[int] = None
replacement: Optional[str] = None
class MLTrainingObservation(Observation):
"""Full observation — extends openenv Observation.
Observation base has built-in: done (bool), reward (float|None), metadata (dict).
"""
run_id: str = ""
framework: str = "pytorch"
epoch: int = 20
training_loss_history: list[float] = Field(default_factory=list)
val_loss_history: list[float] = Field(default_factory=list)
val_accuracy_history: list[float] = Field(default_factory=list)
gradient_stats: list[GradientStats] = Field(default_factory=list)
model_weight_stats: Optional[list[ModelWeightStats]] = None
gpu_memory_used_gb: float = 6.2
gpu_memory_total_gb: float = 16.0
learning_rate: float = 0.001
current_config: TrainingConfig = Field(default_factory=TrainingConfig)
error_log: Optional[str] = None
data_batch_stats: Optional[DataBatchStats] = None
model_mode_info: Optional[dict[str, str]] = None
code_snippet: Optional[CodeSnippet] = None
available_actions: list[str] = Field(default_factory=list)
episode_state: EpisodeState = Field(default_factory=EpisodeState)
notes: Optional[str] = None