ravimohan19's picture
Upload optimizers/ax_optimizer.py with huggingface_hub
28656f7 verified
"""AX Platform optimizer backend for physics-informed BO."""
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from physics_informed_bo.config import OptimizationConfig
from physics_informed_bo.optimizers.base_optimizer import BaseOptimizer
class AXOptimizer(BaseOptimizer):
"""AX Platform backend for structured experiment design.
AX provides a higher-level API for experiment management, including:
- Structured parameter spaces with constraints
- Multi-objective optimization
- Human-in-the-loop experiments
- Automatic model selection
The physics model is injected as a custom BoTorch model generator
within AX's modular framework.
"""
def __init__(self, config: OptimizationConfig):
super().__init__(config)
self._experiment = None
self._gs = None
self._parameter_names: List[str] = []
def setup_experiment(
self,
parameter_names: List[str],
bounds: Dict[str, Tuple[float, float]],
objective_name: str = "objective",
minimize: bool = False,
outcome_constraints: Optional[List[str]] = None,
) -> None:
"""Set up an AX experiment with physics-informed model.
Args:
parameter_names: Names of input parameters.
bounds: Dict of {param_name: (lower, upper)}.
objective_name: Name of the objective metric.
minimize: Whether to minimize (True) or maximize (False).
outcome_constraints: List of outcome constraint strings e.g. ["metric >= 0.5"].
"""
try:
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
except ImportError:
raise ImportError(
"AX Platform is required for AXOptimizer. "
"Install with: pip install ax-platform"
)
self._parameter_names = parameter_names
# Build parameter list for AX
parameters = []
for name in parameter_names:
lb, ub = bounds[name]
parameters.append(
{
"name": name,
"type": "range",
"bounds": [float(lb), float(ub)],
"value_type": "float",
}
)
self._ax_client = AxClient(verbose_logging=False)
self._ax_client.create_experiment(
name="physics_informed_bo",
parameters=parameters,
objectives={
objective_name: ObjectiveProperties(minimize=minimize),
},
)
def suggest(
self,
n_candidates: int = 1,
X_observed: Optional[Tensor] = None,
y_observed: Optional[Tensor] = None,
) -> Tensor:
"""Suggest next experiments using AX."""
if not hasattr(self, "_ax_client"):
raise RuntimeError("Call setup_experiment() before suggesting.")
candidates = []
trial_indices = []
for _ in range(n_candidates):
parameters, trial_index = self._ax_client.get_next_trial()
candidates.append([parameters[name] for name in self._parameter_names])
trial_indices.append(trial_index)
self._last_trial_indices = trial_indices
result = torch.tensor(candidates, dtype=torch.float64)
# Filter through physics constraints
result = self._filter_feasible(result)
return result[:n_candidates]
def update(self, X_new: Tensor, y_new: Tensor) -> None:
"""Report observations back to AX."""
if not hasattr(self, "_ax_client"):
return
for i, (x, y_val) in enumerate(zip(X_new, y_new)):
if i < len(self._last_trial_indices):
trial_index = self._last_trial_indices[i]
self._ax_client.complete_trial(
trial_index=trial_index,
raw_data={"objective": (float(y_val), 0.0)},
)
def get_best_parameters(self) -> Dict:
"""Get the best parameters found so far."""
best_params, values = self._ax_client.get_best_parameters()
return {"parameters": best_params, "values": values}
def get_trials_dataframe(self):
"""Get all trials as a pandas DataFrame."""
return self._ax_client.get_trials_data_frame()