"""Multi-fidelity model: physics model as low-fidelity, experiments as high-fidelity.""" from typing import Callable, Optional, Tuple import torch from torch import Tensor from botorch.models import SingleTaskMultiFidelityGP from botorch.models.transforms.outcome import Standardize from physics_informed_bo.models.base import SurrogateModel class MultiFidelitySurrogate(SurrogateModel): """Multi-fidelity BO model using physics as cheap low-fidelity source. Uses BoTorch's SingleTaskMultiFidelityGP to jointly model: - Low-fidelity: physics model predictions (cheap, approximate) - High-fidelity: experimental observations (expensive, accurate) The model learns the correlation between fidelities to transfer knowledge from physics to improve experimental predictions. """ def __init__( self, physics_fn: Callable[[Tensor], Tensor], fidelity_dim: int = -1, device: str = "cpu", dtype: torch.dtype = torch.float64, ): """ Args: physics_fn: Physics model function (low-fidelity source). fidelity_dim: Column index for the fidelity indicator. device: Torch device. dtype: Torch dtype. """ self.physics_fn = physics_fn self.fidelity_dim = fidelity_dim self.device = torch.device(device) self.dtype = dtype self._model = None def build_multi_fidelity_data( self, X_experiment: Tensor, y_experiment: Tensor, X_physics_grid: Optional[Tensor] = None, n_physics_points: int = 100, ) -> Tuple[Tensor, Tensor]: """Combine physics predictions and experimental data into multi-fidelity dataset. Args: X_experiment: Experimental inputs (n_exp, d). y_experiment: Experimental outputs (n_exp, 1). X_physics_grid: Optional grid for physics evaluations. n_physics_points: Number of physics evaluation points if no grid given. Returns: X_mf: Combined inputs with fidelity column (n_total, d+1). y_mf: Combined outputs (n_total, 1). """ d = X_experiment.shape[-1] # Generate physics data if X_physics_grid is None: X_physics_grid = torch.rand( n_physics_points, d, device=self.device, dtype=self.dtype ) with torch.no_grad(): y_physics = self.physics_fn(X_physics_grid).unsqueeze(-1) # Add fidelity column: 0 = low (physics), 1 = high (experiment) fidelity_low = torch.zeros(len(X_physics_grid), 1, device=self.device, dtype=self.dtype) fidelity_high = torch.ones(len(X_experiment), 1, device=self.device, dtype=self.dtype) X_physics_mf = torch.cat([X_physics_grid, fidelity_low], dim=-1) X_experiment_mf = torch.cat([X_experiment, fidelity_high], dim=-1) X_mf = torch.cat([X_physics_mf, X_experiment_mf], dim=0) y_mf = torch.cat([y_physics, y_experiment], dim=0) return X_mf, y_mf def fit( self, X: Tensor, y: Tensor, training_iterations: int = 200, lr: float = 0.05, ) -> None: """Fit the multi-fidelity GP model. X should include a fidelity column (use build_multi_fidelity_data). """ X = X.to(device=self.device, dtype=self.dtype) y = y.to(device=self.device, dtype=self.dtype) if y.dim() == 1: y = y.unsqueeze(-1) d = X.shape[-1] fidelity_col = d - 1 if self.fidelity_dim == -1 else self.fidelity_dim self._model = SingleTaskMultiFidelityGP( train_X=X, train_Y=y, data_fidelities=[fidelity_col], outcome_transform=Standardize(m=1), ).to(device=self.device, dtype=self.dtype) # Optimize hyperparameters from botorch.fit import fit_gpytorch_mll from gpytorch.mlls import ExactMarginalLogLikelihood mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model) fit_gpytorch_mll(mll) def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: """Predict at high fidelity (fidelity=1).""" X = X.to(device=self.device, dtype=self.dtype) # Add high-fidelity indicator if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1: fidelity_col = torch.ones(len(X), 1, device=self.device, dtype=self.dtype) X = torch.cat([X, fidelity_col], dim=-1) posterior = self._model.posterior(X) return posterior.mean, posterior.variance def posterior(self, X: Tensor): X = X.to(device=self.device, dtype=self.dtype) if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1: fidelity_col = torch.ones( *X.shape[:-1], 1, device=self.device, dtype=self.dtype ) X = torch.cat([X, fidelity_col], dim=-1) return self._model.posterior(X) @property def model(self): return self._model