pytorch-training-debugger / tests /test_models.py
omkarrr88
Minor fixes
7336adb
"""Test all Pydantic models instantiate and serialize correctly."""
from __future__ import annotations
import json
from openenv.core.env_server.types import Action, Observation
from ml_training_debugger.models import (
EpisodeState,
GradientStats,
MLTrainingAction,
MLTrainingObservation,
RootCauseDiagnosis,
TrainingConfig,
)
class TestRootCauseDiagnosis:
def test_all_values_exist(self):
assert len(RootCauseDiagnosis) == 7
def test_values_are_strings(self):
for d in RootCauseDiagnosis:
assert isinstance(d.value, str)
def test_specific_values(self):
assert RootCauseDiagnosis.LR_TOO_HIGH.value == "lr_too_high"
assert RootCauseDiagnosis.CODE_BUG.value == "code_bug"
class TestTrainingConfig:
def test_default_instantiation(self):
config = TrainingConfig()
assert config.learning_rate == 0.001
assert config.gradient_clip_norm is None
def test_json_roundtrip(self):
config = TrainingConfig(learning_rate=0.01, weight_decay=0.1)
data = json.loads(config.model_dump_json())
restored = TrainingConfig.model_validate(data)
assert restored.learning_rate == 0.01
assert restored.weight_decay == 0.1
class TestGradientStats:
def test_exploding(self):
stats = GradientStats(
layer_name="fc",
norm_history=[15.0],
mean_norm=15.0,
max_norm=15.0,
is_exploding=True,
is_vanishing=False,
)
assert stats.is_exploding
def test_vanishing(self):
stats = GradientStats(
layer_name="conv1",
norm_history=[1e-7],
mean_norm=1e-7,
max_norm=1e-7,
is_exploding=False,
is_vanishing=True,
)
assert stats.is_vanishing
def test_normal(self):
stats = GradientStats(
layer_name="conv1",
norm_history=[0.5],
mean_norm=0.5,
max_norm=0.5,
is_exploding=False,
is_vanishing=False,
)
assert not stats.is_exploding
assert not stats.is_vanishing
class TestEpisodeState:
def test_fresh_state(self):
state = EpisodeState()
assert state.step_count == 0
assert not state.gradients_inspected
assert not state.diagnosis_submitted
def test_available_actions_initial(self):
state = EpisodeState()
actions = state.compute_available_actions()
assert "inspect_gradients" in actions
assert "mark_diagnosed" in actions
assert "fix_code" not in actions
assert "restart_run" not in actions
def test_fix_code_available_after_code_inspected(self):
state = EpisodeState(code_inspected=True)
actions = state.compute_available_actions()
assert "fix_code" in actions
def test_restart_run_available_after_fix(self):
state = EpisodeState(fix_action_taken=True)
actions = state.compute_available_actions()
assert "restart_run" in actions
def test_mark_diagnosed_disappears_after_submission(self):
state = EpisodeState(diagnosis_submitted=True)
actions = state.compute_available_actions()
assert "mark_diagnosed" not in actions
class TestMLTrainingObservation:
def test_extends_observation(self):
assert issubclass(MLTrainingObservation, Observation)
def test_has_done_and_reward(self):
obs = MLTrainingObservation(done=True, reward=0.5)
assert obs.done is True
assert obs.reward == 0.5
def test_json_serialization(self):
obs = MLTrainingObservation(
run_id="test",
training_loss_history=[1.0, 2.0],
val_accuracy_history=[0.5],
)
data = json.loads(obs.model_dump_json())
assert data["run_id"] == "test"
assert data["framework"] == "pytorch"
class TestMLTrainingAction:
def test_extends_action(self):
assert issubclass(MLTrainingAction, Action)
def test_basic_action(self):
action = MLTrainingAction(action_type="inspect_gradients")
assert action.action_type == "inspect_gradients"
def test_modify_config_action(self):
action = MLTrainingAction(
action_type="modify_config",
target="learning_rate",
value=0.001,
)
assert action.target == "learning_rate"
def test_mark_diagnosed_action(self):
action = MLTrainingAction(
action_type="mark_diagnosed",
diagnosis="lr_too_high",
)
assert action.diagnosis == "lr_too_high"
def test_fix_code_action(self):
action = MLTrainingAction(
action_type="fix_code",
line=13,
replacement="loss = criterion(output, batch_y)",
)
assert action.line == 13