File size: 2,413 Bytes
0342a1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Base class for all optimizer backends."""

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor

from physics_informed_bo.config import OptimizationConfig
from physics_informed_bo.models.base import SurrogateModel
from physics_informed_bo.priors.physics_prior import PhysicsPrior


class BaseOptimizer(ABC):
    """Abstract base class for optimizer backends (BoTorch, AX, BoFire)."""

    def __init__(self, config: OptimizationConfig):
        self.config = config
        self._surrogate: Optional[SurrogateModel] = None
        self._bounds: Optional[Tensor] = None
        self._physics_prior: Optional[PhysicsPrior] = None

    def set_surrogate(self, surrogate: SurrogateModel) -> None:
        self._surrogate = surrogate

    def set_bounds(self, bounds: Tensor) -> None:
        """Set search space bounds. Shape: (2, d) where [0] = lower, [1] = upper."""
        self._bounds = bounds

    def set_physics_prior(self, physics_prior: PhysicsPrior) -> None:
        self._physics_prior = physics_prior

    @abstractmethod
    def suggest(

        self,

        n_candidates: int = 1,

        X_observed: Optional[Tensor] = None,

        y_observed: Optional[Tensor] = None,

    ) -> Tensor:
        """Suggest next experiment(s) to run.



        Args:

            n_candidates: Number of candidates to suggest.

            X_observed: All observed inputs so far.

            y_observed: All observed outputs so far.



        Returns:

            Tensor of shape (n_candidates, d) with suggested experiments.

        """

    @abstractmethod
    def update(self, X_new: Tensor, y_new: Tensor) -> None:
        """Update the optimizer with new observations."""

    def _filter_feasible(self, candidates: Tensor) -> Tensor:
        """Filter candidates through physics constraints."""
        if self._physics_prior is None:
            return candidates
        feasible_mask = self._physics_prior.check_feasibility(candidates)
        feasible = candidates[feasible_mask]
        if len(feasible) == 0:
            # If no feasible candidates, return least-violating ones
            violations = self._physics_prior.constraint_violation(candidates)
            sorted_idx = violations.argsort()
            return candidates[sorted_idx[:1]]
        return feasible