"""Base class for all optimizer backends.""" from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple import torch from torch import Tensor from physics_informed_bo.config import OptimizationConfig from physics_informed_bo.models.base import SurrogateModel from physics_informed_bo.priors.physics_prior import PhysicsPrior class BaseOptimizer(ABC): """Abstract base class for optimizer backends (BoTorch, AX, BoFire).""" def __init__(self, config: OptimizationConfig): self.config = config self._surrogate: Optional[SurrogateModel] = None self._bounds: Optional[Tensor] = None self._physics_prior: Optional[PhysicsPrior] = None def set_surrogate(self, surrogate: SurrogateModel) -> None: self._surrogate = surrogate def set_bounds(self, bounds: Tensor) -> None: """Set search space bounds. Shape: (2, d) where [0] = lower, [1] = upper.""" self._bounds = bounds def set_physics_prior(self, physics_prior: PhysicsPrior) -> None: self._physics_prior = physics_prior @abstractmethod def suggest( self, n_candidates: int = 1, X_observed: Optional[Tensor] = None, y_observed: Optional[Tensor] = None, ) -> Tensor: """Suggest next experiment(s) to run. Args: n_candidates: Number of candidates to suggest. X_observed: All observed inputs so far. y_observed: All observed outputs so far. Returns: Tensor of shape (n_candidates, d) with suggested experiments. """ @abstractmethod def update(self, X_new: Tensor, y_new: Tensor) -> None: """Update the optimizer with new observations.""" def _filter_feasible(self, candidates: Tensor) -> Tensor: """Filter candidates through physics constraints.""" if self._physics_prior is None: return candidates feasible_mask = self._physics_prior.check_feasibility(candidates) feasible = candidates[feasible_mask] if len(feasible) == 0: # If no feasible candidates, return least-violating ones violations = self._physics_prior.constraint_violation(candidates) sorted_idx = violations.argsort() return candidates[sorted_idx[:1]] return feasible