File size: 5,437 Bytes
b82edbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Physics model wrappers for use as GP mean functions and standalone surrogates."""

from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
import gpytorch
from gpytorch.means import Mean

from physics_informed_bo.models.base import SurrogateModel


class PhysicsMeanFunction(Mean):
    """Wraps a user-defined physics function as a GPyTorch mean function.



    The physics function becomes the prior mean of the GP, so the GP

    only needs to learn the residual (discrepancy) between physics

    predictions and actual observations.



    Example:

        def arrhenius(X):

            # X[:, 0] = temperature, X[:, 1] = activation energy

            T, Ea = X[:, 0], X[:, 1]

            R = 8.314

            return torch.log(1e13 * torch.exp(-Ea / (R * T)))



        mean_fn = PhysicsMeanFunction(arrhenius)

    """

    def __init__(

        self,

        physics_fn: Callable[[Tensor], Tensor],

        output_scale: float = 1.0,

        learnable_scale: bool = True,

    ):
        super().__init__()
        self.physics_fn = physics_fn
        if learnable_scale:
            self.register_parameter(
                "raw_output_scale",
                torch.nn.Parameter(torch.tensor(output_scale)),
            )
        else:
            self.register_buffer(
                "raw_output_scale", torch.tensor(output_scale)
            )

    @property
    def output_scale(self) -> Tensor:
        return self.raw_output_scale

    def forward(self, X: Tensor) -> Tensor:
        """Evaluate the physics model at X and scale the output."""
        physics_pred = self.physics_fn(X)
        return self.output_scale * physics_pred


class PhysicsModel(SurrogateModel):
    """Standalone physics model as a surrogate (no GP, deterministic).



    Useful as a baseline or when no experimental data is available yet.

    Uncertainty is estimated via input perturbation or a fixed noise level.

    """

    def __init__(

        self,

        physics_fn: Callable[[Tensor], Tensor],

        noise_std: float = 0.1,

        param_uncertainty: Optional[Dict[str, float]] = None,

    ):
        """

        Args:

            physics_fn: Callable that takes (n, d) tensor and returns (n,) tensor.

            noise_std: Assumed observation noise standard deviation.

            param_uncertainty: Dict mapping parameter names to their uncertainty

                for propagating uncertainty through the physics model.

        """
        self.physics_fn = physics_fn
        self.noise_std = noise_std
        self.param_uncertainty = param_uncertainty or {}
        self._train_X = None
        self._train_y = None

    def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
        """Predict using the physics model with estimated uncertainty."""
        with torch.no_grad():
            mean = self.physics_fn(X).unsqueeze(-1)

        # Estimate uncertainty via local sensitivity + noise
        variance = self._estimate_variance(X)
        return mean, variance

    def _estimate_variance(self, X: Tensor) -> Tensor:
        """Estimate predictive variance via finite-difference sensitivity analysis."""
        eps = 1e-4
        n, d = X.shape
        var = torch.full((n, 1), self.noise_std**2, dtype=X.dtype, device=X.device)

        # Add sensitivity-based uncertainty if we have param uncertainties
        if self.param_uncertainty:
            X_perturbed = X.clone().requires_grad_(True)
            f = self.physics_fn(X_perturbed)
            for i in range(d):
                X_plus = X.clone()
                X_plus[:, i] += eps
                X_minus = X.clone()
                X_minus[:, i] -= eps
                df_dx = (self.physics_fn(X_plus) - self.physics_fn(X_minus)) / (2 * eps)
                param_name = f"x{i}"
                if param_name in self.param_uncertainty:
                    var[:, 0] += (df_dx * self.param_uncertainty[param_name]) ** 2

        return var

    def fit(self, X: Tensor, y: Tensor) -> None:
        """Store observations (physics model is not fitted, but we track data)."""
        self._train_X = X
        self._train_y = y

        # Update noise estimate from residuals if we have data
        if X is not None and y is not None:
            with torch.no_grad():
                pred = self.physics_fn(X)
                residuals = y.squeeze() - pred
                self.noise_std = float(residuals.std())

    def posterior(self, X: Tensor):
        """Return a simple posterior-like object for BoTorch compatibility."""
        mean, var = self.predict(X)
        return _SimplePosterior(mean, var)


class _SimplePosterior:
    """Minimal posterior wrapper for BoTorch compatibility."""

    def __init__(self, mean: Tensor, variance: Tensor):
        self._mean = mean
        self._variance = variance

    @property
    def mean(self) -> Tensor:
        return self._mean

    @property
    def variance(self) -> Tensor:
        return self._variance

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
        """Draw reparameterized samples from the posterior."""
        std = self._variance.sqrt()
        eps = torch.randn(*sample_shape, *self._mean.shape, device=self._mean.device)
        return self._mean + std * eps