ravimohan19 commited on
Commit
f7040fb
·
verified ·
1 Parent(s): 814f98a

Upload models/hybrid_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/hybrid_model.py +220 -0
models/hybrid_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hybrid surrogate model combining physics models with data-driven GP."""
2
+
3
+ from typing import Callable, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from physics_informed_bo.models.base import SurrogateModel
9
+ from physics_informed_bo.models.gp_model import PhysicsInformedGP, StandardGP
10
+ from physics_informed_bo.models.physics_model import PhysicsModel
11
+
12
+
13
+ class HybridSurrogate(SurrogateModel):
14
+ """Hybrid model that combines a physics model with a GP.
15
+
16
+ Provides multiple operating modes:
17
+
18
+ 1. **Physics-as-mean** (default): Physics function is the GP mean,
19
+ GP learns the residual/discrepancy.
20
+ 2. **Weighted ensemble**: Weighted combination of physics prediction
21
+ and GP prediction, with weight adapting based on data.
22
+ 3. **Physics-only**: Pure physics model when no data is available.
23
+ 4. **GP-only**: Pure GP when physics model is unreliable.
24
+
25
+ The model automatically transitions from physics-only → hybrid → GP-dominant
26
+ as more experimental data becomes available.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ physics_fn: Callable[[Tensor], Tensor],
32
+ mode: str = "physics_as_mean",
33
+ kernel: str = "matern",
34
+ noise_variance: float = 0.01,
35
+ learn_noise: bool = True,
36
+ initial_physics_weight: float = 1.0,
37
+ adapt_weight: bool = True,
38
+ device: str = "cpu",
39
+ dtype: torch.dtype = torch.float64,
40
+ ):
41
+ """
42
+ Args:
43
+ physics_fn: Physics model callable. Takes (n, d) tensor, returns (n,) tensor.
44
+ mode: One of 'physics_as_mean', 'weighted_ensemble', 'physics_only', 'gp_only'.
45
+ kernel: GP kernel type ('rbf' or 'matern').
46
+ noise_variance: Initial observation noise variance.
47
+ learn_noise: Whether to learn noise variance from data.
48
+ initial_physics_weight: Starting weight for physics model (0 to 1).
49
+ adapt_weight: Auto-adapt physics weight based on residual analysis.
50
+ device: Torch device.
51
+ dtype: Torch dtype.
52
+ """
53
+ self.physics_fn = physics_fn
54
+ self.mode = mode
55
+ self.kernel = kernel
56
+ self.noise_variance = noise_variance
57
+ self.learn_noise = learn_noise
58
+ self.physics_weight = initial_physics_weight
59
+ self.adapt_weight = adapt_weight
60
+ self.device = torch.device(device)
61
+ self.dtype = dtype
62
+
63
+ # Internal models
64
+ self._physics_model = PhysicsModel(physics_fn, noise_std=noise_variance**0.5)
65
+ self._gp_model: Optional[PhysicsInformedGP] = None
66
+ self._standard_gp: Optional[StandardGP] = None
67
+ self._is_fitted = False
68
+ self._train_X = None
69
+ self._train_y = None
70
+
71
+ def fit(
72
+ self,
73
+ X: Tensor,
74
+ y: Tensor,
75
+ training_iterations: int = 200,
76
+ lr: float = 0.05,
77
+ ) -> None:
78
+ """Fit the hybrid model.
79
+
80
+ If mode is 'physics_as_mean', fits a PhysicsInformedGP.
81
+ If mode is 'weighted_ensemble', fits both physics and standard GP,
82
+ then determines optimal weighting.
83
+ """
84
+ X = X.to(device=self.device, dtype=self.dtype)
85
+ y = y.to(device=self.device, dtype=self.dtype)
86
+ if y.dim() == 1:
87
+ y = y.unsqueeze(-1)
88
+
89
+ self._train_X = X
90
+ self._train_y = y
91
+
92
+ if self.mode == "physics_only":
93
+ self._physics_model.fit(X, y)
94
+
95
+ elif self.mode == "physics_as_mean":
96
+ self._gp_model = PhysicsInformedGP(
97
+ physics_fn=self.physics_fn,
98
+ kernel=self.kernel,
99
+ noise_variance=self.noise_variance,
100
+ learn_noise=self.learn_noise,
101
+ device=str(self.device),
102
+ dtype=self.dtype,
103
+ )
104
+ self._gp_model.fit(X, y, training_iterations, lr)
105
+
106
+ elif self.mode == "weighted_ensemble":
107
+ # Fit physics-informed GP
108
+ self._gp_model = PhysicsInformedGP(
109
+ physics_fn=self.physics_fn,
110
+ kernel=self.kernel,
111
+ noise_variance=self.noise_variance,
112
+ learn_noise=self.learn_noise,
113
+ device=str(self.device),
114
+ dtype=self.dtype,
115
+ )
116
+ self._gp_model.fit(X, y, training_iterations, lr)
117
+
118
+ # Fit standard GP
119
+ self._standard_gp = StandardGP(
120
+ kernel=self.kernel,
121
+ noise_variance=self.noise_variance,
122
+ learn_noise=self.learn_noise,
123
+ device=str(self.device),
124
+ dtype=self.dtype,
125
+ )
126
+ self._standard_gp.fit(X, y, training_iterations, lr)
127
+
128
+ if self.adapt_weight:
129
+ self._adapt_physics_weight(X, y)
130
+
131
+ elif self.mode == "gp_only":
132
+ self._standard_gp = StandardGP(
133
+ kernel=self.kernel,
134
+ noise_variance=self.noise_variance,
135
+ learn_noise=self.learn_noise,
136
+ device=str(self.device),
137
+ dtype=self.dtype,
138
+ )
139
+ self._standard_gp.fit(X, y, training_iterations, lr)
140
+
141
+ self._is_fitted = True
142
+
143
+ def _adapt_physics_weight(self, X: Tensor, y: Tensor) -> None:
144
+ """Adapt physics weight based on LOO cross-validation of residuals.
145
+
146
+ If physics model is accurate (small residuals), keep high weight.
147
+ If physics model is inaccurate, reduce weight toward pure GP.
148
+ """
149
+ with torch.no_grad():
150
+ physics_pred = self.physics_fn(X)
151
+ residuals = y.squeeze() - physics_pred
152
+ relative_error = (residuals.abs() / (y.squeeze().abs() + 1e-8)).mean()
153
+
154
+ # Sigmoid mapping: high error → low physics weight
155
+ self.physics_weight = float(torch.sigmoid(-5.0 * (relative_error - 0.5)))
156
+
157
+ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
158
+ X = X.to(device=self.device, dtype=self.dtype)
159
+
160
+ if self.mode == "physics_only" or not self._is_fitted:
161
+ return self._physics_model.predict(X)
162
+
163
+ elif self.mode == "physics_as_mean":
164
+ return self._gp_model.predict(X)
165
+
166
+ elif self.mode == "weighted_ensemble":
167
+ gp_mean, gp_var = self._gp_model.predict(X)
168
+ std_mean, std_var = self._standard_gp.predict(X)
169
+ w = self.physics_weight
170
+ mean = w * gp_mean + (1 - w) * std_mean
171
+ variance = w**2 * gp_var + (1 - w) ** 2 * std_var
172
+ return mean, variance
173
+
174
+ elif self.mode == "gp_only":
175
+ return self._standard_gp.predict(X)
176
+
177
+ def posterior(self, X: Tensor):
178
+ if self.mode in ("physics_as_mean", "weighted_ensemble") and self._gp_model:
179
+ return self._gp_model.posterior(X)
180
+ elif self.mode == "gp_only" and self._standard_gp:
181
+ return self._standard_gp.posterior(X)
182
+ else:
183
+ return self._physics_model.posterior(X)
184
+
185
+ @property
186
+ def model(self):
187
+ """Return the primary BoTorch-compatible model for optimization."""
188
+ if self._gp_model is not None:
189
+ return self._gp_model.model
190
+ elif self._standard_gp is not None:
191
+ return self._standard_gp.model
192
+ return None
193
+
194
+ def get_physics_residuals(self) -> Optional[Tensor]:
195
+ """Return residuals between physics predictions and training data."""
196
+ if self._train_X is None or self._train_y is None:
197
+ return None
198
+ with torch.no_grad():
199
+ physics_pred = self.physics_fn(self._train_X)
200
+ return self._train_y.squeeze() - physics_pred
201
+
202
+ def physics_model_quality(self) -> Dict:
203
+ """Assess how well the physics model matches the data."""
204
+ if self._train_X is None:
205
+ return {"status": "no_data"}
206
+
207
+ residuals = self.get_physics_residuals()
208
+ rmse = float((residuals**2).mean().sqrt())
209
+ mae = float(residuals.abs().mean())
210
+ r2 = float(
211
+ 1 - (residuals**2).sum() / ((self._train_y.squeeze() - self._train_y.mean()) ** 2).sum()
212
+ )
213
+
214
+ return {
215
+ "rmse": rmse,
216
+ "mae": mae,
217
+ "r2": r2,
218
+ "physics_weight": self.physics_weight,
219
+ "n_observations": len(self._train_X),
220
+ }