| | import logging |
| | import random |
| | from collections.abc import Iterator |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| |
|
| | from src.data.batch_composer import BatchComposer, ComposedDataset |
| | from src.data.containers import BatchTimeSeriesContainer |
| | from src.data.frequency import parse_frequency |
| | from src.gift_eval.constants import ALL_DATASETS |
| | from src.gift_eval.data import Dataset as GiftEvalDataset |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class GiftEvalDataLoader: |
| | """ |
| | Data loader for GIFT-eval datasets, converting them to BatchTimeSeriesContainer format. |
| | Supports both training and validation modes. |
| | """ |
| |
|
| | TERMS = ["short", "medium", "long"] |
| |
|
| | def __init__( |
| | self, |
| | mode: str = "train", |
| | batch_size: int = 32, |
| | device: torch.device | None = None, |
| | shuffle: bool = True, |
| | to_univariate: bool = False, |
| | max_context_length: int | None = None, |
| | max_windows: int = 20, |
| | skip_datasets_with_nans: bool = False, |
| | datasets_to_use: list[str] | None = None, |
| | dataset_storage_path: str | None = None, |
| | ): |
| | """ |
| | Initialize GIFT-eval data loader. |
| | |
| | Args: |
| | mode: Either "train" or "validation" |
| | batch_size: Number of samples per batch |
| | device: Device to load data to |
| | shuffle: Whether to shuffle data |
| | to_univariate: Whether to convert multivariate data to multiple univariate series |
| | max_context_length: Optional maximum total window length (context + forecast) to prevent memory issues |
| | max_windows: Number of windows to use for training/validation |
| | skip_datasets_with_nans: Whether to skip datasets/series that contain NaN values |
| | datasets_to_use: Optional list of dataset names to use. If None, uses all available datasets |
| | dataset_storage_path: Path on disk where GIFT-eval HuggingFace datasets are stored |
| | """ |
| | |
| | if datasets_to_use is not None and len(datasets_to_use) > 0: |
| | |
| | invalid_datasets = [ds for ds in datasets_to_use if ds not in ALL_DATASETS] |
| | if invalid_datasets: |
| | logger.warning(f"Invalid datasets requested: {invalid_datasets}") |
| | logger.warning(f"Available datasets: {ALL_DATASETS}") |
| | |
| | self.dataset_names = [ds for ds in datasets_to_use if ds in ALL_DATASETS] |
| | else: |
| | self.dataset_names = datasets_to_use |
| | else: |
| | self.dataset_names = ALL_DATASETS |
| |
|
| | |
| | if datasets_to_use is not None and len(datasets_to_use) > 0: |
| | logger.info(f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets") |
| | logger.info(f"Selected datasets: {self.dataset_names}") |
| | else: |
| | logger.info(f"Using all available datasets: {len(self.dataset_names)} datasets") |
| |
|
| | self.terms = self.TERMS |
| | self.mode = mode |
| | self.batch_size = batch_size |
| | self.device = device |
| | self.shuffle = shuffle |
| | self.to_univariate = to_univariate |
| | self.max_context_length = max_context_length |
| | self.skip_datasets_with_nans = skip_datasets_with_nans |
| |
|
| | |
| | self.max_windows = max_windows |
| | self.dataset_storage_path = dataset_storage_path |
| |
|
| | |
| | self._load_datasets() |
| |
|
| | |
| | self._current_idx = 0 |
| | self._epoch_data = [] |
| | self._prepare_epoch_data() |
| |
|
| | def _load_datasets(self) -> None: |
| | """Load all specified GIFT-eval datasets.""" |
| | self.datasets = {} |
| | self.dataset_prediction_lengths = {} |
| |
|
| | for dataset_name in self.dataset_names: |
| | if dataset_name.startswith("m4_"): |
| | max_windows = 1 |
| | else: |
| | max_windows = self.max_windows |
| | try: |
| | |
| | |
| | temp_dataset = GiftEvalDataset( |
| | name=dataset_name, |
| | term=self.terms[0], |
| | to_univariate=False, |
| | max_windows=max_windows, |
| | storage_path=self.dataset_storage_path, |
| | ) |
| |
|
| | |
| | to_univariate = self.to_univariate and temp_dataset.target_dim > 1 |
| |
|
| | |
| | for term in self.terms: |
| | dataset_key = f"{dataset_name}_{term}" |
| | dataset = GiftEvalDataset( |
| | name=dataset_name, |
| | term=term, |
| | to_univariate=to_univariate, |
| | max_windows=max_windows, |
| | storage_path=self.dataset_storage_path, |
| | ) |
| |
|
| | self.datasets[dataset_key] = dataset |
| | self.dataset_prediction_lengths[dataset_key] = dataset.prediction_length |
| |
|
| | logger.info( |
| | f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, " |
| | f"frequency: {dataset.freq}, target_dim: {dataset.target_dim}, " |
| | f"min_length: {dataset._min_series_length}, windows: {dataset.windows}" |
| | ) |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to load dataset {dataset_name}: {str(e)}") |
| | continue |
| |
|
| | def _contains_nan(self, data_entry: dict) -> bool: |
| | """Check if a data entry contains NaN values.""" |
| | target = data_entry.get("target") |
| | if target is None: |
| | return False |
| |
|
| | |
| | try: |
| | target_np = np.asarray(target, dtype=np.float32) |
| | return np.isnan(target_np).any() |
| | except Exception: |
| | logger.warning("NaN check: failed to coerce target to float32; skipping entry") |
| | return True |
| |
|
| | def _convert_to_container( |
| | self, data_entries: list[dict], prediction_length: int, dataset_freq: str |
| | ) -> BatchTimeSeriesContainer: |
| | """Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length.""" |
| | batch_size = len(data_entries) |
| | max_history_len = 0 |
| |
|
| | |
| | for entry in data_entries: |
| | target = np.asarray(entry["target"], dtype=np.float32) |
| | if target.ndim == 1: |
| | target = target.reshape(1, -1) |
| |
|
| | _, seq_len = target.shape |
| |
|
| | |
| | effective_max_context = self.max_context_length if self.max_context_length is not None else seq_len |
| | if seq_len > effective_max_context: |
| | seq_len = effective_max_context |
| |
|
| | |
| | history_len = max(0, min(seq_len, effective_max_context) - prediction_length) |
| | max_history_len = max(max_history_len, history_len) |
| |
|
| | |
| | first_target = np.asarray(data_entries[0]["target"], dtype=np.float32) |
| | if first_target.ndim == 1: |
| | |
| | first_target = first_target.reshape(1, -1) |
| | num_channels = first_target.shape[0] |
| |
|
| | |
| | history_values = np.full((batch_size, max_history_len, num_channels), np.nan, dtype=np.float32) |
| | future_values = np.full((batch_size, prediction_length, num_channels), np.nan, dtype=np.float32) |
| | history_mask = np.zeros((batch_size, max_history_len), dtype=bool) |
| |
|
| | |
| | for i, entry in enumerate(data_entries): |
| | target = np.asarray(entry["target"], dtype=np.float32) |
| | if target.ndim == 1: |
| | target = target.reshape(1, -1) |
| |
|
| | |
| | full_seq_len = target.shape[1] |
| | total_len_allowed = self.max_context_length if self.max_context_length is not None else full_seq_len |
| | total_len_for_entry = min(full_seq_len, total_len_allowed) |
| |
|
| | if total_len_for_entry < prediction_length + 1: |
| | |
| | raise ValueError("Entry too short after max_context_length truncation to form history+future window") |
| |
|
| | truncated = target[:, -total_len_for_entry:] |
| | cur_history_len = total_len_for_entry - prediction_length |
| |
|
| | hist = truncated[:, :cur_history_len] |
| | fut = truncated[:, cur_history_len : cur_history_len + prediction_length] |
| |
|
| | |
| | history_values[i, :cur_history_len, :] = hist.T |
| | future_values[i, :, :] = fut.T |
| | history_mask[i, :cur_history_len] = True |
| |
|
| | |
| | start_timestamp = data_entries[0]["start"] |
| | if hasattr(start_timestamp, "to_timestamp"): |
| | start_numpy = start_timestamp.to_timestamp().to_numpy() |
| | else: |
| | start_numpy = pd.Timestamp(start_timestamp).to_numpy() |
| | start_list = [start_numpy for _ in range(batch_size)] |
| |
|
| | |
| | frequency_enum = parse_frequency(dataset_freq) |
| | frequency_list = [frequency_enum for _ in range(batch_size)] |
| |
|
| | |
| | return BatchTimeSeriesContainer( |
| | history_values=torch.tensor(history_values, dtype=torch.float32), |
| | future_values=torch.tensor(future_values, dtype=torch.float32), |
| | start=start_list, |
| | frequency=frequency_list, |
| | history_mask=torch.tensor(history_mask, dtype=torch.bool) if self.mode == "train" else None, |
| | ) |
| |
|
| | def _prepare_epoch_data(self) -> None: |
| | """Prepare all batches for one epoch.""" |
| | self._epoch_data = [] |
| |
|
| | for dataset_key, dataset in self.datasets.items(): |
| | try: |
| | |
| | if self.mode == "train": |
| | data = dataset.training_dataset |
| | else: |
| | data = dataset.validation_dataset |
| |
|
| | |
| | valid_entries = [] |
| | dataset_freq = dataset.freq |
| | prediction_length = self.dataset_prediction_lengths[dataset_key] |
| |
|
| | for entry in data: |
| | |
| | if self.skip_datasets_with_nans and self._contains_nan(entry): |
| | continue |
| |
|
| | |
| | target = np.asarray(entry["target"]) |
| | if target.ndim == 1: |
| | seq_len = len(target) |
| | else: |
| | seq_len = target.shape[1] |
| |
|
| | |
| | if self.mode == "train" and seq_len < prediction_length + 1: |
| | continue |
| |
|
| | valid_entries.append(entry) |
| |
|
| | if not valid_entries: |
| | logger.warning(f"No valid entries found for {dataset_key}") |
| | continue |
| |
|
| | |
| | for i in range(0, len(valid_entries), self.batch_size): |
| | batch_entries = valid_entries[i : i + self.batch_size] |
| | try: |
| | batch_container = self._convert_to_container(batch_entries, prediction_length, dataset_freq) |
| | self._epoch_data.append((dataset_key, batch_container)) |
| | except Exception as e: |
| | logger.warning(f"Failed to create batch for {dataset_key}: {str(e)}") |
| | continue |
| |
|
| | except Exception as e: |
| | logger.warning( |
| | f"Failed to process dataset {dataset_key}: {str(e)}. " |
| | f"Dataset may be too short for the required offset." |
| | ) |
| | continue |
| |
|
| | |
| | if self.mode == "train" and self.shuffle: |
| | random.shuffle(self._epoch_data) |
| |
|
| | logger.info(f"Prepared {len(self._epoch_data)} batches for {self.mode} mode") |
| |
|
| | def __iter__(self) -> Iterator[BatchTimeSeriesContainer]: |
| | """Iterate through batches for one epoch.""" |
| | |
| | self._current_idx = 0 |
| |
|
| | |
| | if self.mode == "train" and self.shuffle: |
| | random.shuffle(self._epoch_data) |
| |
|
| | return self |
| |
|
| | def __next__(self) -> BatchTimeSeriesContainer: |
| | """Get next batch.""" |
| | if not self._epoch_data: |
| | raise StopIteration("No valid data available") |
| |
|
| | |
| | if self._current_idx >= len(self._epoch_data): |
| | raise StopIteration |
| |
|
| | |
| | dataset_key, batch = self._epoch_data[self._current_idx] |
| | self._current_idx += 1 |
| |
|
| | |
| | if self.device is not None: |
| | batch.to_device(self.device) |
| |
|
| | return batch |
| |
|
| | def __len__(self) -> int: |
| | """Return number of batches per epoch.""" |
| | return len(self._epoch_data) |
| |
|
| |
|
| | class CyclicGiftEvalDataLoader: |
| | """ |
| | Wrapper for GiftEvalDataLoader that provides cycling behavior for training. |
| | This allows training for a fixed number of iterations per epoch, cycling through |
| | the available data as needed. |
| | """ |
| |
|
| | def __init__(self, base_loader: GiftEvalDataLoader, num_iterations_per_epoch: int): |
| | """ |
| | Initialize the cyclic data loader. |
| | |
| | Args: |
| | base_loader: The underlying GiftEvalDataLoader |
| | num_iterations_per_epoch: Number of iterations to run per epoch |
| | """ |
| | self.base_loader = base_loader |
| | self.num_iterations_per_epoch = num_iterations_per_epoch |
| | self.dataset_names = base_loader.dataset_names |
| | self.device = base_loader.device |
| |
|
| | def __iter__(self) -> Iterator[BatchTimeSeriesContainer]: |
| | """Iterate for exactly num_iterations_per_epoch iterations.""" |
| | self._current_iteration = 0 |
| | self._base_iter = iter(self.base_loader) |
| | return self |
| |
|
| | def __next__(self) -> BatchTimeSeriesContainer: |
| | """Get next batch, cycling through base loader as needed.""" |
| | if self._current_iteration >= self.num_iterations_per_epoch: |
| | raise StopIteration |
| |
|
| | try: |
| | batch = next(self._base_iter) |
| | except StopIteration: |
| | |
| | self._base_iter = iter(self.base_loader) |
| | batch = next(self._base_iter) |
| |
|
| | self._current_iteration += 1 |
| | return batch |
| |
|
| | def __len__(self) -> int: |
| | """Return the configured number of iterations per epoch.""" |
| | return self.num_iterations_per_epoch |
| |
|
| |
|
| | def create_synthetic_dataloader( |
| | base_data_dir: str, |
| | batch_size: int = 128, |
| | num_batches_per_epoch: int = 1000, |
| | generator_proportions: dict[str, float] | None = None, |
| | mixed_batches: bool = True, |
| | augmentations: dict[str, bool] | None = None, |
| | augmentation_probabilities: dict[str, float] | None = None, |
| | device: torch.device | None = None, |
| | num_workers: int = 0, |
| | pin_memory: bool = True, |
| | global_seed: int = 42, |
| | nan_stats_path: str | None = None, |
| | nan_patterns_path: str | None = None, |
| | chosen_scaler_name: str | None = None, |
| | ) -> torch.utils.data.DataLoader: |
| | """ |
| | Create a PyTorch DataLoader for training with saved generator batches. |
| | |
| | Args: |
| | base_data_dir: Base directory containing generator subdirectories |
| | batch_size: Size of each training batch |
| | num_batches_per_epoch: Number of batches per epoch |
| | generator_proportions: Dict mapping generator names to proportions |
| | mixed_batches: Whether to create mixed or uniform batches |
| | augmentations: Dict mapping augmentation names to booleans |
| | augmentation_probabilities: Dict mapping augmentation names to probabilities |
| | device: Target device |
| | num_workers: Number of DataLoader workers |
| | pin_memory: Whether to pin memory |
| | global_seed: Global random seed |
| | nan_stats_path: Path to nan stats file |
| | chosen_scaler_name: Name of the scaler that used in training |
| | |
| | Returns: |
| | PyTorch DataLoader |
| | """ |
| |
|
| | |
| | composer = BatchComposer( |
| | base_data_dir=base_data_dir, |
| | generator_proportions=generator_proportions, |
| | mixed_batches=mixed_batches, |
| | device=device, |
| | augmentations=augmentations, |
| | augmentation_probabilities=augmentation_probabilities, |
| | global_seed=global_seed, |
| | nan_stats_path=nan_stats_path, |
| | nan_patterns_path=nan_patterns_path, |
| | chosen_scaler_name=chosen_scaler_name, |
| | ) |
| |
|
| | |
| | dataset = ComposedDataset( |
| | batch_composer=composer, |
| | num_batches_per_epoch=num_batches_per_epoch, |
| | batch_size=batch_size, |
| | ) |
| |
|
| | |
| | def collate_fn(batch): |
| | """Custom collate function that returns a single BatchTimeSeriesContainer.""" |
| | |
| | |
| | return batch[0] |
| |
|
| | |
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=pin_memory, |
| | collate_fn=collate_fn, |
| | drop_last=False, |
| | ) |
| |
|
| | logger.info( |
| | f"Created DataLoader with {len(dataset)} batches per epoch, " |
| | f"batch_size={batch_size}, mixed_batches={mixed_batches}" |
| | ) |
| |
|
| | return dataloader |
| |
|
| |
|
| | class SyntheticValidationDataset(torch.utils.data.Dataset): |
| | """ |
| | Fixed synthetic validation dataset that generates a small number of batches |
| | using the same composition approach as training data. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | base_data_dir: str, |
| | batch_size: int = 128, |
| | num_batches: int = 2, |
| | future_length: int = 512, |
| | generator_proportions: dict[str, float] | None = None, |
| | augmentations: dict[str, bool] | None = None, |
| | augmentation_probabilities: dict[str, float] | None = None, |
| | device: torch.device | None = None, |
| | global_seed: int = 42, |
| | chosen_scaler_name: str | None = None, |
| | nan_stats_path: str | None = None, |
| | nan_patterns_path: str | None = None, |
| | rank: int = 0, |
| | world_size: int = 1, |
| | ): |
| | """ |
| | Initialize the validation dataset. |
| | |
| | Args: |
| | base_data_dir: Base directory containing generator subdirectories |
| | batch_size: Size of each validation batch |
| | num_batches: Number of validation batches to generate (1 or 2) |
| | generator_proportions: Dict mapping generator names to proportions |
| | device: Device to load tensors to |
| | global_seed: Global random seed |
| | chosen_scaler_name: Name of the scaler that used in training |
| | """ |
| | self.batch_size = batch_size |
| | self.num_batches = num_batches |
| | self.device = device |
| |
|
| | |
| | val_augmentations = dict(augmentations or {}) |
| | val_augmentations["length_shortening"] = False |
| |
|
| | self.batch_composer = BatchComposer( |
| | base_data_dir=base_data_dir, |
| | generator_proportions=generator_proportions, |
| | mixed_batches=True, |
| | device=device, |
| | global_seed=global_seed + 999999, |
| | augmentations=val_augmentations, |
| | augmentation_probabilities=augmentation_probabilities, |
| | nan_stats_path=nan_stats_path, |
| | nan_patterns_path=nan_patterns_path, |
| | chosen_scaler_name=chosen_scaler_name, |
| | rank=rank, |
| | world_size=world_size, |
| | ) |
| |
|
| | |
| | self.validation_batches = [] |
| | for i in range(num_batches): |
| | batch, _ = self.batch_composer.create_batch( |
| | batch_size=batch_size, |
| | future_length=future_length, |
| | seed=global_seed + 999999 + i, |
| | ) |
| | self.validation_batches.append(batch) |
| |
|
| | logger.info(f"Created {num_batches} fixed validation batches with batch_size={batch_size}") |
| |
|
| | def __len__(self) -> int: |
| | return self.num_batches |
| |
|
| | def __getitem__(self, idx: int) -> BatchTimeSeriesContainer: |
| | """ |
| | Get a pre-generated validation batch by index. |
| | |
| | Args: |
| | idx: Batch index |
| | |
| | Returns: |
| | BatchTimeSeriesContainer |
| | """ |
| | if idx >= len(self.validation_batches): |
| | raise IndexError(f"Batch index {idx} out of range") |
| |
|
| | batch = self.validation_batches[idx] |
| |
|
| | |
| | if self.device is not None: |
| | batch.to_device(self.device) |
| |
|
| | return batch |
| |
|
| |
|
| | def create_synthetic_dataset( |
| | base_data_dir: str, |
| | batch_size: int = 128, |
| | num_batches_per_epoch: int = 1000, |
| | generator_proportions: dict[str, float] | None = None, |
| | mixed_batches: bool = True, |
| | augmentations: dict[str, bool] | None = None, |
| | augmentation_probabilities: dict[str, float] | None = None, |
| | global_seed: int = 42, |
| | nan_stats_path: str | None = None, |
| | nan_patterns_path: str | None = None, |
| | chosen_scaler_name: str | None = None, |
| | rank: int = 0, |
| | world_size: int = 1, |
| | ) -> ComposedDataset: |
| | """ |
| | Creates the ComposedDataset for training with saved generator batches. |
| | |
| | Args: |
| | base_data_dir: Base directory containing generator subdirectories. |
| | batch_size: Size of each training batch. |
| | num_batches_per_epoch: Number of batches per epoch. |
| | generator_proportions: Dict mapping generator names to proportions. |
| | mixed_batches: Whether to create mixed or uniform batches. |
| | augmentations: Dict mapping augmentation names to booleans. |
| | global_seed: Global random seed. |
| | nan_stats_path: Path to nan stats file. |
| | chosen_scaler_name: Name of the scaler to use. |
| | Returns: |
| | A ComposedDataset instance. |
| | """ |
| | |
| | composer = BatchComposer( |
| | base_data_dir=base_data_dir, |
| | generator_proportions=generator_proportions, |
| | mixed_batches=mixed_batches, |
| | device=None, |
| | augmentations=augmentations, |
| | augmentation_probabilities=augmentation_probabilities, |
| | global_seed=global_seed, |
| | nan_stats_path=nan_stats_path, |
| | nan_patterns_path=nan_patterns_path, |
| | chosen_scaler_name=chosen_scaler_name, |
| | rank=rank, |
| | world_size=world_size, |
| | ) |
| |
|
| | |
| | dataset = ComposedDataset( |
| | batch_composer=composer, |
| | num_batches_per_epoch=num_batches_per_epoch, |
| | batch_size=batch_size, |
| | ) |
| |
|
| | logger.info( |
| | f"Created ComposedDataset with {len(dataset)} batches per epoch, " |
| | f"batch_size={batch_size}, mixed_batches={mixed_batches}" |
| | ) |
| |
|
| | return dataset |
| |
|