"""AX Platform optimizer backend for physics-informed BO.""" from typing import Callable, Dict, List, Optional, Tuple import torch from torch import Tensor from physics_informed_bo.config import OptimizationConfig from physics_informed_bo.optimizers.base_optimizer import BaseOptimizer class AXOptimizer(BaseOptimizer): """AX Platform backend for structured experiment design. AX provides a higher-level API for experiment management, including: - Structured parameter spaces with constraints - Multi-objective optimization - Human-in-the-loop experiments - Automatic model selection The physics model is injected as a custom BoTorch model generator within AX's modular framework. """ def __init__(self, config: OptimizationConfig): super().__init__(config) self._experiment = None self._gs = None self._parameter_names: List[str] = [] def setup_experiment( self, parameter_names: List[str], bounds: Dict[str, Tuple[float, float]], objective_name: str = "objective", minimize: bool = False, outcome_constraints: Optional[List[str]] = None, ) -> None: """Set up an AX experiment with physics-informed model. Args: parameter_names: Names of input parameters. bounds: Dict of {param_name: (lower, upper)}. objective_name: Name of the objective metric. minimize: Whether to minimize (True) or maximize (False). outcome_constraints: List of outcome constraint strings e.g. ["metric >= 0.5"]. """ try: from ax.service.ax_client import AxClient from ax.service.utils.instantiation import ObjectiveProperties except ImportError: raise ImportError( "AX Platform is required for AXOptimizer. " "Install with: pip install ax-platform" ) self._parameter_names = parameter_names # Build parameter list for AX parameters = [] for name in parameter_names: lb, ub = bounds[name] parameters.append( { "name": name, "type": "range", "bounds": [float(lb), float(ub)], "value_type": "float", } ) self._ax_client = AxClient(verbose_logging=False) self._ax_client.create_experiment( name="physics_informed_bo", parameters=parameters, objectives={ objective_name: ObjectiveProperties(minimize=minimize), }, ) def suggest( self, n_candidates: int = 1, X_observed: Optional[Tensor] = None, y_observed: Optional[Tensor] = None, ) -> Tensor: """Suggest next experiments using AX.""" if not hasattr(self, "_ax_client"): raise RuntimeError("Call setup_experiment() before suggesting.") candidates = [] trial_indices = [] for _ in range(n_candidates): parameters, trial_index = self._ax_client.get_next_trial() candidates.append([parameters[name] for name in self._parameter_names]) trial_indices.append(trial_index) self._last_trial_indices = trial_indices result = torch.tensor(candidates, dtype=torch.float64) # Filter through physics constraints result = self._filter_feasible(result) return result[:n_candidates] def update(self, X_new: Tensor, y_new: Tensor) -> None: """Report observations back to AX.""" if not hasattr(self, "_ax_client"): return for i, (x, y_val) in enumerate(zip(X_new, y_new)): if i < len(self._last_trial_indices): trial_index = self._last_trial_indices[i] self._ax_client.complete_trial( trial_index=trial_index, raw_data={"objective": (float(y_val), 0.0)}, ) def get_best_parameters(self) -> Dict: """Get the best parameters found so far.""" best_params, values = self._ax_client.get_best_parameters() return {"parameters": best_params, "values": values} def get_trials_dataframe(self): """Get all trials as a pandas DataFrame.""" return self._ax_client.get_trials_data_frame()