| """Multi-fidelity model: physics model as low-fidelity, experiments as high-fidelity."""
|
|
|
| from typing import Callable, Optional, Tuple
|
|
|
| import torch
|
| from torch import Tensor
|
| from botorch.models import SingleTaskMultiFidelityGP
|
| from botorch.models.transforms.outcome import Standardize
|
|
|
| from physics_informed_bo.models.base import SurrogateModel
|
|
|
|
|
| class MultiFidelitySurrogate(SurrogateModel):
|
| """Multi-fidelity BO model using physics as cheap low-fidelity source.
|
|
|
| Uses BoTorch's SingleTaskMultiFidelityGP to jointly model:
|
| - Low-fidelity: physics model predictions (cheap, approximate)
|
| - High-fidelity: experimental observations (expensive, accurate)
|
|
|
| The model learns the correlation between fidelities to transfer
|
| knowledge from physics to improve experimental predictions.
|
| """
|
|
|
| def __init__(
|
| self,
|
| physics_fn: Callable[[Tensor], Tensor],
|
| fidelity_dim: int = -1,
|
| device: str = "cpu",
|
| dtype: torch.dtype = torch.float64,
|
| ):
|
| """
|
| Args:
|
| physics_fn: Physics model function (low-fidelity source).
|
| fidelity_dim: Column index for the fidelity indicator.
|
| device: Torch device.
|
| dtype: Torch dtype.
|
| """
|
| self.physics_fn = physics_fn
|
| self.fidelity_dim = fidelity_dim
|
| self.device = torch.device(device)
|
| self.dtype = dtype
|
| self._model = None
|
|
|
| def build_multi_fidelity_data(
|
| self,
|
| X_experiment: Tensor,
|
| y_experiment: Tensor,
|
| X_physics_grid: Optional[Tensor] = None,
|
| n_physics_points: int = 100,
|
| ) -> Tuple[Tensor, Tensor]:
|
| """Combine physics predictions and experimental data into multi-fidelity dataset.
|
|
|
| Args:
|
| X_experiment: Experimental inputs (n_exp, d).
|
| y_experiment: Experimental outputs (n_exp, 1).
|
| X_physics_grid: Optional grid for physics evaluations.
|
| n_physics_points: Number of physics evaluation points if no grid given.
|
|
|
| Returns:
|
| X_mf: Combined inputs with fidelity column (n_total, d+1).
|
| y_mf: Combined outputs (n_total, 1).
|
| """
|
| d = X_experiment.shape[-1]
|
|
|
|
|
| if X_physics_grid is None:
|
| X_physics_grid = torch.rand(
|
| n_physics_points, d, device=self.device, dtype=self.dtype
|
| )
|
|
|
| with torch.no_grad():
|
| y_physics = self.physics_fn(X_physics_grid).unsqueeze(-1)
|
|
|
|
|
| fidelity_low = torch.zeros(len(X_physics_grid), 1, device=self.device, dtype=self.dtype)
|
| fidelity_high = torch.ones(len(X_experiment), 1, device=self.device, dtype=self.dtype)
|
|
|
| X_physics_mf = torch.cat([X_physics_grid, fidelity_low], dim=-1)
|
| X_experiment_mf = torch.cat([X_experiment, fidelity_high], dim=-1)
|
|
|
| X_mf = torch.cat([X_physics_mf, X_experiment_mf], dim=0)
|
| y_mf = torch.cat([y_physics, y_experiment], dim=0)
|
|
|
| return X_mf, y_mf
|
|
|
| def fit(
|
| self,
|
| X: Tensor,
|
| y: Tensor,
|
| training_iterations: int = 200,
|
| lr: float = 0.05,
|
| ) -> None:
|
| """Fit the multi-fidelity GP model.
|
|
|
| X should include a fidelity column (use build_multi_fidelity_data).
|
| """
|
| X = X.to(device=self.device, dtype=self.dtype)
|
| y = y.to(device=self.device, dtype=self.dtype)
|
| if y.dim() == 1:
|
| y = y.unsqueeze(-1)
|
|
|
| d = X.shape[-1]
|
| fidelity_col = d - 1 if self.fidelity_dim == -1 else self.fidelity_dim
|
|
|
| self._model = SingleTaskMultiFidelityGP(
|
| train_X=X,
|
| train_Y=y,
|
| data_fidelities=[fidelity_col],
|
| outcome_transform=Standardize(m=1),
|
| ).to(device=self.device, dtype=self.dtype)
|
|
|
|
|
| from botorch.fit import fit_gpytorch_mll
|
| from gpytorch.mlls import ExactMarginalLogLikelihood
|
|
|
| mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model)
|
| fit_gpytorch_mll(mll)
|
|
|
| def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
|
| """Predict at high fidelity (fidelity=1)."""
|
| X = X.to(device=self.device, dtype=self.dtype)
|
|
|
|
|
| if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
|
| fidelity_col = torch.ones(len(X), 1, device=self.device, dtype=self.dtype)
|
| X = torch.cat([X, fidelity_col], dim=-1)
|
|
|
| posterior = self._model.posterior(X)
|
| return posterior.mean, posterior.variance
|
|
|
| def posterior(self, X: Tensor):
|
| X = X.to(device=self.device, dtype=self.dtype)
|
| if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
|
| fidelity_col = torch.ones(
|
| *X.shape[:-1], 1, device=self.device, dtype=self.dtype
|
| )
|
| X = torch.cat([X, fidelity_col], dim=-1)
|
| return self._model.posterior(X)
|
|
|
| @property
|
| def model(self):
|
| return self._model
|
|
|