File size: 4,559 Bytes
28656f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""AX Platform optimizer backend for physics-informed BO."""

from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor

from physics_informed_bo.config import OptimizationConfig
from physics_informed_bo.optimizers.base_optimizer import BaseOptimizer


class AXOptimizer(BaseOptimizer):
    """AX Platform backend for structured experiment design.



    AX provides a higher-level API for experiment management, including:

    - Structured parameter spaces with constraints

    - Multi-objective optimization

    - Human-in-the-loop experiments

    - Automatic model selection



    The physics model is injected as a custom BoTorch model generator

    within AX's modular framework.

    """

    def __init__(self, config: OptimizationConfig):
        super().__init__(config)
        self._experiment = None
        self._gs = None
        self._parameter_names: List[str] = []

    def setup_experiment(

        self,

        parameter_names: List[str],

        bounds: Dict[str, Tuple[float, float]],

        objective_name: str = "objective",

        minimize: bool = False,

        outcome_constraints: Optional[List[str]] = None,

    ) -> None:
        """Set up an AX experiment with physics-informed model.



        Args:

            parameter_names: Names of input parameters.

            bounds: Dict of {param_name: (lower, upper)}.

            objective_name: Name of the objective metric.

            minimize: Whether to minimize (True) or maximize (False).

            outcome_constraints: List of outcome constraint strings e.g. ["metric >= 0.5"].

        """
        try:
            from ax.service.ax_client import AxClient
            from ax.service.utils.instantiation import ObjectiveProperties
        except ImportError:
            raise ImportError(
                "AX Platform is required for AXOptimizer. "
                "Install with: pip install ax-platform"
            )

        self._parameter_names = parameter_names

        # Build parameter list for AX
        parameters = []
        for name in parameter_names:
            lb, ub = bounds[name]
            parameters.append(
                {
                    "name": name,
                    "type": "range",
                    "bounds": [float(lb), float(ub)],
                    "value_type": "float",
                }
            )

        self._ax_client = AxClient(verbose_logging=False)
        self._ax_client.create_experiment(
            name="physics_informed_bo",
            parameters=parameters,
            objectives={
                objective_name: ObjectiveProperties(minimize=minimize),
            },
        )

    def suggest(

        self,

        n_candidates: int = 1,

        X_observed: Optional[Tensor] = None,

        y_observed: Optional[Tensor] = None,

    ) -> Tensor:
        """Suggest next experiments using AX."""
        if not hasattr(self, "_ax_client"):
            raise RuntimeError("Call setup_experiment() before suggesting.")

        candidates = []
        trial_indices = []

        for _ in range(n_candidates):
            parameters, trial_index = self._ax_client.get_next_trial()
            candidates.append([parameters[name] for name in self._parameter_names])
            trial_indices.append(trial_index)

        self._last_trial_indices = trial_indices
        result = torch.tensor(candidates, dtype=torch.float64)

        # Filter through physics constraints
        result = self._filter_feasible(result)
        return result[:n_candidates]

    def update(self, X_new: Tensor, y_new: Tensor) -> None:
        """Report observations back to AX."""
        if not hasattr(self, "_ax_client"):
            return

        for i, (x, y_val) in enumerate(zip(X_new, y_new)):
            if i < len(self._last_trial_indices):
                trial_index = self._last_trial_indices[i]
                self._ax_client.complete_trial(
                    trial_index=trial_index,
                    raw_data={"objective": (float(y_val), 0.0)},
                )

    def get_best_parameters(self) -> Dict:
        """Get the best parameters found so far."""
        best_params, values = self._ax_client.get_best_parameters()
        return {"parameters": best_params, "values": values}

    def get_trials_dataframe(self):
        """Get all trials as a pandas DataFrame."""
        return self._ax_client.get_trials_data_frame()