File size: 5,519 Bytes
88d6cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""BoTorch-based optimizer backend."""

from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor

from botorch.acquisition import (
    ExpectedImprovement,
    UpperConfidenceBound,
    ProbabilityOfImprovement,
    qExpectedImprovement,
    qNoisyExpectedImprovement,
    qKnowledgeGradient,
)
from botorch.optim import optimize_acqf
from botorch.utils.transforms import standardize, normalize, unnormalize

from physics_informed_bo.config import AcquisitionType, OptimizationConfig
from physics_informed_bo.optimizers.base_optimizer import BaseOptimizer


class PhysicsInformedEI(ExpectedImprovement):
    """Custom acquisition that penalizes physically implausible regions.



    Multiplies standard EI by a feasibility probability derived from

    physics constraints, steering the search toward physically valid regions.

    """

    def __init__(self, model, best_f, physics_prior=None, penalty_weight=10.0, **kwargs):
        super().__init__(model=model, best_f=best_f, **kwargs)
        self.physics_prior = physics_prior
        self.penalty_weight = penalty_weight

    def forward(self, X: Tensor) -> Tensor:
        ei = super().forward(X)

        if self.physics_prior is not None:
            # Compute constraint violation penalty
            X_2d = X.squeeze(1) if X.dim() == 3 else X
            violation = self.physics_prior.constraint_violation(X_2d)
            feasibility = torch.exp(-self.penalty_weight * violation)
            ei = ei * feasibility

        return ei


class BoTorchOptimizer(BaseOptimizer):
    """BoTorch-based Bayesian optimization backend.



    Supports standard and physics-informed acquisition functions,

    batch optimization, and constrained optimization.

    """

    def __init__(self, config: OptimizationConfig):
        super().__init__(config)
        self._acq_function = None
        self._best_f = None

    def _get_acquisition_function(self, model, best_f: float):
        """Build the acquisition function based on config."""
        acq_type = self.config.acquisition_type

        if acq_type == AcquisitionType.EXPECTED_IMPROVEMENT:
            return ExpectedImprovement(model=model, best_f=best_f)

        elif acq_type == AcquisitionType.UPPER_CONFIDENCE_BOUND:
            return UpperConfidenceBound(model=model, beta=2.0)

        elif acq_type == AcquisitionType.PROBABILITY_OF_IMPROVEMENT:
            return ProbabilityOfImprovement(model=model, best_f=best_f)

        elif acq_type == AcquisitionType.NOISY_EXPECTED_IMPROVEMENT:
            return qNoisyExpectedImprovement(
                model=model,
                X_baseline=self._X_observed,
            )

        elif acq_type == AcquisitionType.KNOWLEDGE_GRADIENT:
            return qKnowledgeGradient(model=model, num_fantasies=8)

        elif acq_type == AcquisitionType.PHYSICS_INFORMED_EI:
            return PhysicsInformedEI(
                model=model,
                best_f=best_f,
                physics_prior=self._physics_prior,
                penalty_weight=self.config.physics_constraint_penalty,
            )

        else:
            raise ValueError(f"Unsupported acquisition type: {acq_type}")

    def suggest(

        self,

        n_candidates: int = 1,

        X_observed: Optional[Tensor] = None,

        y_observed: Optional[Tensor] = None,

    ) -> Tensor:
        """Suggest next experiments using BoTorch optimization."""
        if self._surrogate is None:
            raise RuntimeError("Surrogate model not set. Call set_surrogate() first.")
        if self._bounds is None:
            raise RuntimeError("Bounds not set. Call set_bounds() first.")

        self._X_observed = X_observed
        model = self._surrogate.model
        best_f = float(y_observed.max()) if y_observed is not None else 0.0
        self._best_f = best_f

        acq_function = self._get_acquisition_function(model, best_f)

        # Optimize the acquisition function
        candidates, acq_value = optimize_acqf(
            acq_function=acq_function,
            bounds=self._bounds,
            q=n_candidates,
            num_restarts=10,
            raw_samples=256,
        )

        # Filter through physics constraints
        candidates = self._filter_feasible(candidates)

        return candidates[:n_candidates]

    def update(self, X_new: Tensor, y_new: Tensor) -> None:
        """Update is handled by re-fitting the surrogate externally."""
        pass

    def suggest_batch(

        self,

        batch_size: int,

        X_observed: Tensor,

        y_observed: Tensor,

        sequential: bool = True,

    ) -> Tensor:
        """Suggest a batch of experiments.



        Args:

            batch_size: Number of experiments to suggest.

            X_observed: All observed inputs.

            y_observed: All observed outputs.

            sequential: If True, use sequential greedy optimization.

                        If False, use joint q-batch optimization.



        Returns:

            Tensor of shape (batch_size, d).

        """
        if sequential:
            candidates = []
            for _ in range(batch_size):
                c = self.suggest(1, X_observed, y_observed)
                candidates.append(c)
            return torch.cat(candidates, dim=0)
        else:
            return self.suggest(batch_size, X_observed, y_observed)