File size: 5,208 Bytes
1bcef47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """Multi-fidelity model: physics model as low-fidelity, experiments as high-fidelity."""
from typing import Callable, Optional, Tuple
import torch
from torch import Tensor
from botorch.models import SingleTaskMultiFidelityGP
from botorch.models.transforms.outcome import Standardize
from physics_informed_bo.models.base import SurrogateModel
class MultiFidelitySurrogate(SurrogateModel):
"""Multi-fidelity BO model using physics as cheap low-fidelity source.
Uses BoTorch's SingleTaskMultiFidelityGP to jointly model:
- Low-fidelity: physics model predictions (cheap, approximate)
- High-fidelity: experimental observations (expensive, accurate)
The model learns the correlation between fidelities to transfer
knowledge from physics to improve experimental predictions.
"""
def __init__(
self,
physics_fn: Callable[[Tensor], Tensor],
fidelity_dim: int = -1,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
):
"""
Args:
physics_fn: Physics model function (low-fidelity source).
fidelity_dim: Column index for the fidelity indicator.
device: Torch device.
dtype: Torch dtype.
"""
self.physics_fn = physics_fn
self.fidelity_dim = fidelity_dim
self.device = torch.device(device)
self.dtype = dtype
self._model = None
def build_multi_fidelity_data(
self,
X_experiment: Tensor,
y_experiment: Tensor,
X_physics_grid: Optional[Tensor] = None,
n_physics_points: int = 100,
) -> Tuple[Tensor, Tensor]:
"""Combine physics predictions and experimental data into multi-fidelity dataset.
Args:
X_experiment: Experimental inputs (n_exp, d).
y_experiment: Experimental outputs (n_exp, 1).
X_physics_grid: Optional grid for physics evaluations.
n_physics_points: Number of physics evaluation points if no grid given.
Returns:
X_mf: Combined inputs with fidelity column (n_total, d+1).
y_mf: Combined outputs (n_total, 1).
"""
d = X_experiment.shape[-1]
# Generate physics data
if X_physics_grid is None:
X_physics_grid = torch.rand(
n_physics_points, d, device=self.device, dtype=self.dtype
)
with torch.no_grad():
y_physics = self.physics_fn(X_physics_grid).unsqueeze(-1)
# Add fidelity column: 0 = low (physics), 1 = high (experiment)
fidelity_low = torch.zeros(len(X_physics_grid), 1, device=self.device, dtype=self.dtype)
fidelity_high = torch.ones(len(X_experiment), 1, device=self.device, dtype=self.dtype)
X_physics_mf = torch.cat([X_physics_grid, fidelity_low], dim=-1)
X_experiment_mf = torch.cat([X_experiment, fidelity_high], dim=-1)
X_mf = torch.cat([X_physics_mf, X_experiment_mf], dim=0)
y_mf = torch.cat([y_physics, y_experiment], dim=0)
return X_mf, y_mf
def fit(
self,
X: Tensor,
y: Tensor,
training_iterations: int = 200,
lr: float = 0.05,
) -> None:
"""Fit the multi-fidelity GP model.
X should include a fidelity column (use build_multi_fidelity_data).
"""
X = X.to(device=self.device, dtype=self.dtype)
y = y.to(device=self.device, dtype=self.dtype)
if y.dim() == 1:
y = y.unsqueeze(-1)
d = X.shape[-1]
fidelity_col = d - 1 if self.fidelity_dim == -1 else self.fidelity_dim
self._model = SingleTaskMultiFidelityGP(
train_X=X,
train_Y=y,
data_fidelities=[fidelity_col],
outcome_transform=Standardize(m=1),
).to(device=self.device, dtype=self.dtype)
# Optimize hyperparameters
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model)
fit_gpytorch_mll(mll)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Predict at high fidelity (fidelity=1)."""
X = X.to(device=self.device, dtype=self.dtype)
# Add high-fidelity indicator
if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
fidelity_col = torch.ones(len(X), 1, device=self.device, dtype=self.dtype)
X = torch.cat([X, fidelity_col], dim=-1)
posterior = self._model.posterior(X)
return posterior.mean, posterior.variance
def posterior(self, X: Tensor):
X = X.to(device=self.device, dtype=self.dtype)
if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
fidelity_col = torch.ones(
*X.shape[:-1], 1, device=self.device, dtype=self.dtype
)
X = torch.cat([X, fidelity_col], dim=-1)
return self._model.posterior(X)
@property
def model(self):
return self._model
|