ravimohan19's picture
Upload models/physics_model.py with huggingface_hub
b82edbd verified
"""Physics model wrappers for use as GP mean functions and standalone surrogates."""
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
import gpytorch
from gpytorch.means import Mean
from physics_informed_bo.models.base import SurrogateModel
class PhysicsMeanFunction(Mean):
"""Wraps a user-defined physics function as a GPyTorch mean function.
The physics function becomes the prior mean of the GP, so the GP
only needs to learn the residual (discrepancy) between physics
predictions and actual observations.
Example:
def arrhenius(X):
# X[:, 0] = temperature, X[:, 1] = activation energy
T, Ea = X[:, 0], X[:, 1]
R = 8.314
return torch.log(1e13 * torch.exp(-Ea / (R * T)))
mean_fn = PhysicsMeanFunction(arrhenius)
"""
def __init__(
self,
physics_fn: Callable[[Tensor], Tensor],
output_scale: float = 1.0,
learnable_scale: bool = True,
):
super().__init__()
self.physics_fn = physics_fn
if learnable_scale:
self.register_parameter(
"raw_output_scale",
torch.nn.Parameter(torch.tensor(output_scale)),
)
else:
self.register_buffer(
"raw_output_scale", torch.tensor(output_scale)
)
@property
def output_scale(self) -> Tensor:
return self.raw_output_scale
def forward(self, X: Tensor) -> Tensor:
"""Evaluate the physics model at X and scale the output."""
physics_pred = self.physics_fn(X)
return self.output_scale * physics_pred
class PhysicsModel(SurrogateModel):
"""Standalone physics model as a surrogate (no GP, deterministic).
Useful as a baseline or when no experimental data is available yet.
Uncertainty is estimated via input perturbation or a fixed noise level.
"""
def __init__(
self,
physics_fn: Callable[[Tensor], Tensor],
noise_std: float = 0.1,
param_uncertainty: Optional[Dict[str, float]] = None,
):
"""
Args:
physics_fn: Callable that takes (n, d) tensor and returns (n,) tensor.
noise_std: Assumed observation noise standard deviation.
param_uncertainty: Dict mapping parameter names to their uncertainty
for propagating uncertainty through the physics model.
"""
self.physics_fn = physics_fn
self.noise_std = noise_std
self.param_uncertainty = param_uncertainty or {}
self._train_X = None
self._train_y = None
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Predict using the physics model with estimated uncertainty."""
with torch.no_grad():
mean = self.physics_fn(X).unsqueeze(-1)
# Estimate uncertainty via local sensitivity + noise
variance = self._estimate_variance(X)
return mean, variance
def _estimate_variance(self, X: Tensor) -> Tensor:
"""Estimate predictive variance via finite-difference sensitivity analysis."""
eps = 1e-4
n, d = X.shape
var = torch.full((n, 1), self.noise_std**2, dtype=X.dtype, device=X.device)
# Add sensitivity-based uncertainty if we have param uncertainties
if self.param_uncertainty:
X_perturbed = X.clone().requires_grad_(True)
f = self.physics_fn(X_perturbed)
for i in range(d):
X_plus = X.clone()
X_plus[:, i] += eps
X_minus = X.clone()
X_minus[:, i] -= eps
df_dx = (self.physics_fn(X_plus) - self.physics_fn(X_minus)) / (2 * eps)
param_name = f"x{i}"
if param_name in self.param_uncertainty:
var[:, 0] += (df_dx * self.param_uncertainty[param_name]) ** 2
return var
def fit(self, X: Tensor, y: Tensor) -> None:
"""Store observations (physics model is not fitted, but we track data)."""
self._train_X = X
self._train_y = y
# Update noise estimate from residuals if we have data
if X is not None and y is not None:
with torch.no_grad():
pred = self.physics_fn(X)
residuals = y.squeeze() - pred
self.noise_std = float(residuals.std())
def posterior(self, X: Tensor):
"""Return a simple posterior-like object for BoTorch compatibility."""
mean, var = self.predict(X)
return _SimplePosterior(mean, var)
class _SimplePosterior:
"""Minimal posterior wrapper for BoTorch compatibility."""
def __init__(self, mean: Tensor, variance: Tensor):
self._mean = mean
self._variance = variance
@property
def mean(self) -> Tensor:
return self._mean
@property
def variance(self) -> Tensor:
return self._variance
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
"""Draw reparameterized samples from the posterior."""
std = self._variance.sqrt()
eps = torch.randn(*sample_shape, *self._mean.shape, device=self._mean.device)
return self._mean + std * eps