ravimohan19's picture
Upload models/multi_fidelity.py with huggingface_hub
1bcef47 verified
"""Multi-fidelity model: physics model as low-fidelity, experiments as high-fidelity."""
from typing import Callable, Optional, Tuple
import torch
from torch import Tensor
from botorch.models import SingleTaskMultiFidelityGP
from botorch.models.transforms.outcome import Standardize
from physics_informed_bo.models.base import SurrogateModel
class MultiFidelitySurrogate(SurrogateModel):
"""Multi-fidelity BO model using physics as cheap low-fidelity source.
Uses BoTorch's SingleTaskMultiFidelityGP to jointly model:
- Low-fidelity: physics model predictions (cheap, approximate)
- High-fidelity: experimental observations (expensive, accurate)
The model learns the correlation between fidelities to transfer
knowledge from physics to improve experimental predictions.
"""
def __init__(
self,
physics_fn: Callable[[Tensor], Tensor],
fidelity_dim: int = -1,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
):
"""
Args:
physics_fn: Physics model function (low-fidelity source).
fidelity_dim: Column index for the fidelity indicator.
device: Torch device.
dtype: Torch dtype.
"""
self.physics_fn = physics_fn
self.fidelity_dim = fidelity_dim
self.device = torch.device(device)
self.dtype = dtype
self._model = None
def build_multi_fidelity_data(
self,
X_experiment: Tensor,
y_experiment: Tensor,
X_physics_grid: Optional[Tensor] = None,
n_physics_points: int = 100,
) -> Tuple[Tensor, Tensor]:
"""Combine physics predictions and experimental data into multi-fidelity dataset.
Args:
X_experiment: Experimental inputs (n_exp, d).
y_experiment: Experimental outputs (n_exp, 1).
X_physics_grid: Optional grid for physics evaluations.
n_physics_points: Number of physics evaluation points if no grid given.
Returns:
X_mf: Combined inputs with fidelity column (n_total, d+1).
y_mf: Combined outputs (n_total, 1).
"""
d = X_experiment.shape[-1]
# Generate physics data
if X_physics_grid is None:
X_physics_grid = torch.rand(
n_physics_points, d, device=self.device, dtype=self.dtype
)
with torch.no_grad():
y_physics = self.physics_fn(X_physics_grid).unsqueeze(-1)
# Add fidelity column: 0 = low (physics), 1 = high (experiment)
fidelity_low = torch.zeros(len(X_physics_grid), 1, device=self.device, dtype=self.dtype)
fidelity_high = torch.ones(len(X_experiment), 1, device=self.device, dtype=self.dtype)
X_physics_mf = torch.cat([X_physics_grid, fidelity_low], dim=-1)
X_experiment_mf = torch.cat([X_experiment, fidelity_high], dim=-1)
X_mf = torch.cat([X_physics_mf, X_experiment_mf], dim=0)
y_mf = torch.cat([y_physics, y_experiment], dim=0)
return X_mf, y_mf
def fit(
self,
X: Tensor,
y: Tensor,
training_iterations: int = 200,
lr: float = 0.05,
) -> None:
"""Fit the multi-fidelity GP model.
X should include a fidelity column (use build_multi_fidelity_data).
"""
X = X.to(device=self.device, dtype=self.dtype)
y = y.to(device=self.device, dtype=self.dtype)
if y.dim() == 1:
y = y.unsqueeze(-1)
d = X.shape[-1]
fidelity_col = d - 1 if self.fidelity_dim == -1 else self.fidelity_dim
self._model = SingleTaskMultiFidelityGP(
train_X=X,
train_Y=y,
data_fidelities=[fidelity_col],
outcome_transform=Standardize(m=1),
).to(device=self.device, dtype=self.dtype)
# Optimize hyperparameters
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model)
fit_gpytorch_mll(mll)
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Predict at high fidelity (fidelity=1)."""
X = X.to(device=self.device, dtype=self.dtype)
# Add high-fidelity indicator
if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
fidelity_col = torch.ones(len(X), 1, device=self.device, dtype=self.dtype)
X = torch.cat([X, fidelity_col], dim=-1)
posterior = self._model.posterior(X)
return posterior.mean, posterior.variance
def posterior(self, X: Tensor):
X = X.to(device=self.device, dtype=self.dtype)
if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
fidelity_col = torch.ones(
*X.shape[:-1], 1, device=self.device, dtype=self.dtype
)
X = torch.cat([X, fidelity_col], dim=-1)
return self._model.posterior(X)
@property
def model(self):
return self._model