| """Physics model wrappers for use as GP mean functions and standalone surrogates."""
|
|
|
| from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
| import torch
|
| from torch import Tensor
|
| import gpytorch
|
| from gpytorch.means import Mean
|
|
|
| from physics_informed_bo.models.base import SurrogateModel
|
|
|
|
|
| class PhysicsMeanFunction(Mean):
|
| """Wraps a user-defined physics function as a GPyTorch mean function.
|
|
|
| The physics function becomes the prior mean of the GP, so the GP
|
| only needs to learn the residual (discrepancy) between physics
|
| predictions and actual observations.
|
|
|
| Example:
|
| def arrhenius(X):
|
| # X[:, 0] = temperature, X[:, 1] = activation energy
|
| T, Ea = X[:, 0], X[:, 1]
|
| R = 8.314
|
| return torch.log(1e13 * torch.exp(-Ea / (R * T)))
|
|
|
| mean_fn = PhysicsMeanFunction(arrhenius)
|
| """
|
|
|
| def __init__(
|
| self,
|
| physics_fn: Callable[[Tensor], Tensor],
|
| output_scale: float = 1.0,
|
| learnable_scale: bool = True,
|
| ):
|
| super().__init__()
|
| self.physics_fn = physics_fn
|
| if learnable_scale:
|
| self.register_parameter(
|
| "raw_output_scale",
|
| torch.nn.Parameter(torch.tensor(output_scale)),
|
| )
|
| else:
|
| self.register_buffer(
|
| "raw_output_scale", torch.tensor(output_scale)
|
| )
|
|
|
| @property
|
| def output_scale(self) -> Tensor:
|
| return self.raw_output_scale
|
|
|
| def forward(self, X: Tensor) -> Tensor:
|
| """Evaluate the physics model at X and scale the output."""
|
| physics_pred = self.physics_fn(X)
|
| return self.output_scale * physics_pred
|
|
|
|
|
| class PhysicsModel(SurrogateModel):
|
| """Standalone physics model as a surrogate (no GP, deterministic).
|
|
|
| Useful as a baseline or when no experimental data is available yet.
|
| Uncertainty is estimated via input perturbation or a fixed noise level.
|
| """
|
|
|
| def __init__(
|
| self,
|
| physics_fn: Callable[[Tensor], Tensor],
|
| noise_std: float = 0.1,
|
| param_uncertainty: Optional[Dict[str, float]] = None,
|
| ):
|
| """
|
| Args:
|
| physics_fn: Callable that takes (n, d) tensor and returns (n,) tensor.
|
| noise_std: Assumed observation noise standard deviation.
|
| param_uncertainty: Dict mapping parameter names to their uncertainty
|
| for propagating uncertainty through the physics model.
|
| """
|
| self.physics_fn = physics_fn
|
| self.noise_std = noise_std
|
| self.param_uncertainty = param_uncertainty or {}
|
| self._train_X = None
|
| self._train_y = None
|
|
|
| def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
|
| """Predict using the physics model with estimated uncertainty."""
|
| with torch.no_grad():
|
| mean = self.physics_fn(X).unsqueeze(-1)
|
|
|
|
|
| variance = self._estimate_variance(X)
|
| return mean, variance
|
|
|
| def _estimate_variance(self, X: Tensor) -> Tensor:
|
| """Estimate predictive variance via finite-difference sensitivity analysis."""
|
| eps = 1e-4
|
| n, d = X.shape
|
| var = torch.full((n, 1), self.noise_std**2, dtype=X.dtype, device=X.device)
|
|
|
|
|
| if self.param_uncertainty:
|
| X_perturbed = X.clone().requires_grad_(True)
|
| f = self.physics_fn(X_perturbed)
|
| for i in range(d):
|
| X_plus = X.clone()
|
| X_plus[:, i] += eps
|
| X_minus = X.clone()
|
| X_minus[:, i] -= eps
|
| df_dx = (self.physics_fn(X_plus) - self.physics_fn(X_minus)) / (2 * eps)
|
| param_name = f"x{i}"
|
| if param_name in self.param_uncertainty:
|
| var[:, 0] += (df_dx * self.param_uncertainty[param_name]) ** 2
|
|
|
| return var
|
|
|
| def fit(self, X: Tensor, y: Tensor) -> None:
|
| """Store observations (physics model is not fitted, but we track data)."""
|
| self._train_X = X
|
| self._train_y = y
|
|
|
|
|
| if X is not None and y is not None:
|
| with torch.no_grad():
|
| pred = self.physics_fn(X)
|
| residuals = y.squeeze() - pred
|
| self.noise_std = float(residuals.std())
|
|
|
| def posterior(self, X: Tensor):
|
| """Return a simple posterior-like object for BoTorch compatibility."""
|
| mean, var = self.predict(X)
|
| return _SimplePosterior(mean, var)
|
|
|
|
|
| class _SimplePosterior:
|
| """Minimal posterior wrapper for BoTorch compatibility."""
|
|
|
| def __init__(self, mean: Tensor, variance: Tensor):
|
| self._mean = mean
|
| self._variance = variance
|
|
|
| @property
|
| def mean(self) -> Tensor:
|
| return self._mean
|
|
|
| @property
|
| def variance(self) -> Tensor:
|
| return self._variance
|
|
|
| def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
|
| """Draw reparameterized samples from the posterior."""
|
| std = self._variance.sqrt()
|
| eps = torch.randn(*sample_shape, *self._mean.shape, device=self._mean.device)
|
| return self._mean + std * eps
|
|
|