"""Abstract base classes for surrogate models.""" from abc import ABC, abstractmethod from typing import Optional, Tuple import torch from torch import Tensor class SurrogateModel(ABC): """Abstract base class for all surrogate models in the platform. A surrogate model provides predictions (mean + uncertainty) and can be updated with new observations. """ @abstractmethod def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: """Return posterior mean and variance at input locations X. Args: X: Input tensor of shape (n, d). Returns: mean: Predicted mean of shape (n, 1). variance: Predicted variance of shape (n, 1). """ @abstractmethod def fit(self, X: Tensor, y: Tensor) -> None: """Fit/update the surrogate model with observed data. Args: X: Training inputs of shape (n, d). y: Training targets of shape (n, 1). """ @abstractmethod def posterior(self, X: Tensor): """Return the full posterior distribution at X (for BoTorch compatibility). Args: X: Input tensor of shape (batch, n, d). """ def condition_on_observations(self, X: Tensor, y: Tensor) -> "SurrogateModel": """Return a new model conditioned on additional observations. Default implementation refits the model. Subclasses can override for fantasy-based conditioning. """ raise NotImplementedError( "Fantasy conditioning not implemented for this model. Use fit() instead." )