| """Predictor implementation wrapping the TimeSeriesModel for GIFT-Eval.""" |
|
|
| import logging |
| from collections.abc import Iterator |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from gluonts.model.forecast import QuantileForecast |
| from gluonts.model.predictor import Predictor |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| from src.data.containers import BatchTimeSeriesContainer |
| from src.data.frequency import parse_frequency |
| from src.data.scalers import RobustScaler |
| from src.models.model import TimeSeriesModel |
| from src.utils.utils import device |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TimeSeriesPredictor(Predictor): |
| """Unified predictor for TimeSeriesModel supporting flexible construction.""" |
|
|
| def __init__( |
| self, |
| model: TimeSeriesModel, |
| config: dict, |
| ds_prediction_length: int, |
| ds_freq: str, |
| batch_size: int = 32, |
| max_context_length: int | None = None, |
| debug: bool = False, |
| ) -> None: |
| |
| self.ds_prediction_length = ds_prediction_length |
| self.ds_freq = ds_freq |
| self.batch_size = batch_size |
| self.max_context_length = max_context_length |
| self.debug = debug |
|
|
| |
| self.model = model.module if isinstance(model, DDP) else model |
| self.model.eval() |
| self.config = config |
|
|
| |
| scaler_type = self.config.get("TimeSeriesModel", {}).get("scaler", "custom_robust") |
| epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3) |
| if scaler_type == "custom_robust": |
| self.scaler = RobustScaler(epsilon=epsilon) |
| else: |
| raise ValueError(f"Unsupported scaler type: {scaler_type}") |
|
|
| def set_dataset_context( |
| self, |
| prediction_length: int | None = None, |
| freq: str | None = None, |
| batch_size: int | None = None, |
| max_context_length: int | None = None, |
| ) -> None: |
| """Update lightweight dataset-specific attributes without reloading the model.""" |
|
|
| if prediction_length is not None: |
| self.ds_prediction_length = prediction_length |
| if freq is not None: |
| self.ds_freq = freq |
| if batch_size is not None: |
| self.batch_size = batch_size |
| if max_context_length is not None: |
| self.max_context_length = max_context_length |
|
|
| @classmethod |
| def from_model( |
| cls, |
| model: TimeSeriesModel, |
| config: dict, |
| ds_prediction_length: int, |
| ds_freq: str, |
| batch_size: int = 32, |
| max_context_length: int | None = None, |
| debug: bool = False, |
| ) -> "TimeSeriesPredictor": |
| return cls( |
| model=model, |
| config=config, |
| ds_prediction_length=ds_prediction_length, |
| ds_freq=ds_freq, |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| debug=debug, |
| ) |
|
|
| @classmethod |
| def from_paths( |
| cls, |
| model_path: str, |
| config_path: str, |
| ds_prediction_length: int, |
| ds_freq: str, |
| batch_size: int = 32, |
| max_context_length: int | None = None, |
| debug: bool = False, |
| ) -> "TimeSeriesPredictor": |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
| model = cls._load_model_from_path(config=config, model_path=model_path) |
| return cls( |
| model=model, |
| config=config, |
| ds_prediction_length=ds_prediction_length, |
| ds_freq=ds_freq, |
| batch_size=batch_size, |
| max_context_length=max_context_length, |
| debug=debug, |
| ) |
|
|
| @staticmethod |
| def _load_model_from_path(config: dict, model_path: str) -> TimeSeriesModel: |
| try: |
| model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) |
| checkpoint = torch.load(model_path, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| logger.info(f"Successfully loaded model from {model_path}") |
| return model |
| except Exception as exc: |
| logger.error(f"Failed to load model from {model_path}: {exc}") |
| raise |
|
|
| def predict(self, test_data_input) -> Iterator[QuantileForecast]: |
| """Generate forecasts for the test data.""" |
|
|
| if hasattr(test_data_input, "__iter__") and not isinstance(test_data_input, list): |
| test_data_input = list(test_data_input) |
| logger.debug(f"Processing {len(test_data_input)} time series") |
|
|
| |
| |
| def _effective_length(entry) -> int: |
| target = entry["target"] |
| if target.ndim == 1: |
| seq_len = len(target) |
| else: |
| |
| seq_len = target.shape[1] |
| if self.max_context_length is not None: |
| seq_len = min(seq_len, self.max_context_length) |
| return seq_len |
|
|
| length_to_items: dict[int, list[tuple[int, object]]] = {} |
| for idx, entry in enumerate(test_data_input): |
| seq_len = _effective_length(entry) |
| length_to_items.setdefault(seq_len, []).append((idx, entry)) |
|
|
| total = len(test_data_input) |
| ordered_results: list[QuantileForecast | None] = [None] * total |
|
|
| for _, items in length_to_items.items(): |
| for i in range(0, len(items), self.batch_size): |
| chunk = items[i : i + self.batch_size] |
| entries = [entry for (_orig_idx, entry) in chunk] |
| batch_forecasts = self._predict_batch(entries) |
| for forecast_idx, (orig_idx, _entry) in enumerate(chunk): |
| ordered_results[orig_idx] = batch_forecasts[forecast_idx] |
|
|
| return ordered_results |
|
|
| def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]: |
| """Generate predictions for a batch of time series.""" |
|
|
| logger.debug(f"Processing batch of size: {len(test_data_batch)}") |
|
|
| try: |
| batch_container = self._convert_to_batch_container(test_data_batch) |
|
|
| if isinstance(device, torch.device): |
| device_type = device.type |
| else: |
| device_type = "cuda" if "cuda" in str(device).lower() else "cpu" |
| enable_autocast = device_type == "cuda" |
|
|
| with torch.autocast( |
| device_type=device_type, |
| dtype=torch.bfloat16, |
| enabled=enable_autocast, |
| ): |
| with torch.no_grad(): |
| model_output = self.model(batch_container, drop_enc_allow=False) |
|
|
| forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container) |
|
|
| logger.debug(f"Generated {len(forecasts)} forecasts") |
| return forecasts |
| except Exception as exc: |
| logger.error(f"Error in batch prediction: {exc}") |
| raise |
|
|
| def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer: |
| """Convert gluonts test data to BatchTimeSeriesContainer.""" |
|
|
| batch_size = len(test_data_batch) |
| history_values_list = [] |
| start_dates = [] |
| frequencies = [] |
|
|
| for entry in test_data_batch: |
| target = entry["target"] |
|
|
| if target.ndim == 1: |
| target = target.reshape(-1, 1) |
| else: |
| target = target.T |
|
|
| if self.max_context_length is not None and len(target) > self.max_context_length: |
| target = target[-self.max_context_length :] |
|
|
| history_values_list.append(target) |
| start_dates.append(entry["start"].to_timestamp().to_datetime64()) |
| frequencies.append(parse_frequency(entry["freq"])) |
|
|
| history_values_np = np.stack(history_values_list, axis=0) |
| num_channels = history_values_np.shape[2] |
|
|
| history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device) |
|
|
| future_values = torch.zeros( |
| (batch_size, self.ds_prediction_length, num_channels), |
| dtype=torch.float32, |
| device=device, |
| ) |
|
|
| return BatchTimeSeriesContainer( |
| history_values=history_values, |
| future_values=future_values, |
| start=start_dates, |
| frequency=frequencies, |
| ) |
|
|
| def _convert_to_forecasts( |
| self, |
| model_output: dict, |
| test_data_batch: list, |
| batch_container: BatchTimeSeriesContainer, |
| ) -> list[QuantileForecast]: |
| """Convert model predictions to QuantileForecast objects.""" |
|
|
| predictions = model_output["result"] |
| scale_statistics = model_output["scale_statistics"] |
|
|
| if predictions.ndim == 4: |
| predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics) |
| is_quantile = True |
| quantile_levels = self.model.quantiles |
| else: |
| predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics) |
| is_quantile = False |
| quantile_levels = [0.5] |
|
|
| forecasts: list[QuantileForecast] = [] |
| for idx, entry in enumerate(test_data_batch): |
| history_length = int(batch_container.history_values.shape[1]) |
| start_date = entry["start"] |
| forecast_start = start_date + history_length |
|
|
| if is_quantile: |
| pred_array = predictions_unscaled[idx].cpu().numpy() |
|
|
| if pred_array.shape[1] == 1: |
| pred_array = pred_array.squeeze(1) |
| forecast_arrays = pred_array.T |
| else: |
| forecast_arrays = pred_array.transpose(2, 0, 1) |
|
|
| forecast = QuantileForecast( |
| forecast_arrays=forecast_arrays, |
| forecast_keys=[str(q) for q in quantile_levels], |
| start_date=forecast_start, |
| ) |
| else: |
| pred_array = predictions_unscaled[idx].cpu().numpy() |
|
|
| if pred_array.shape[1] == 1: |
| pred_array = pred_array.squeeze(1) |
| forecast_arrays = pred_array.reshape(1, -1) |
| else: |
| forecast_arrays = pred_array.reshape(1, *pred_array.shape) |
|
|
| forecast = QuantileForecast( |
| forecast_arrays=forecast_arrays, |
| forecast_keys=["0.5"], |
| start_date=forecast_start, |
| ) |
|
|
| forecasts.append(forecast) |
|
|
| return forecasts |
|
|
|
|
| __all__ = ["TimeSeriesPredictor"] |
|
|