| """MLTrainingEnvironment — extends openenv Environment. |
| |
| Session isolation, progressive information reveal, error handling. |
| step() never raises an unhandled exception. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import dataclasses |
| import logging |
| import uuid |
| from typing import Any, Optional, Union |
|
|
| import torch |
| from openenv.core.env_server.interfaces import Environment |
|
|
| from ml_training_debugger.code_templates import ( |
| generate_code_snippet, |
| validate_fix, |
| ) |
| from ml_training_debugger.graders import grade_episode |
| from ml_training_debugger.models import ( |
| ALL_ACTION_TYPES, |
| VALID_CONFIG_KEYS, |
| VALID_DIAGNOSES, |
| CodeSnippet, |
| DataBatchStats, |
| EpisodeState, |
| GradientStats, |
| MLTrainingAction, |
| MLTrainingObservation, |
| ModelWeightStats, |
| TrainingConfig, |
| ) |
| from ml_training_debugger.pytorch_engine import ( |
| create_model_and_inject_fault, |
| extract_gradient_stats, |
| extract_model_modes, |
| extract_weight_stats, |
| ) |
| from ml_training_debugger.reward_engine import compute_reward |
| from ml_training_debugger.scenarios import ScenarioParams, sample_scenario |
| from ml_training_debugger.simulation import ( |
| gen_data_batch_stats, |
| gen_loss_history, |
| gen_val_accuracy_history, |
| gen_val_loss_history, |
| ) |
| from server._baseline_results import store_grader_result |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclasses.dataclass |
| class SessionData: |
| """Per-session episode data.""" |
|
|
| scenario: ScenarioParams |
| model: torch.nn.Module |
| state: EpisodeState |
| config: TrainingConfig |
| gradient_stats: list[GradientStats] |
| weight_stats: list[ModelWeightStats] | None |
| model_modes: dict[str, str] | None |
| data_batch_stats_raw: dict[str, Union[int, float, list, dict, None]] | None |
| code_snippet_raw: dict[str, Union[str, int, list, None]] | None |
| loss_history: list[float] |
| val_acc_history: list[float] |
| val_loss_history: list[float] |
| done: bool |
| last_score: float | None |
| convergence_after_fix: bool |
|
|
|
|
| class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation, dict]): |
| """OpenEnv environment for PyTorch training run debugging.""" |
|
|
| SUPPORTS_CONCURRENT_SESSIONS = True |
|
|
| def __init__(self, **kwargs: Any) -> None: |
| super().__init__(**kwargs) |
| self._sessions: dict[str, SessionData] = {} |
| self._last_completed: dict[str, dict] = {} |
| self._current_session_id: str = "" |
|
|
| def _get_session(self, episode_id: str | None = None) -> SessionData | None: |
| sid = episode_id or self._current_session_id |
| return self._sessions.get(sid) |
|
|
| def _build_observation( |
| self, session: SessionData, reward: float = 0.0 |
| ) -> MLTrainingObservation: |
| """Build observation from session data.""" |
| state = session.state |
|
|
| gradient_stats_models = [] |
| if state.gradients_inspected and session.gradient_stats: |
| gradient_stats_models = session.gradient_stats |
|
|
| weight_stats_models = None |
| if state.model_weights_inspected and session.weight_stats is not None: |
| weight_stats_models = session.weight_stats |
|
|
| data_batch = None |
| if state.data_inspected and session.data_batch_stats_raw is not None: |
| data_batch = DataBatchStats(**session.data_batch_stats_raw) |
|
|
| model_modes = None |
| if state.model_modes_inspected and session.model_modes is not None: |
| model_modes = session.model_modes |
|
|
| code_snippet = None |
| if state.code_inspected and session.code_snippet_raw is not None: |
| code_snippet = CodeSnippet(**session.code_snippet_raw) |
|
|
| return MLTrainingObservation( |
| run_id=self._current_session_id, |
| framework="pytorch", |
| epoch=20, |
| training_loss_history=session.loss_history, |
| val_loss_history=session.val_loss_history, |
| val_accuracy_history=session.val_acc_history, |
| gradient_stats=gradient_stats_models, |
| model_weight_stats=weight_stats_models, |
| gpu_memory_used_gb=session.scenario.gpu_memory_used_gb, |
| gpu_memory_total_gb=16.0, |
| learning_rate=session.config.learning_rate, |
| current_config=session.config, |
| error_log=session.scenario.error_log, |
| data_batch_stats=data_batch, |
| model_mode_info=model_modes, |
| code_snippet=code_snippet, |
| available_actions=state.compute_available_actions(), |
| episode_state=state, |
| notes=session.scenario.notes, |
| done=session.done, |
| reward=reward, |
| ) |
|
|
| def reset( |
| self, |
| seed: Optional[int] = None, |
| episode_id: Optional[str] = None, |
| **kwargs: Any, |
| ) -> MLTrainingObservation: |
| """Reset environment for a new episode.""" |
| |
| task_id = kwargs.get("task_id", "task_001") |
|
|
| |
| session_id = episode_id or str(uuid.uuid4()) |
| if session_id in self._sessions: |
| old = self._sessions[session_id] |
| if not old.done: |
| score = grade_episode(old.scenario.task_id, old.state, old.scenario) |
| self._last_completed[session_id] = { |
| "score": score, |
| "task_id": old.scenario.task_id, |
| "steps": old.state.step_count, |
| } |
| store_grader_result( |
| session_id, score, old.scenario.task_id, old.state.step_count |
| ) |
|
|
| self._current_session_id = session_id |
|
|
| |
| base_seed = seed if seed is not None else 42 |
| difficulty_level = kwargs.get("difficulty_level", 3) |
| scenario = sample_scenario(task_id, base_seed, difficulty_level=difficulty_level) |
|
|
| |
| torch.manual_seed(scenario.seed) |
|
|
| |
| model, info = create_model_and_inject_fault(scenario) |
|
|
| |
| loss_history = gen_loss_history(scenario) |
| val_acc_history = gen_val_accuracy_history(scenario) |
| val_loss_history = gen_val_loss_history(scenario) |
|
|
| |
| data_batch_raw = gen_data_batch_stats(scenario) |
|
|
| |
| code_snippet_raw = None |
| if scenario.bug_type is not None: |
| code_snippet_raw = generate_code_snippet(scenario.bug_type, scenario.seed) |
|
|
| |
| config = TrainingConfig( |
| learning_rate=scenario.learning_rate, |
| weight_decay=scenario.weight_decay, |
| ) |
|
|
| |
| state = EpisodeState() |
|
|
| session = SessionData( |
| scenario=scenario, |
| model=model, |
| state=state, |
| config=config, |
| gradient_stats=[], |
| weight_stats=None, |
| model_modes=None, |
| data_batch_stats_raw=data_batch_raw, |
| code_snippet_raw=code_snippet_raw, |
| loss_history=loss_history, |
| val_acc_history=val_acc_history, |
| val_loss_history=val_loss_history, |
| done=False, |
| last_score=None, |
| convergence_after_fix=False, |
| ) |
|
|
| self._sessions[session_id] = session |
|
|
| logger.info( |
| "reset", |
| extra={ |
| "session_id": session_id, |
| "task_id": task_id, |
| "scenario_seed": scenario.seed, |
| }, |
| ) |
|
|
| return self._build_observation(session) |
|
|
| def step( |
| self, |
| action: MLTrainingAction, |
| timeout_s: Optional[float] = None, |
| **kwargs: Any, |
| ) -> MLTrainingObservation: |
| """Process one agent action. Never raises.""" |
| session = self._get_session() |
|
|
| |
| if session is None: |
| return MLTrainingObservation( |
| done=True, |
| reward=0.0, |
| error_log="Error: no active episode. Call reset(task_id) first.", |
| ) |
|
|
| |
| if session.done: |
| return self._build_observation(session, reward=0.0) |
|
|
| state = session.state |
| scenario = session.scenario |
| action_type = action.action_type |
|
|
| |
| state.step_count += 1 |
|
|
| |
| if action_type not in ALL_ACTION_TYPES: |
| reward = compute_reward(action, state, scenario, is_valid_action=False) |
| state.actions_taken.append(f"invalid:{action_type}") |
| obs = self._build_observation(session, reward=reward) |
| obs.error_log = ( |
| f"Invalid action_type: {action_type}. " |
| f"Valid types: {sorted(ALL_ACTION_TYPES)}" |
| ) |
| return obs |
|
|
| |
| available = state.compute_available_actions() |
| if action_type not in available: |
| reward = compute_reward(action, state, scenario, is_valid_action=False) |
| state.actions_taken.append(f"unavailable:{action_type}") |
| obs = self._build_observation(session, reward=reward) |
| obs.error_log = ( |
| f"Action '{action_type}' not available. " f"Available: {available}" |
| ) |
| return obs |
|
|
| |
| error = self._validate_action_fields(action) |
| if error is not None: |
| reward = compute_reward(action, state, scenario, is_valid_action=False) |
| state.actions_taken.append(f"malformed:{action_type}") |
| obs = self._build_observation(session, reward=reward) |
| obs.error_log = error |
| return obs |
|
|
| |
| is_correct_fix: bool | None = None |
| convergence = False |
|
|
| |
| |
| state_before = state.model_copy(deep=True) |
|
|
| try: |
| is_correct_fix, convergence = self._dispatch_action(action, session) |
| except Exception as exc: |
| logger.error( |
| "step_error", |
| extra={ |
| "session_id": self._current_session_id, |
| "action": action_type, |
| "error": str(exc), |
| }, |
| exc_info=True, |
| ) |
| reward = compute_reward(action, state_before, scenario, is_valid_action=False) |
| obs = self._build_observation(session, reward=reward) |
| obs.error_log = f"Internal error processing {action_type}: {exc}" |
| return obs |
|
|
| |
| if action_type == "mark_diagnosed" and action.diagnosis: |
| state.actions_taken.append(f"mark_diagnosed:{action.diagnosis}") |
| else: |
| state.actions_taken.append(action_type) |
|
|
| |
| reward = compute_reward( |
| action, |
| state_before, |
| scenario, |
| is_valid_action=True, |
| is_correct_fix=is_correct_fix, |
| convergence_confirmed=convergence, |
| ) |
|
|
| |
| if state.step_count >= scenario.max_steps and not session.done: |
| session.done = True |
|
|
| |
| if session.done: |
| score = grade_episode(scenario.task_id, state, scenario) |
| session.last_score = score |
| self._last_completed[self._current_session_id] = { |
| "score": score, |
| "task_id": scenario.task_id, |
| "steps": state.step_count, |
| } |
| store_grader_result( |
| self._current_session_id, score, scenario.task_id, state.step_count |
| ) |
| logger.info( |
| "episode_completed", |
| extra={ |
| "session_id": self._current_session_id, |
| "task_id": scenario.task_id, |
| "steps": state.step_count, |
| "score": score, |
| }, |
| ) |
|
|
| logger.info( |
| "step", |
| extra={ |
| "session_id": self._current_session_id, |
| "step_count": state.step_count, |
| "action_type": action_type, |
| "reward": reward, |
| }, |
| ) |
|
|
| return self._build_observation(session, reward=reward) |
|
|
| def _validate_action_fields(self, action: MLTrainingAction) -> str | None: |
| """Validate required fields for specific actions. Return error or None.""" |
| if action.action_type == "modify_config": |
| if action.target is None or action.value is None: |
| return "modify_config requires 'target' and 'value' fields" |
| if action.target not in VALID_CONFIG_KEYS: |
| return f"Unknown config key: {action.target}. Valid: {sorted(VALID_CONFIG_KEYS)}" |
|
|
| if action.action_type == "mark_diagnosed": |
| if action.diagnosis is None: |
| return "mark_diagnosed requires 'diagnosis' field" |
| if action.diagnosis not in VALID_DIAGNOSES: |
| return ( |
| f"Invalid diagnosis: {action.diagnosis}. " |
| f"Valid: {sorted(VALID_DIAGNOSES)}" |
| ) |
|
|
| if action.action_type == "fix_code": |
| if action.line is None or action.replacement is None: |
| return "fix_code requires 'line' and 'replacement' fields" |
|
|
| return None |
|
|
| def _dispatch_action( |
| self, action: MLTrainingAction, session: SessionData |
| ) -> tuple[bool | None, bool]: |
| """Dispatch action to handler. Returns (is_correct_fix, convergence).""" |
| state = session.state |
| scenario = session.scenario |
| is_correct_fix: bool | None = None |
| convergence = False |
|
|
| at = action.action_type |
|
|
| if at == "inspect_gradients": |
| if not state.gradients_inspected: |
| stats = extract_gradient_stats(session.model, scenario) |
| session.gradient_stats = stats |
| state.gradients_inspected = True |
| |
| state.gradients_were_normal = all(not s.is_exploding for s in stats) |
|
|
| elif at == "inspect_data_batch": |
| state.data_inspected = True |
|
|
| elif at == "inspect_model_modes": |
| if not state.model_modes_inspected: |
| modes = extract_model_modes(session.model) |
| session.model_modes = modes |
| state.model_modes_inspected = True |
|
|
| elif at == "inspect_model_weights": |
| if not state.model_weights_inspected: |
| stats = extract_weight_stats(session.model) |
| session.weight_stats = stats |
| state.model_weights_inspected = True |
|
|
| elif at == "inspect_code": |
| state.code_inspected = True |
|
|
| elif at == "modify_config": |
| if action.target and action.value is not None: |
| setattr(session.config, action.target, action.value) |
| state.fix_action_taken = True |
|
|
| elif at == "add_callback": |
| state.fix_action_taken = True |
|
|
| elif at == "replace_optimizer": |
| state.fix_action_taken = True |
|
|
| elif at == "patch_data_loader": |
| state.fix_action_taken = True |
|
|
| elif at == "fix_model_mode": |
| state.fix_action_taken = True |
|
|
| elif at == "fix_code": |
| state.fix_action_taken = True |
| if scenario.bug_type and action.line and action.replacement: |
| is_correct_fix = validate_fix( |
| scenario.bug_type, action.line, action.replacement |
| ) |
| else: |
| is_correct_fix = False |
|
|
| elif at == "restart_run": |
| state.restart_after_fix = True |
| |
| convergence = self._check_convergence(session) |
| session.convergence_after_fix = convergence |
|
|
| elif at == "mark_diagnosed": |
| state.diagnosis_submitted = True |
| session.done = True |
|
|
| return is_correct_fix, convergence |
|
|
| def _check_convergence(self, session: SessionData) -> bool: |
| """Check if the applied fix would resolve the root cause.""" |
| scenario = session.scenario |
| state = session.state |
| root = scenario.root_cause.value |
|
|
| if root == "lr_too_high": |
| return ( |
| "modify_config" in state.actions_taken |
| and session.config.learning_rate <= 0.001 |
| ) |
|
|
| if root == "vanishing_gradients": |
| return ( |
| "modify_config" in state.actions_taken |
| and session.config.learning_rate >= 0.001 |
| ) |
|
|
| if root == "data_leakage": |
| return "patch_data_loader" in state.actions_taken |
|
|
| if root == "overfitting": |
| return ( |
| "modify_config" in state.actions_taken |
| or "add_callback" in state.actions_taken |
| ) |
|
|
| if root == "batchnorm_eval_mode": |
| return "fix_model_mode" in state.actions_taken |
|
|
| if root == "code_bug": |
| return "fix_code" in state.actions_taken and state.fix_action_taken |
|
|
| if root == "scheduler_misconfigured": |
| return "modify_config" in state.actions_taken |
|
|
| return False |
|
|
| @property |
| def state(self) -> dict: |
| """Return current environment state.""" |
| session = self._get_session() |
| if session is None: |
| return {"status": "no_active_episode"} |
| st = session.state |
| return { |
| "status": "active", |
| "task_id": session.scenario.task_id, |
| "step_count": st.step_count, |
| "done": session.done, |
| "gradients_inspected": st.gradients_inspected, |
| "data_inspected": st.data_inspected, |
| "model_modes_inspected": st.model_modes_inspected, |
| "model_weights_inspected": st.model_weights_inspected, |
| "code_inspected": st.code_inspected, |
| "fix_action_taken": st.fix_action_taken, |
| "restart_after_fix": st.restart_after_fix, |
| "diagnosis_submitted": st.diagnosis_submitted, |
| "available_actions": st.compute_available_actions(), |
| } |
|
|
| def get_last_completed(self, session_id: str | None = None) -> dict | None: |
| """Get last completed episode data for grader endpoint.""" |
| if session_id: |
| return self._last_completed.get(session_id) |
| |
| if self._last_completed: |
| return list(self._last_completed.values())[-1] |
| return None |
|
|