ravimohan19's picture
Upload optimizers/base_optimizer.py with huggingface_hub
0342a1d verified
"""Base class for all optimizer backends."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from physics_informed_bo.config import OptimizationConfig
from physics_informed_bo.models.base import SurrogateModel
from physics_informed_bo.priors.physics_prior import PhysicsPrior
class BaseOptimizer(ABC):
"""Abstract base class for optimizer backends (BoTorch, AX, BoFire)."""
def __init__(self, config: OptimizationConfig):
self.config = config
self._surrogate: Optional[SurrogateModel] = None
self._bounds: Optional[Tensor] = None
self._physics_prior: Optional[PhysicsPrior] = None
def set_surrogate(self, surrogate: SurrogateModel) -> None:
self._surrogate = surrogate
def set_bounds(self, bounds: Tensor) -> None:
"""Set search space bounds. Shape: (2, d) where [0] = lower, [1] = upper."""
self._bounds = bounds
def set_physics_prior(self, physics_prior: PhysicsPrior) -> None:
self._physics_prior = physics_prior
@abstractmethod
def suggest(
self,
n_candidates: int = 1,
X_observed: Optional[Tensor] = None,
y_observed: Optional[Tensor] = None,
) -> Tensor:
"""Suggest next experiment(s) to run.
Args:
n_candidates: Number of candidates to suggest.
X_observed: All observed inputs so far.
y_observed: All observed outputs so far.
Returns:
Tensor of shape (n_candidates, d) with suggested experiments.
"""
@abstractmethod
def update(self, X_new: Tensor, y_new: Tensor) -> None:
"""Update the optimizer with new observations."""
def _filter_feasible(self, candidates: Tensor) -> Tensor:
"""Filter candidates through physics constraints."""
if self._physics_prior is None:
return candidates
feasible_mask = self._physics_prior.check_feasibility(candidates)
feasible = candidates[feasible_mask]
if len(feasible) == 0:
# If no feasible candidates, return least-violating ones
violations = self._physics_prior.constraint_violation(candidates)
sorted_idx = violations.argsort()
return candidates[sorted_idx[:1]]
return feasible