"""Physics model wrappers for use as GP mean functions and standalone surrogates.""" from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor import gpytorch from gpytorch.means import Mean from physics_informed_bo.models.base import SurrogateModel class PhysicsMeanFunction(Mean): """Wraps a user-defined physics function as a GPyTorch mean function. The physics function becomes the prior mean of the GP, so the GP only needs to learn the residual (discrepancy) between physics predictions and actual observations. Example: def arrhenius(X): # X[:, 0] = temperature, X[:, 1] = activation energy T, Ea = X[:, 0], X[:, 1] R = 8.314 return torch.log(1e13 * torch.exp(-Ea / (R * T))) mean_fn = PhysicsMeanFunction(arrhenius) """ def __init__( self, physics_fn: Callable[[Tensor], Tensor], output_scale: float = 1.0, learnable_scale: bool = True, ): super().__init__() self.physics_fn = physics_fn if learnable_scale: self.register_parameter( "raw_output_scale", torch.nn.Parameter(torch.tensor(output_scale)), ) else: self.register_buffer( "raw_output_scale", torch.tensor(output_scale) ) @property def output_scale(self) -> Tensor: return self.raw_output_scale def forward(self, X: Tensor) -> Tensor: """Evaluate the physics model at X and scale the output.""" physics_pred = self.physics_fn(X) return self.output_scale * physics_pred class PhysicsModel(SurrogateModel): """Standalone physics model as a surrogate (no GP, deterministic). Useful as a baseline or when no experimental data is available yet. Uncertainty is estimated via input perturbation or a fixed noise level. """ def __init__( self, physics_fn: Callable[[Tensor], Tensor], noise_std: float = 0.1, param_uncertainty: Optional[Dict[str, float]] = None, ): """ Args: physics_fn: Callable that takes (n, d) tensor and returns (n,) tensor. noise_std: Assumed observation noise standard deviation. param_uncertainty: Dict mapping parameter names to their uncertainty for propagating uncertainty through the physics model. """ self.physics_fn = physics_fn self.noise_std = noise_std self.param_uncertainty = param_uncertainty or {} self._train_X = None self._train_y = None def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: """Predict using the physics model with estimated uncertainty.""" with torch.no_grad(): mean = self.physics_fn(X).unsqueeze(-1) # Estimate uncertainty via local sensitivity + noise variance = self._estimate_variance(X) return mean, variance def _estimate_variance(self, X: Tensor) -> Tensor: """Estimate predictive variance via finite-difference sensitivity analysis.""" eps = 1e-4 n, d = X.shape var = torch.full((n, 1), self.noise_std**2, dtype=X.dtype, device=X.device) # Add sensitivity-based uncertainty if we have param uncertainties if self.param_uncertainty: X_perturbed = X.clone().requires_grad_(True) f = self.physics_fn(X_perturbed) for i in range(d): X_plus = X.clone() X_plus[:, i] += eps X_minus = X.clone() X_minus[:, i] -= eps df_dx = (self.physics_fn(X_plus) - self.physics_fn(X_minus)) / (2 * eps) param_name = f"x{i}" if param_name in self.param_uncertainty: var[:, 0] += (df_dx * self.param_uncertainty[param_name]) ** 2 return var def fit(self, X: Tensor, y: Tensor) -> None: """Store observations (physics model is not fitted, but we track data).""" self._train_X = X self._train_y = y # Update noise estimate from residuals if we have data if X is not None and y is not None: with torch.no_grad(): pred = self.physics_fn(X) residuals = y.squeeze() - pred self.noise_std = float(residuals.std()) def posterior(self, X: Tensor): """Return a simple posterior-like object for BoTorch compatibility.""" mean, var = self.predict(X) return _SimplePosterior(mean, var) class _SimplePosterior: """Minimal posterior wrapper for BoTorch compatibility.""" def __init__(self, mean: Tensor, variance: Tensor): self._mean = mean self._variance = variance @property def mean(self) -> Tensor: return self._mean @property def variance(self) -> Tensor: return self._variance def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: """Draw reparameterized samples from the posterior.""" std = self._variance.sqrt() eps = torch.randn(*sample_shape, *self._mean.shape, device=self._mean.device) return self._mean + std * eps