| """AX Platform optimizer backend for physics-informed BO."""
|
|
|
| from typing import Callable, Dict, List, Optional, Tuple
|
|
|
| import torch
|
| from torch import Tensor
|
|
|
| from physics_informed_bo.config import OptimizationConfig
|
| from physics_informed_bo.optimizers.base_optimizer import BaseOptimizer
|
|
|
|
|
| class AXOptimizer(BaseOptimizer):
|
| """AX Platform backend for structured experiment design.
|
|
|
| AX provides a higher-level API for experiment management, including:
|
| - Structured parameter spaces with constraints
|
| - Multi-objective optimization
|
| - Human-in-the-loop experiments
|
| - Automatic model selection
|
|
|
| The physics model is injected as a custom BoTorch model generator
|
| within AX's modular framework.
|
| """
|
|
|
| def __init__(self, config: OptimizationConfig):
|
| super().__init__(config)
|
| self._experiment = None
|
| self._gs = None
|
| self._parameter_names: List[str] = []
|
|
|
| def setup_experiment(
|
| self,
|
| parameter_names: List[str],
|
| bounds: Dict[str, Tuple[float, float]],
|
| objective_name: str = "objective",
|
| minimize: bool = False,
|
| outcome_constraints: Optional[List[str]] = None,
|
| ) -> None:
|
| """Set up an AX experiment with physics-informed model.
|
|
|
| Args:
|
| parameter_names: Names of input parameters.
|
| bounds: Dict of {param_name: (lower, upper)}.
|
| objective_name: Name of the objective metric.
|
| minimize: Whether to minimize (True) or maximize (False).
|
| outcome_constraints: List of outcome constraint strings e.g. ["metric >= 0.5"].
|
| """
|
| try:
|
| from ax.service.ax_client import AxClient
|
| from ax.service.utils.instantiation import ObjectiveProperties
|
| except ImportError:
|
| raise ImportError(
|
| "AX Platform is required for AXOptimizer. "
|
| "Install with: pip install ax-platform"
|
| )
|
|
|
| self._parameter_names = parameter_names
|
|
|
|
|
| parameters = []
|
| for name in parameter_names:
|
| lb, ub = bounds[name]
|
| parameters.append(
|
| {
|
| "name": name,
|
| "type": "range",
|
| "bounds": [float(lb), float(ub)],
|
| "value_type": "float",
|
| }
|
| )
|
|
|
| self._ax_client = AxClient(verbose_logging=False)
|
| self._ax_client.create_experiment(
|
| name="physics_informed_bo",
|
| parameters=parameters,
|
| objectives={
|
| objective_name: ObjectiveProperties(minimize=minimize),
|
| },
|
| )
|
|
|
| def suggest(
|
| self,
|
| n_candidates: int = 1,
|
| X_observed: Optional[Tensor] = None,
|
| y_observed: Optional[Tensor] = None,
|
| ) -> Tensor:
|
| """Suggest next experiments using AX."""
|
| if not hasattr(self, "_ax_client"):
|
| raise RuntimeError("Call setup_experiment() before suggesting.")
|
|
|
| candidates = []
|
| trial_indices = []
|
|
|
| for _ in range(n_candidates):
|
| parameters, trial_index = self._ax_client.get_next_trial()
|
| candidates.append([parameters[name] for name in self._parameter_names])
|
| trial_indices.append(trial_index)
|
|
|
| self._last_trial_indices = trial_indices
|
| result = torch.tensor(candidates, dtype=torch.float64)
|
|
|
|
|
| result = self._filter_feasible(result)
|
| return result[:n_candidates]
|
|
|
| def update(self, X_new: Tensor, y_new: Tensor) -> None:
|
| """Report observations back to AX."""
|
| if not hasattr(self, "_ax_client"):
|
| return
|
|
|
| for i, (x, y_val) in enumerate(zip(X_new, y_new)):
|
| if i < len(self._last_trial_indices):
|
| trial_index = self._last_trial_indices[i]
|
| self._ax_client.complete_trial(
|
| trial_index=trial_index,
|
| raw_data={"objective": (float(y_val), 0.0)},
|
| )
|
|
|
| def get_best_parameters(self) -> Dict:
|
| """Get the best parameters found so far."""
|
| best_params, values = self._ax_client.get_best_parameters()
|
| return {"parameters": best_params, "values": values}
|
|
|
| def get_trials_dataframe(self):
|
| """Get all trials as a pandas DataFrame."""
|
| return self._ax_client.get_trials_data_frame()
|
|
|