File size: 2,413 Bytes
0342a1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | """Base class for all optimizer backends."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from physics_informed_bo.config import OptimizationConfig
from physics_informed_bo.models.base import SurrogateModel
from physics_informed_bo.priors.physics_prior import PhysicsPrior
class BaseOptimizer(ABC):
"""Abstract base class for optimizer backends (BoTorch, AX, BoFire)."""
def __init__(self, config: OptimizationConfig):
self.config = config
self._surrogate: Optional[SurrogateModel] = None
self._bounds: Optional[Tensor] = None
self._physics_prior: Optional[PhysicsPrior] = None
def set_surrogate(self, surrogate: SurrogateModel) -> None:
self._surrogate = surrogate
def set_bounds(self, bounds: Tensor) -> None:
"""Set search space bounds. Shape: (2, d) where [0] = lower, [1] = upper."""
self._bounds = bounds
def set_physics_prior(self, physics_prior: PhysicsPrior) -> None:
self._physics_prior = physics_prior
@abstractmethod
def suggest(
self,
n_candidates: int = 1,
X_observed: Optional[Tensor] = None,
y_observed: Optional[Tensor] = None,
) -> Tensor:
"""Suggest next experiment(s) to run.
Args:
n_candidates: Number of candidates to suggest.
X_observed: All observed inputs so far.
y_observed: All observed outputs so far.
Returns:
Tensor of shape (n_candidates, d) with suggested experiments.
"""
@abstractmethod
def update(self, X_new: Tensor, y_new: Tensor) -> None:
"""Update the optimizer with new observations."""
def _filter_feasible(self, candidates: Tensor) -> Tensor:
"""Filter candidates through physics constraints."""
if self._physics_prior is None:
return candidates
feasible_mask = self._physics_prior.check_feasibility(candidates)
feasible = candidates[feasible_mask]
if len(feasible) == 0:
# If no feasible candidates, return least-violating ones
violations = self._physics_prior.constraint_violation(candidates)
sorted_idx = violations.argsort()
return candidates[sorted_idx[:1]]
return feasible
|