ravimohan19's picture
Upload models/base.py with huggingface_hub
058f8a4 verified
"""Abstract base classes for surrogate models."""
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import torch
from torch import Tensor
class SurrogateModel(ABC):
"""Abstract base class for all surrogate models in the platform.
A surrogate model provides predictions (mean + uncertainty) and can be
updated with new observations.
"""
@abstractmethod
def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Return posterior mean and variance at input locations X.
Args:
X: Input tensor of shape (n, d).
Returns:
mean: Predicted mean of shape (n, 1).
variance: Predicted variance of shape (n, 1).
"""
@abstractmethod
def fit(self, X: Tensor, y: Tensor) -> None:
"""Fit/update the surrogate model with observed data.
Args:
X: Training inputs of shape (n, d).
y: Training targets of shape (n, 1).
"""
@abstractmethod
def posterior(self, X: Tensor):
"""Return the full posterior distribution at X (for BoTorch compatibility).
Args:
X: Input tensor of shape (batch, n, d).
"""
def condition_on_observations(self, X: Tensor, y: Tensor) -> "SurrogateModel":
"""Return a new model conditioned on additional observations.
Default implementation refits the model. Subclasses can override
for fantasy-based conditioning.
"""
raise NotImplementedError(
"Fantasy conditioning not implemented for this model. Use fit() instead."
)