"""MLTrainingEnvironment — extends openenv Environment. Session isolation, progressive information reveal, error handling. step() never raises an unhandled exception. """ from __future__ import annotations import dataclasses import logging import uuid from typing import Any, Optional, Union import torch from openenv.core.env_server.interfaces import Environment from ml_training_debugger.code_templates import ( generate_code_snippet, validate_fix, ) from ml_training_debugger.graders import grade_episode from ml_training_debugger.models import ( ALL_ACTION_TYPES, VALID_CONFIG_KEYS, VALID_DIAGNOSES, CodeSnippet, DataBatchStats, EpisodeState, GradientStats, MLTrainingAction, MLTrainingObservation, ModelWeightStats, TrainingConfig, ) from ml_training_debugger.pytorch_engine import ( create_model_and_inject_fault, extract_gradient_stats, extract_model_modes, extract_weight_stats, ) from ml_training_debugger.reward_engine import compute_reward from ml_training_debugger.scenarios import ScenarioParams, sample_scenario from ml_training_debugger.simulation import ( gen_data_batch_stats, gen_loss_history, gen_val_accuracy_history, gen_val_loss_history, ) from server._baseline_results import store_grader_result logger = logging.getLogger(__name__) @dataclasses.dataclass class SessionData: """Per-session episode data.""" scenario: ScenarioParams model: torch.nn.Module state: EpisodeState config: TrainingConfig gradient_stats: list[GradientStats] weight_stats: list[ModelWeightStats] | None model_modes: dict[str, str] | None data_batch_stats_raw: dict[str, Union[int, float, list, dict, None]] | None code_snippet_raw: dict[str, Union[str, int, list, None]] | None loss_history: list[float] val_acc_history: list[float] val_loss_history: list[float] done: bool last_score: float | None convergence_after_fix: bool class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation, dict]): """OpenEnv environment for PyTorch training run debugging.""" SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._sessions: dict[str, SessionData] = {} self._last_completed: dict[str, dict] = {} self._current_session_id: str = "" def _get_session(self, episode_id: str | None = None) -> SessionData | None: sid = episode_id or self._current_session_id return self._sessions.get(sid) def _build_observation( self, session: SessionData, reward: float = 0.0 ) -> MLTrainingObservation: """Build observation from session data.""" state = session.state gradient_stats_models = [] if state.gradients_inspected and session.gradient_stats: gradient_stats_models = session.gradient_stats weight_stats_models = None if state.model_weights_inspected and session.weight_stats is not None: weight_stats_models = session.weight_stats data_batch = None if state.data_inspected and session.data_batch_stats_raw is not None: data_batch = DataBatchStats(**session.data_batch_stats_raw) model_modes = None if state.model_modes_inspected and session.model_modes is not None: model_modes = session.model_modes code_snippet = None if state.code_inspected and session.code_snippet_raw is not None: code_snippet = CodeSnippet(**session.code_snippet_raw) return MLTrainingObservation( run_id=self._current_session_id, framework="pytorch", epoch=20, training_loss_history=session.loss_history, val_loss_history=session.val_loss_history, val_accuracy_history=session.val_acc_history, gradient_stats=gradient_stats_models, model_weight_stats=weight_stats_models, gpu_memory_used_gb=session.scenario.gpu_memory_used_gb, gpu_memory_total_gb=16.0, learning_rate=session.config.learning_rate, current_config=session.config, error_log=session.scenario.error_log, data_batch_stats=data_batch, model_mode_info=model_modes, code_snippet=code_snippet, available_actions=state.compute_available_actions(), episode_state=state, notes=session.scenario.notes, done=session.done, reward=reward, ) def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> MLTrainingObservation: """Reset environment for a new episode.""" # Determine task_id — passed via kwargs or defaults to task_001 task_id = kwargs.get("task_id", "task_001") # If called with episode_id that has an active session, terminate it session_id = episode_id or str(uuid.uuid4()) if session_id in self._sessions: old = self._sessions[session_id] if not old.done: score = grade_episode(old.scenario.task_id, old.state, old.scenario) self._last_completed[session_id] = { "score": score, "task_id": old.scenario.task_id, "steps": old.state.step_count, } store_grader_result( session_id, score, old.scenario.task_id, old.state.step_count ) self._current_session_id = session_id # Derive deterministic seed and difficulty base_seed = seed if seed is not None else 42 difficulty_level = kwargs.get("difficulty_level", 3) scenario = sample_scenario(task_id, base_seed, difficulty_level=difficulty_level) # Set torch seed for reproducibility torch.manual_seed(scenario.seed) # Create real PyTorch model with fault injection model, info = create_model_and_inject_fault(scenario) # Generate parametric curves loss_history = gen_loss_history(scenario) val_acc_history = gen_val_accuracy_history(scenario) val_loss_history = gen_val_loss_history(scenario) # Pre-generate data batch stats data_batch_raw = gen_data_batch_stats(scenario) # Pre-generate code snippet (for Task 6) code_snippet_raw = None if scenario.bug_type is not None: code_snippet_raw = generate_code_snippet(scenario.bug_type, scenario.seed) # Build initial config from scenario config = TrainingConfig( learning_rate=scenario.learning_rate, weight_decay=scenario.weight_decay, ) # Create fresh episode state state = EpisodeState() session = SessionData( scenario=scenario, model=model, state=state, config=config, gradient_stats=[], weight_stats=None, model_modes=None, data_batch_stats_raw=data_batch_raw, code_snippet_raw=code_snippet_raw, loss_history=loss_history, val_acc_history=val_acc_history, val_loss_history=val_loss_history, done=False, last_score=None, convergence_after_fix=False, ) self._sessions[session_id] = session logger.info( "reset", extra={ "session_id": session_id, "task_id": task_id, "scenario_seed": scenario.seed, }, ) return self._build_observation(session) def step( self, action: MLTrainingAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> MLTrainingObservation: """Process one agent action. Never raises.""" session = self._get_session() # No active episode if session is None: return MLTrainingObservation( done=True, reward=0.0, error_log="Error: no active episode. Call reset(task_id) first.", ) # Episode already done if session.done: return self._build_observation(session, reward=0.0) state = session.state scenario = session.scenario action_type = action.action_type # Increment step count state.step_count += 1 # Validate action_type is a known type if action_type not in ALL_ACTION_TYPES: reward = compute_reward(action, state, scenario, is_valid_action=False) state.actions_taken.append(f"invalid:{action_type}") obs = self._build_observation(session, reward=reward) obs.error_log = ( f"Invalid action_type: {action_type}. " f"Valid types: {sorted(ALL_ACTION_TYPES)}" ) return obs # Check if action is in available_actions available = state.compute_available_actions() if action_type not in available: reward = compute_reward(action, state, scenario, is_valid_action=False) state.actions_taken.append(f"unavailable:{action_type}") obs = self._build_observation(session, reward=reward) obs.error_log = ( f"Action '{action_type}' not available. " f"Available: {available}" ) return obs # Validate required fields for specific actions error = self._validate_action_fields(action) if error is not None: reward = compute_reward(action, state, scenario, is_valid_action=False) state.actions_taken.append(f"malformed:{action_type}") obs = self._build_observation(session, reward=reward) obs.error_log = error return obs # Dispatch action is_correct_fix: bool | None = None convergence = False # Snapshot state BEFORE dispatch — reward engine needs pre-action state # to correctly compute investigation bonuses and context-gated penalties state_before = state.model_copy(deep=True) try: is_correct_fix, convergence = self._dispatch_action(action, session) except Exception as exc: logger.error( "step_error", extra={ "session_id": self._current_session_id, "action": action_type, "error": str(exc), }, exc_info=True, ) reward = compute_reward(action, state_before, scenario, is_valid_action=False) obs = self._build_observation(session, reward=reward) obs.error_log = f"Internal error processing {action_type}: {exc}" return obs # Record action if action_type == "mark_diagnosed" and action.diagnosis: state.actions_taken.append(f"mark_diagnosed:{action.diagnosis}") else: state.actions_taken.append(action_type) # Compute reward using pre-action state reward = compute_reward( action, state_before, scenario, is_valid_action=True, is_correct_fix=is_correct_fix, convergence_confirmed=convergence, ) # Check step limit if state.step_count >= scenario.max_steps and not session.done: session.done = True # Check done if session.done: score = grade_episode(scenario.task_id, state, scenario) session.last_score = score self._last_completed[self._current_session_id] = { "score": score, "task_id": scenario.task_id, "steps": state.step_count, } store_grader_result( self._current_session_id, score, scenario.task_id, state.step_count ) logger.info( "episode_completed", extra={ "session_id": self._current_session_id, "task_id": scenario.task_id, "steps": state.step_count, "score": score, }, ) logger.info( "step", extra={ "session_id": self._current_session_id, "step_count": state.step_count, "action_type": action_type, "reward": reward, }, ) return self._build_observation(session, reward=reward) def _validate_action_fields(self, action: MLTrainingAction) -> str | None: """Validate required fields for specific actions. Return error or None.""" if action.action_type == "modify_config": if action.target is None or action.value is None: return "modify_config requires 'target' and 'value' fields" if action.target not in VALID_CONFIG_KEYS: return f"Unknown config key: {action.target}. Valid: {sorted(VALID_CONFIG_KEYS)}" if action.action_type == "mark_diagnosed": if action.diagnosis is None: return "mark_diagnosed requires 'diagnosis' field" if action.diagnosis not in VALID_DIAGNOSES: return ( f"Invalid diagnosis: {action.diagnosis}. " f"Valid: {sorted(VALID_DIAGNOSES)}" ) if action.action_type == "fix_code": if action.line is None or action.replacement is None: return "fix_code requires 'line' and 'replacement' fields" return None def _dispatch_action( self, action: MLTrainingAction, session: SessionData ) -> tuple[bool | None, bool]: """Dispatch action to handler. Returns (is_correct_fix, convergence).""" state = session.state scenario = session.scenario is_correct_fix: bool | None = None convergence = False at = action.action_type if at == "inspect_gradients": if not state.gradients_inspected: stats = extract_gradient_stats(session.model, scenario) session.gradient_stats = stats state.gradients_inspected = True # Set gradients_were_normal: True if ALL layers is_exploding=False state.gradients_were_normal = all(not s.is_exploding for s in stats) elif at == "inspect_data_batch": state.data_inspected = True elif at == "inspect_model_modes": if not state.model_modes_inspected: modes = extract_model_modes(session.model) session.model_modes = modes state.model_modes_inspected = True elif at == "inspect_model_weights": if not state.model_weights_inspected: stats = extract_weight_stats(session.model) session.weight_stats = stats state.model_weights_inspected = True elif at == "inspect_code": state.code_inspected = True elif at == "modify_config": if action.target and action.value is not None: setattr(session.config, action.target, action.value) state.fix_action_taken = True elif at == "add_callback": state.fix_action_taken = True elif at == "replace_optimizer": state.fix_action_taken = True elif at == "patch_data_loader": state.fix_action_taken = True elif at == "fix_model_mode": state.fix_action_taken = True elif at == "fix_code": state.fix_action_taken = True if scenario.bug_type and action.line and action.replacement: is_correct_fix = validate_fix( scenario.bug_type, action.line, action.replacement ) else: is_correct_fix = False elif at == "restart_run": state.restart_after_fix = True # Check convergence — did the fix address the root cause? convergence = self._check_convergence(session) session.convergence_after_fix = convergence elif at == "mark_diagnosed": state.diagnosis_submitted = True session.done = True return is_correct_fix, convergence def _check_convergence(self, session: SessionData) -> bool: """Check if the applied fix would resolve the root cause.""" scenario = session.scenario state = session.state root = scenario.root_cause.value if root == "lr_too_high": return ( "modify_config" in state.actions_taken and session.config.learning_rate <= 0.001 ) if root == "vanishing_gradients": return ( "modify_config" in state.actions_taken and session.config.learning_rate >= 0.001 ) if root == "data_leakage": return "patch_data_loader" in state.actions_taken if root == "overfitting": return ( "modify_config" in state.actions_taken or "add_callback" in state.actions_taken ) if root == "batchnorm_eval_mode": return "fix_model_mode" in state.actions_taken if root == "code_bug": return "fix_code" in state.actions_taken and state.fix_action_taken if root == "scheduler_misconfigured": return "modify_config" in state.actions_taken return False @property def state(self) -> dict: """Return current environment state.""" session = self._get_session() if session is None: return {"status": "no_active_episode"} st = session.state return { "status": "active", "task_id": session.scenario.task_id, "step_count": st.step_count, "done": session.done, "gradients_inspected": st.gradients_inspected, "data_inspected": st.data_inspected, "model_modes_inspected": st.model_modes_inspected, "model_weights_inspected": st.model_weights_inspected, "code_inspected": st.code_inspected, "fix_action_taken": st.fix_action_taken, "restart_after_fix": st.restart_after_fix, "diagnosis_submitted": st.diagnosis_submitted, "available_actions": st.compute_available_actions(), } def get_last_completed(self, session_id: str | None = None) -> dict | None: """Get last completed episode data for grader endpoint.""" if session_id: return self._last_completed.get(session_id) # Return most recent if self._last_completed: return list(self._last_completed.values())[-1] return None